|
|
@@ -399,6 +399,59 @@ __global__ void pooling2d_device_template_nchwc(const int8_t* __restrict__ src, |
|
|
|
*(reinterpret_cast<ldg_type*>(g_dst_ptr)) = res; |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Pooler, int pack_size, int pack_byte, |
|
|
|
int ldg_width_assert = 4> |
|
|
|
__global__ void pooling2d_device_template_nhwc(const int8_t* __restrict__ src, |
|
|
|
int8_t* __restrict__ dst, |
|
|
|
Param param, int zero_point) { |
|
|
|
const int tid = blockIdx.x * blockDim.x + threadIdx.x; |
|
|
|
using ldg_type = typename Pooler::feed_type; |
|
|
|
static int constexpr ldg_width = sizeof(ldg_type) / sizeof(int32_t); |
|
|
|
static int constexpr ldg_width_bytes = sizeof(ldg_type); |
|
|
|
MEGDNN_STATIC_ASSERT( |
|
|
|
ldg_width == ldg_width_assert, |
|
|
|
"pooling2d (NHWC) kernel must ldg_width == ldg_width_assert"); |
|
|
|
const int c_packed = param.c / pack_size; |
|
|
|
const int batch = tid / (param.ho * param.wo * c_packed); |
|
|
|
const int batch_residual = tid - batch * param.ho * param.wo * c_packed; |
|
|
|
const int oh = batch_residual / (param.wo * c_packed); |
|
|
|
const int oh_residual = batch_residual - oh * param.wo * c_packed; |
|
|
|
const int ow = oh_residual / c_packed; |
|
|
|
const int ow_residual = oh_residual - ow * c_packed; |
|
|
|
const int sec = ow_residual; |
|
|
|
if (batch >= param.n || oh >= param.ho || ow >= param.wo) |
|
|
|
return; |
|
|
|
|
|
|
|
const int in_batch_stride = |
|
|
|
param.hi * param.wi * param.c * pack_byte / pack_size; |
|
|
|
const int out_batch_stride = |
|
|
|
param.ho * param.wo * param.c * pack_byte / pack_size; |
|
|
|
const int w_stride = param.c * pack_byte / pack_size; |
|
|
|
const int8_t* __restrict__ g_src_ptr = |
|
|
|
src + (batch * in_batch_stride + sec * ldg_width_bytes); |
|
|
|
int8_t* __restrict__ g_dst_ptr = |
|
|
|
dst + (batch * out_batch_stride + (oh * param.wo + ow) * w_stride + |
|
|
|
sec * ldg_width_bytes); |
|
|
|
|
|
|
|
Pooler pooler(param.window_h * param.window_w, zero_point); |
|
|
|
pooler.init(); |
|
|
|
for (int fh = 0; fh < param.window_h; fh++) { |
|
|
|
uint32_t ih = oh * param.sh + fh - param.ph; |
|
|
|
for (int fw = 0; fw < param.window_w; fw++) { |
|
|
|
uint32_t iw = ow * param.sw + fw - param.pw; |
|
|
|
if (ih < param.hi && iw < param.wi) { |
|
|
|
const int8_t* __restrict__ cur_src_ptr = |
|
|
|
g_src_ptr + (ih * param.wi + iw) * w_stride; |
|
|
|
ldg_type sval = |
|
|
|
__ldg(reinterpret_cast<const ldg_type*>(cur_src_ptr)); |
|
|
|
pooler.feed(sval); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
ldg_type res = pooler.get_ans(); |
|
|
|
*(reinterpret_cast<ldg_type*>(g_dst_ptr)) = res; |
|
|
|
} |
|
|
|
|
|
|
|
}; // namespace |
|
|
|
|
|
|
|
void megdnn::cuda::pooling2d::do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, |
|
|
@@ -588,4 +641,68 @@ void megdnn::cuda::pooling2d::do_pooling2d_int4_ncdiv64hw64( |
|
|
|
after_kernel_launch(); |
|
|
|
} |
|
|
|
|
|
|
|
void megdnn::cuda::pooling2d::do_pooling2d_int4_nhwc( |
|
|
|
const int8_t* d_src, int8_t* d_dst, const Param& param, |
|
|
|
cudaStream_t stream, uint32_t mode, bool uint_case, int zero_point) { |
|
|
|
using Mode = megdnn::param_enumv::Pooling::Mode; |
|
|
|
void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param, |
|
|
|
int zero_point); |
|
|
|
|
|
|
|
megdnn_assert(param.c % 8 == 0); |
|
|
|
constexpr int ldg_byte = 4; |
|
|
|
constexpr int elem_per_byte = 2; |
|
|
|
constexpr int ldg_width_assert = 1; |
|
|
|
constexpr int pack_size = ldg_byte * elem_per_byte; |
|
|
|
constexpr int pack_byte = pack_size / elem_per_byte; |
|
|
|
constexpr int elem_per_thread = ldg_byte * elem_per_byte; |
|
|
|
uint32_t vthreads = |
|
|
|
param.n * param.c * param.ho * param.wo / elem_per_thread; |
|
|
|
if (uint_case) { |
|
|
|
switch (mode) { |
|
|
|
case Mode::MAX: |
|
|
|
kern = pooling2d_device_template_nhwc< |
|
|
|
MaxPooler<dt_quint4, int32_t>, pack_size, pack_byte, |
|
|
|
ldg_width_assert>; |
|
|
|
break; |
|
|
|
case Mode::AVERAGE: |
|
|
|
kern = pooling2d_device_template_nhwc< |
|
|
|
MeanIncludeRoundedPooler<dt_quint4, int32_t, int32_t>, |
|
|
|
pack_size, pack_byte, ldg_width_assert>; |
|
|
|
break; |
|
|
|
case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: |
|
|
|
kern = pooling2d_device_template_nhwc< |
|
|
|
MeanExcludeRoundedPooler<dt_quint4, int32_t, int32_t>, |
|
|
|
pack_size, pack_byte, ldg_width_assert>; |
|
|
|
break; |
|
|
|
default: |
|
|
|
megdnn_assert(false, "invalid pooling mode"); |
|
|
|
} |
|
|
|
|
|
|
|
} else { |
|
|
|
switch (mode) { |
|
|
|
case Mode::MAX: |
|
|
|
kern = pooling2d_device_template_nhwc< |
|
|
|
MaxPooler<dt_qint4, int32_t>, pack_size, pack_byte, |
|
|
|
ldg_width_assert>; |
|
|
|
break; |
|
|
|
case Mode::AVERAGE: |
|
|
|
kern = pooling2d_device_template_nhwc< |
|
|
|
MeanIncludeRoundedPooler<dt_qint4, int32_t, int32_t>, |
|
|
|
pack_size, pack_byte, ldg_width_assert>; |
|
|
|
break; |
|
|
|
case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: |
|
|
|
kern = pooling2d_device_template_nhwc< |
|
|
|
MeanExcludeRoundedPooler<dt_qint4, int32_t, int32_t>, |
|
|
|
pack_size, pack_byte, ldg_width_assert>; |
|
|
|
break; |
|
|
|
default: |
|
|
|
megdnn_assert(false, "invalid pooling mode"); |
|
|
|
} |
|
|
|
} |
|
|
|
uint32_t nr_threads = query_blocksize_for_kernel(kern); |
|
|
|
nr_threads = std::min(nr_threads, vthreads); |
|
|
|
uint32_t nr_blocks = DIVUP(vthreads, nr_threads); |
|
|
|
kern<<<nr_blocks, nr_threads, 0, stream>>>(d_src, d_dst, param, zero_point); |
|
|
|
after_kernel_launch(); |
|
|
|
} |
|
|
|
// vim: syntax=cuda.doxygen |