|
|
@@ -29,53 +29,51 @@ __device__ __forceinline__ int pack_int8_to_int8x4(int8_t x, int8_t y, int8_t z, |
|
|
|
return ix; |
|
|
|
} |
|
|
|
|
|
|
|
template <int regs, typename Dtype, typename OutDtype> |
|
|
|
template <int regs, int dtype_bits, typename OutDtype> |
|
|
|
__device__ __forceinline__ OutDtype pack_int8(int8_t (&x)[regs]); |
|
|
|
|
|
|
|
template <> |
|
|
|
__device__ __forceinline__ int pack_int8<4, int8_t, int>(int8_t (&x)[4]) { |
|
|
|
__device__ __forceinline__ int pack_int8<4, 8, int>(int8_t (&x)[4]) { |
|
|
|
return pack_int8_to_int8x4(x[0], x[1], x[2], x[3]); |
|
|
|
} |
|
|
|
|
|
|
|
template <> |
|
|
|
__device__ __forceinline__ int2 pack_int8<8, int8_t, int2>(int8_t (&x)[8]) { |
|
|
|
__device__ __forceinline__ int2 pack_int8<8, 8, int2>(int8_t (&x)[8]) { |
|
|
|
int8_t x0[4]{x[0], x[1], x[2], x[3]}; |
|
|
|
int8_t x1[4]{x[4], x[5], x[6], x[7]}; |
|
|
|
return ::make_int2(pack_int8<4, int8_t, int>(x0), |
|
|
|
pack_int8<4, int8_t, int>(x1)); |
|
|
|
return ::make_int2(pack_int8<4, 8, int>(x0), pack_int8<4, 8, int>(x1)); |
|
|
|
} |
|
|
|
|
|
|
|
template <> |
|
|
|
__device__ __forceinline__ int4 pack_int8<16, int8_t, int4>(int8_t (&x)[16]) { |
|
|
|
__device__ __forceinline__ int4 pack_int8<16, 8, int4>(int8_t (&x)[16]) { |
|
|
|
int8_t x0[4]{x[0], x[1], x[2], x[3]}; |
|
|
|
int8_t x1[4]{x[4], x[5], x[6], x[7]}; |
|
|
|
int8_t x2[4]{x[8], x[9], x[10], x[11]}; |
|
|
|
int8_t x3[4]{x[12], x[13], x[14], x[15]}; |
|
|
|
return ::make_int4( |
|
|
|
pack_int8<4, int8_t, int>(x0), pack_int8<4, int8_t, int>(x1), |
|
|
|
pack_int8<4, int8_t, int>(x2), pack_int8<4, int8_t, int>(x3)); |
|
|
|
return ::make_int4(pack_int8<4, 8, int>(x0), pack_int8<4, 8, int>(x1), |
|
|
|
pack_int8<4, 8, int>(x2), pack_int8<4, 8, int>(x3)); |
|
|
|
} |
|
|
|
|
|
|
|
__device__ __forceinline__ int8_t pack_int8_to_int4x2(int8_t x0, int8_t x1) { |
|
|
|
return (x0 & 0xf) | (x1 << 4); |
|
|
|
} |
|
|
|
template <> |
|
|
|
__device__ __forceinline__ int pack_int8<8, dt_qint4, int>(int8_t (&x)[8]) { |
|
|
|
__device__ __forceinline__ int pack_int8<8, 4, int>(int8_t (&x)[8]) { |
|
|
|
int8_t x0 = pack_int8_to_int4x2(x[0], x[1]); |
|
|
|
int8_t x1 = pack_int8_to_int4x2(x[2], x[3]); |
|
|
|
int8_t x2 = pack_int8_to_int4x2(x[4], x[5]); |
|
|
|
int8_t x3 = pack_int8_to_int4x2(x[6], x[7]); |
|
|
|
return pack_int8_to_int8x4(x0, x1, x2, x3); |
|
|
|
} |
|
|
|
|
|
|
|
template <> |
|
|
|
__device__ __forceinline__ int4 pack_int8<32, dt_qint4, int4>(int8_t (&x)[32]) { |
|
|
|
__device__ __forceinline__ int4 pack_int8<32, 4, int4>(int8_t (&x)[32]) { |
|
|
|
int8_t x0[8]{x[0], x[1], x[2], x[3], x[4], x[5], x[6], x[7]}; |
|
|
|
int8_t x1[8]{x[8], x[9], x[10], x[11], x[12], x[13], x[14], x[15]}; |
|
|
|
int8_t x2[8]{x[16], x[17], x[18], x[19], x[20], x[21], x[22], x[23]}; |
|
|
|
int8_t x3[8]{x[24], x[25], x[26], x[27], x[28], x[29], x[30], x[31]}; |
|
|
|
return ::make_int4( |
|
|
|
pack_int8<8, dt_qint4, int>(x0), pack_int8<8, dt_qint4, int>(x1), |
|
|
|
pack_int8<8, dt_qint4, int>(x2), pack_int8<8, dt_qint4, int>(x3)); |
|
|
|
return ::make_int4(pack_int8<8, 4, int>(x0), pack_int8<8, 4, int>(x1), |
|
|
|
pack_int8<8, 4, int>(x2), pack_int8<8, 4, int>(x3)); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Dtype> |
|
|
@@ -88,6 +86,7 @@ struct TypeTrait<int8_t> { |
|
|
|
static constexpr int8_t min = -128; |
|
|
|
static constexpr int elem_per_32bit = 32 / bit_width; |
|
|
|
static constexpr int shift_fix_sign = 0; |
|
|
|
static constexpr bool need_zero_pad = false; |
|
|
|
}; |
|
|
|
|
|
|
|
template <> |
|
|
@@ -97,6 +96,16 @@ struct TypeTrait<dt_qint4> { |
|
|
|
static constexpr int8_t min = -8; |
|
|
|
static constexpr int elem_per_32bit = 32 / bit_width; |
|
|
|
static constexpr int shift_fix_sign = 4; |
|
|
|
static constexpr bool need_zero_pad = false; |
|
|
|
}; |
|
|
|
template <> |
|
|
|
struct TypeTrait<dt_quint4> { |
|
|
|
static constexpr int bit_width = 4; |
|
|
|
static constexpr int mask = 0xf; |
|
|
|
static constexpr int8_t min = 0; |
|
|
|
static constexpr int elem_per_32bit = 32 / bit_width; |
|
|
|
static constexpr int shift_fix_sign = 0; |
|
|
|
static constexpr bool need_zero_pad = true; |
|
|
|
}; |
|
|
|
|
|
|
|
template <typename src_type, typename _feed_type> |
|
|
@@ -108,7 +117,7 @@ struct MaxPooler { |
|
|
|
static constexpr int shift_fix_sign = TypeTrait<src_type>::shift_fix_sign; |
|
|
|
int8_t res[nr_results]; |
|
|
|
|
|
|
|
__device__ MaxPooler(int) {} |
|
|
|
__device__ MaxPooler(int, int) {} |
|
|
|
__device__ __forceinline__ void init() { |
|
|
|
#pragma unroll |
|
|
|
for (int i = 0; i < nr_results; ++i) { |
|
|
@@ -137,7 +146,7 @@ struct MaxPooler { |
|
|
|
} |
|
|
|
__device__ __forceinline__ feed_type get_ans() { |
|
|
|
feed_type ans; |
|
|
|
ans = pack_int8<nr_results, src_type, feed_type>(res); |
|
|
|
ans = pack_int8<nr_results, bit_width, feed_type>(res); |
|
|
|
return ans; |
|
|
|
} |
|
|
|
}; |
|
|
@@ -149,21 +158,27 @@ struct MeanIncludeRoundedPooler { |
|
|
|
static constexpr int nr_results = sizeof(feed_type) * 8 / bit_width; |
|
|
|
static constexpr int elem_per_32bit = TypeTrait<src_type>::elem_per_32bit; |
|
|
|
static constexpr int shift_fix_sign = TypeTrait<src_type>::shift_fix_sign; |
|
|
|
static constexpr bool need_zero_pad = TypeTrait<src_type>::need_zero_pad; |
|
|
|
|
|
|
|
int32_t res[nr_results]; |
|
|
|
const int count; |
|
|
|
const float fi_count; |
|
|
|
int real_fi_count; |
|
|
|
const int zero_pad; |
|
|
|
|
|
|
|
__device__ MeanIncludeRoundedPooler(int count) |
|
|
|
: count{count}, fi_count{1.f / count} {} |
|
|
|
__device__ MeanIncludeRoundedPooler(int count, int zero_point) |
|
|
|
: count{count}, fi_count{1.f / count}, zero_pad{zero_point} {} |
|
|
|
|
|
|
|
__device__ __forceinline__ void init() { |
|
|
|
#pragma unroll |
|
|
|
for (int i = 0; i < nr_results; ++i) { |
|
|
|
res[i] = 0; |
|
|
|
} |
|
|
|
if (need_zero_pad) { |
|
|
|
real_fi_count = 0; |
|
|
|
} |
|
|
|
} |
|
|
|
__device__ __forceinline__ void feed(int x, int idx = 0) { |
|
|
|
__device__ __forceinline__ void feed(int x, int idx) { |
|
|
|
constexpr int unroll_n = sizeof(int) * 8 / bit_width; |
|
|
|
#pragma unroll |
|
|
|
for (int i = 0; i < unroll_n; i++) { |
|
|
@@ -173,15 +188,27 @@ struct MeanIncludeRoundedPooler { |
|
|
|
res[idx + i] += static_cast<int32_t>(temp); |
|
|
|
} |
|
|
|
} |
|
|
|
__device__ __forceinline__ void feed(int x) { |
|
|
|
feed(x, 0); |
|
|
|
if (need_zero_pad) { |
|
|
|
real_fi_count++; |
|
|
|
} |
|
|
|
} |
|
|
|
__device__ __forceinline__ void feed(int2 x) { |
|
|
|
feed(x.x, 0 * elem_per_32bit); |
|
|
|
feed(x.y, 1 * elem_per_32bit); |
|
|
|
if (need_zero_pad) { |
|
|
|
real_fi_count++; |
|
|
|
} |
|
|
|
} |
|
|
|
__device__ __forceinline__ void feed(int4 x) { |
|
|
|
feed(x.x, 0 * elem_per_32bit); |
|
|
|
feed(x.y, 1 * elem_per_32bit); |
|
|
|
feed(x.z, 2 * elem_per_32bit); |
|
|
|
feed(x.w, 3 * elem_per_32bit); |
|
|
|
if (need_zero_pad) { |
|
|
|
real_fi_count++; |
|
|
|
} |
|
|
|
} |
|
|
|
__device__ __forceinline__ feed_type get_ans() { |
|
|
|
feed_type ans; |
|
|
@@ -189,13 +216,18 @@ struct MeanIncludeRoundedPooler { |
|
|
|
#pragma unroll |
|
|
|
for (int i = 0; i < nr_results; i++) { |
|
|
|
float f32_res = roundf(static_cast<float>(res[i]) * fi_count); |
|
|
|
if (need_zero_pad) { |
|
|
|
f32_res = roundf((static_cast<float>(res[i]) + |
|
|
|
(count - real_fi_count) * zero_pad) * |
|
|
|
fi_count); |
|
|
|
} |
|
|
|
int i8_res; |
|
|
|
asm volatile("cvt.rni.s8.f32 %0, %1;" |
|
|
|
: "=r"(i8_res) |
|
|
|
: "f"(f32_res)); |
|
|
|
out_res[i] = i8_res; |
|
|
|
} |
|
|
|
ans = pack_int8<nr_results, src_type, feed_type>(out_res); |
|
|
|
ans = pack_int8<nr_results, bit_width, feed_type>(out_res); |
|
|
|
return ans; |
|
|
|
} |
|
|
|
}; |
|
|
@@ -209,7 +241,7 @@ struct MeanExcludeRoundedPooler { |
|
|
|
static constexpr int shift_fix_sign = TypeTrait<src_type>::shift_fix_sign; |
|
|
|
int32_t res[nr_results]; |
|
|
|
int count; |
|
|
|
__device__ MeanExcludeRoundedPooler(int) {} |
|
|
|
__device__ MeanExcludeRoundedPooler(int, int) {} |
|
|
|
|
|
|
|
__device__ __forceinline__ void init() { |
|
|
|
#pragma unroll |
|
|
@@ -257,7 +289,7 @@ struct MeanExcludeRoundedPooler { |
|
|
|
: "f"(f32_res)); |
|
|
|
out_res[i] = i8_res; |
|
|
|
} |
|
|
|
ans = pack_int8<nr_results, src_type, feed_type>(out_res); |
|
|
|
ans = pack_int8<nr_results, bit_width, feed_type>(out_res); |
|
|
|
return ans; |
|
|
|
} |
|
|
|
}; |
|
|
@@ -290,7 +322,7 @@ __global__ void pooling2d_device_template_int8_cdiv4hwn4( |
|
|
|
packed_ch * output_pixels * npack + |
|
|
|
(ho * param.wo + wo) * npack; |
|
|
|
|
|
|
|
Pooler pooler(param.window_h * param.window_w); |
|
|
|
Pooler pooler(param.window_h * param.window_w, 0); |
|
|
|
pooler.init(); |
|
|
|
for (int fh = 0; fh < param.window_h; fh++) { |
|
|
|
uint32_t ih = ho * param.sh + fh - param.ph; |
|
|
@@ -313,7 +345,7 @@ template <typename Pooler, int pack_size, int pack_byte, |
|
|
|
int ldg_width_assert = 4> |
|
|
|
__global__ void pooling2d_device_template_nchwc(const int8_t* __restrict__ src, |
|
|
|
int8_t* __restrict__ dst, |
|
|
|
Param param) { |
|
|
|
Param param, int zero_point) { |
|
|
|
const int tid = blockIdx.x * blockDim.x + threadIdx.x; |
|
|
|
using ldg_type = typename Pooler::feed_type; |
|
|
|
static int constexpr ldg_width = sizeof(ldg_type) / sizeof(int32_t); |
|
|
@@ -348,7 +380,7 @@ __global__ void pooling2d_device_template_nchwc(const int8_t* __restrict__ src, |
|
|
|
dst + (batch * out_batch_stride + oc * out_channel_stride + |
|
|
|
(oh * param.wo + ow) * pack_byte + sec * ldg_width_bytes); |
|
|
|
|
|
|
|
Pooler pooler(param.window_h * param.window_w); |
|
|
|
Pooler pooler(param.window_h * param.window_w, zero_point); |
|
|
|
pooler.init(); |
|
|
|
for (int fh = 0; fh < param.window_h; fh++) { |
|
|
|
uint32_t ih = oh * param.sh + fh - param.ph; |
|
|
@@ -418,13 +450,12 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_cdiv4hwn4(const int8_t* d_src, |
|
|
|
after_kernel_launch(); |
|
|
|
} |
|
|
|
|
|
|
|
void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, |
|
|
|
int8_t* d_dst, |
|
|
|
const Param& param, |
|
|
|
cudaStream_t stream, |
|
|
|
uint32_t mode) { |
|
|
|
void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4( |
|
|
|
const int8_t* d_src, int8_t* d_dst, const Param& param, |
|
|
|
cudaStream_t stream, uint32_t mode, bool uint_case, int zero_point) { |
|
|
|
using Mode = megdnn::param_enumv::Pooling::Mode; |
|
|
|
void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); |
|
|
|
void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param, |
|
|
|
int zero_point); |
|
|
|
constexpr int ldg_byte = 4; |
|
|
|
constexpr int elem_per_byte = 1; |
|
|
|
constexpr int pack_size = 4; |
|
|
@@ -455,17 +486,16 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv4hw4(const int8_t* d_src, |
|
|
|
uint32_t nr_threads = query_blocksize_for_kernel(kern); |
|
|
|
nr_threads = std::min(nr_threads, vthreads); |
|
|
|
uint32_t nr_blocks = DIVUP(vthreads, nr_threads); |
|
|
|
kern<<<nr_blocks, nr_threads, 0, stream>>>(d_src, d_dst, param); |
|
|
|
kern<<<nr_blocks, nr_threads, 0, stream>>>(d_src, d_dst, param, zero_point); |
|
|
|
after_kernel_launch(); |
|
|
|
} |
|
|
|
|
|
|
|
void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv32hw32(const int8_t* d_src, |
|
|
|
int8_t* d_dst, |
|
|
|
const Param& param, |
|
|
|
cudaStream_t stream, |
|
|
|
uint32_t mode) { |
|
|
|
void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv32hw32( |
|
|
|
const int8_t* d_src, int8_t* d_dst, const Param& param, |
|
|
|
cudaStream_t stream, uint32_t mode, bool uint_case, int zero_point) { |
|
|
|
using Mode = megdnn::param_enumv::Pooling::Mode; |
|
|
|
void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); |
|
|
|
void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param, |
|
|
|
int zero_point); |
|
|
|
constexpr int ldg_byte = 16; |
|
|
|
constexpr int elem_per_byte = 1; |
|
|
|
constexpr int pack_size = 32; |
|
|
@@ -494,17 +524,16 @@ void megdnn::cuda::pooling2d::do_pooling2d_int8_ncdiv32hw32(const int8_t* d_src, |
|
|
|
uint32_t nr_threads = query_blocksize_for_kernel(kern); |
|
|
|
nr_threads = std::min(nr_threads, vthreads); |
|
|
|
uint32_t nr_blocks = DIVUP(vthreads, nr_threads); |
|
|
|
kern<<<nr_blocks, nr_threads, 0, stream>>>(d_src, d_dst, param); |
|
|
|
kern<<<nr_blocks, nr_threads, 0, stream>>>(d_src, d_dst, param, zero_point); |
|
|
|
after_kernel_launch(); |
|
|
|
} |
|
|
|
|
|
|
|
void megdnn::cuda::pooling2d::do_pooling2d_int4_ncdiv64hw64(const int8_t* d_src, |
|
|
|
int8_t* d_dst, |
|
|
|
const Param& param, |
|
|
|
cudaStream_t stream, |
|
|
|
uint32_t mode) { |
|
|
|
void megdnn::cuda::pooling2d::do_pooling2d_int4_ncdiv64hw64( |
|
|
|
const int8_t* d_src, int8_t* d_dst, const Param& param, |
|
|
|
cudaStream_t stream, uint32_t mode, bool uint_case, int zero_point) { |
|
|
|
using Mode = megdnn::param_enumv::Pooling::Mode; |
|
|
|
void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param); |
|
|
|
void (*kern)(const int8_t* __restrict__, int8_t* __restrict__, Param param, |
|
|
|
int zero_point); |
|
|
|
constexpr int ldg_byte = 16; |
|
|
|
constexpr int elem_per_byte = 2; |
|
|
|
constexpr int pack_size = 64; |
|
|
@@ -512,28 +541,50 @@ void megdnn::cuda::pooling2d::do_pooling2d_int4_ncdiv64hw64(const int8_t* d_src, |
|
|
|
constexpr int elem_per_thread = ldg_byte * elem_per_byte; |
|
|
|
uint32_t vthreads = |
|
|
|
param.n * param.c * param.ho * param.wo / elem_per_thread; |
|
|
|
switch (mode) { |
|
|
|
case Mode::MAX: |
|
|
|
kern = pooling2d_device_template_nchwc<MaxPooler<dt_qint4, int4>, |
|
|
|
pack_size, pack_byte>; |
|
|
|
break; |
|
|
|
case Mode::AVERAGE: |
|
|
|
kern = pooling2d_device_template_nchwc< |
|
|
|
MeanIncludeRoundedPooler<dt_qint4, int4, int32_t>, |
|
|
|
pack_size, pack_byte>; |
|
|
|
break; |
|
|
|
case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: |
|
|
|
kern = pooling2d_device_template_nchwc< |
|
|
|
MeanExcludeRoundedPooler<dt_qint4, int4, int32_t>, |
|
|
|
pack_size, pack_byte>; |
|
|
|
break; |
|
|
|
default: |
|
|
|
megdnn_assert(false, "invalid pooling mode"); |
|
|
|
if (uint_case) { |
|
|
|
switch (mode) { |
|
|
|
case Mode::MAX: |
|
|
|
kern = pooling2d_device_template_nchwc< |
|
|
|
MaxPooler<dt_quint4, int4>, pack_size, pack_byte>; |
|
|
|
break; |
|
|
|
case Mode::AVERAGE: |
|
|
|
kern = pooling2d_device_template_nchwc< |
|
|
|
MeanIncludeRoundedPooler<dt_quint4, int4, int32_t>, |
|
|
|
pack_size, pack_byte>; |
|
|
|
break; |
|
|
|
case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: |
|
|
|
kern = pooling2d_device_template_nchwc< |
|
|
|
MeanExcludeRoundedPooler<dt_quint4, int4, int32_t>, |
|
|
|
pack_size, pack_byte>; |
|
|
|
break; |
|
|
|
default: |
|
|
|
megdnn_assert(false, "invalid pooling mode"); |
|
|
|
} |
|
|
|
|
|
|
|
} else { |
|
|
|
switch (mode) { |
|
|
|
case Mode::MAX: |
|
|
|
kern = pooling2d_device_template_nchwc<MaxPooler<dt_qint4, int4>, |
|
|
|
pack_size, pack_byte>; |
|
|
|
break; |
|
|
|
case Mode::AVERAGE: |
|
|
|
kern = pooling2d_device_template_nchwc< |
|
|
|
MeanIncludeRoundedPooler<dt_qint4, int4, int32_t>, |
|
|
|
pack_size, pack_byte>; |
|
|
|
break; |
|
|
|
case Mode::AVERAGE_COUNT_EXCLUDE_PADDING: |
|
|
|
kern = pooling2d_device_template_nchwc< |
|
|
|
MeanExcludeRoundedPooler<dt_qint4, int4, int32_t>, |
|
|
|
pack_size, pack_byte>; |
|
|
|
break; |
|
|
|
default: |
|
|
|
megdnn_assert(false, "invalid pooling mode"); |
|
|
|
} |
|
|
|
} |
|
|
|
uint32_t nr_threads = query_blocksize_for_kernel(kern); |
|
|
|
nr_threads = std::min(nr_threads, vthreads); |
|
|
|
uint32_t nr_blocks = DIVUP(vthreads, nr_threads); |
|
|
|
kern<<<nr_blocks, nr_threads, 0, stream>>>(d_src, d_dst, param); |
|
|
|
kern<<<nr_blocks, nr_threads, 0, stream>>>(d_src, d_dst, param, zero_point); |
|
|
|
after_kernel_launch(); |
|
|
|
} |
|
|
|
|
|
|
|