GitOrigin-RevId: 84e0815a59
release-1.10
@@ -21,9 +21,13 @@ public: | |||
DirectConvRunner(size_t flt_size, size_t stride) { | |||
if (flt_size == 9 && stride == 1) { | |||
m_func = megdnn_dot_nchw_large_chanwise_direct_conv_9x9s1_oh4_ow16; | |||
} else { | |||
megdnn_assert(flt_size == 9 && stride == 2); | |||
} else if (flt_size == 9 && stride == 2) { | |||
m_func = megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16; | |||
} else if (flt_size == 11 && stride == 1) { | |||
m_func = megdnn_dot_nchw_large_chanwise_direct_conv_11x11s1_oh4_ow16; | |||
} else { | |||
megdnn_assert(flt_size == 11 && stride == 2); | |||
m_func = megdnn_dot_nchw_large_chanwise_direct_conv_11x11s2_oh4_ow16; | |||
} | |||
} | |||
size_t get_round_fw(const ConvBiasImpl::NCBKernSizeParam& param) const { | |||
@@ -208,8 +212,8 @@ bool ConvBiasImpl::AlgoDotS8DirectChanWiseLarge::usable( | |||
(bias_mode == BiasMode::NO_BIAS || | |||
bias_mode == BiasMode::BROADCAST_CHANNEL_BIAS) && | |||
fm.spatial_ndim == 2 && fm.dilation[0] == 1 && fm.dilation[1] == 1 && | |||
SH == SW && (SH == 1 || SH == 2) && FH == FW && (FH == 9) && fm.icpg == 1 && | |||
fm.ocpg == 1; | |||
SH == SW && (SH == 1 || SH == 2) && FH == FW && (FH == 9 || FH == 11) && | |||
fm.icpg == 1 && fm.ocpg == 1; | |||
return avaible; | |||
} | |||
@@ -12,4 +12,13 @@ void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16( | |||
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, | |||
int8_t relu_val); | |||
void megdnn_dot_nchw_large_chanwise_direct_conv_11x11s1_oh4_ow16( | |||
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, | |||
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, | |||
int8_t relu_val); | |||
void megdnn_dot_nchw_large_chanwise_direct_conv_11x11s2_oh4_ow16( | |||
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, | |||
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, | |||
int8_t relu_val); | |||
#endif |
@@ -0,0 +1,240 @@ | |||
#include "megdnn/arch.h" | |||
#if MGB_ENABLE_DOT | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h" | |||
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h" | |||
#include "src/common/unroll_macro.h" | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
void megdnn_dot_nchw_large_chanwise_direct_conv_11x11s1_oh4_ow16( | |||
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, | |||
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, | |||
int8_t relu_val) { | |||
//! 4x16 | |||
const size_t SH = 1; | |||
const size_t SW = 1; | |||
static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 1, 2, 3, 4, | |||
2, 3, 4, 5, 3, 4, 5, 6}; | |||
static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 5, 6, 7, 8, | |||
6, 7, 8, 9, 7, 8, 9, 10}; | |||
static const uint8_t tbl_array_2[16] = {8, 9, 10, 11, 9, 10, 11, 12, | |||
10, 11, 12, 13, 11, 12, 13, 14}; | |||
uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); | |||
uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); | |||
uint8x16_t tbl_reg_2 = vld1q_u8(&tbl_array_2[0]); | |||
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||
//! init | |||
int32x4_t c[4][4]; | |||
#define cb(step) \ | |||
c[step][0] = vdupq_n_s32(bias); \ | |||
c[step][1] = vdupq_n_s32(bias); \ | |||
c[step][2] = vdupq_n_s32(bias); \ | |||
c[step][3] = vdupq_n_s32(bias); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
#define flt_reg 4 | |||
int8x16_t flt[flt_reg]; | |||
flt[0] = vld1q_s8(weight + 0 * 16); | |||
flt[1] = vld1q_s8(weight + 1 * 16); | |||
flt[2] = vld1q_s8(weight + 2 * 16); | |||
flt[3] = vld1q_s8(weight + 3 * 16); | |||
//! row 0 | |||
int8x16_t read_w[2]; | |||
read_w[0] = vld1q_s8(src_n + 0 * pad_iw); | |||
read_w[1] = vld1q_s8(src_n + 0 * pad_iw + 16); | |||
int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); | |||
int8x16_t n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1); | |||
int8x16_t n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2); | |||
int8x16_t ncdef_0 = vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0); | |||
int8x16_t n0123_1 = n4567_0; | |||
int8x16_t n4567_1 = n89ab_0; | |||
int8x16_t n89ab_1 = ncdef_0; | |||
int8x16_t ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0); | |||
int8x16_t n0123_2 = n89ab_0; | |||
int8x16_t n4567_2 = ncdef_0; | |||
int8x16_t n89ab_2 = ncdef_1; | |||
int8x16_t ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1); | |||
#define CAL_C(oh, flt_start) \ | |||
c[oh][0] = vdotq_laneq_s32( \ | |||
c[oh][0], n0123_0, flt[(flt_start + 0) / 4 % flt_reg], \ | |||
(flt_start + 0) % 4); \ | |||
c[oh][1] = vdotq_laneq_s32( \ | |||
c[oh][1], n4567_0, flt[(flt_start + 0) / 4 % flt_reg], \ | |||
(flt_start + 0) % 4); \ | |||
c[oh][2] = vdotq_laneq_s32( \ | |||
c[oh][2], n89ab_0, flt[(flt_start + 0) / 4 % flt_reg], \ | |||
(flt_start + 0) % 4); \ | |||
c[oh][3] = vdotq_laneq_s32( \ | |||
c[oh][3], ncdef_0, flt[(flt_start + 0) / 4 % flt_reg], \ | |||
(flt_start + 0) % 4); \ | |||
c[oh][0] = vdotq_laneq_s32( \ | |||
c[oh][0], n0123_1, flt[(flt_start + 1) / 4 % flt_reg], \ | |||
(flt_start + 1) % 4); \ | |||
c[oh][1] = vdotq_laneq_s32( \ | |||
c[oh][1], n4567_1, flt[(flt_start + 1) / 4 % flt_reg], \ | |||
(flt_start + 1) % 4); \ | |||
c[oh][2] = vdotq_laneq_s32( \ | |||
c[oh][2], n89ab_1, flt[(flt_start + 1) / 4 % flt_reg], \ | |||
(flt_start + 1) % 4); \ | |||
c[oh][3] = vdotq_laneq_s32( \ | |||
c[oh][3], ncdef_1, flt[(flt_start + 1) / 4 % flt_reg], \ | |||
(flt_start + 1) % 4); \ | |||
c[oh][0] = vdotq_laneq_s32( \ | |||
c[oh][0], n0123_2, flt[(flt_start + 2) / 4 % flt_reg], \ | |||
(flt_start + 2) % 4); \ | |||
c[oh][1] = vdotq_laneq_s32( \ | |||
c[oh][1], n4567_2, flt[(flt_start + 2) / 4 % flt_reg], \ | |||
(flt_start + 2) % 4); \ | |||
c[oh][2] = vdotq_laneq_s32( \ | |||
c[oh][2], n89ab_2, flt[(flt_start + 2) / 4 % flt_reg], \ | |||
(flt_start + 2) % 4); \ | |||
c[oh][3] = vdotq_laneq_s32( \ | |||
c[oh][3], ncdef_2, flt[(flt_start + 2) / 4 % flt_reg], \ | |||
(flt_start + 2) % 4); | |||
CAL_C(0, 0); | |||
//! row 1 | |||
#define LOAD_SRC(row_id) \ | |||
read_w[0] = vld1q_s8(src_n + row_id * pad_iw); \ | |||
read_w[1] = vld1q_s8(src_n + row_id * pad_iw + 16); \ | |||
n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); \ | |||
n4567_0 = vqtbl1q_s8(read_w[0], tbl_reg_1); \ | |||
n89ab_0 = vqtbl1q_s8(read_w[0], tbl_reg_2); \ | |||
ncdef_0 = vqtbl1q_s8(vextq_s8(read_w[0], read_w[1], 12), tbl_reg_0); \ | |||
n0123_1 = n4567_0; \ | |||
n4567_1 = n89ab_0; \ | |||
n89ab_1 = ncdef_0; \ | |||
ncdef_1 = vqtbl1q_s8(read_w[1], tbl_reg_0); \ | |||
n0123_2 = n89ab_0; \ | |||
n4567_2 = ncdef_0; \ | |||
n89ab_2 = ncdef_1; \ | |||
ncdef_2 = vqtbl1q_s8(read_w[1], tbl_reg_1); | |||
LOAD_SRC(1); | |||
CAL_C(0, 3); | |||
CAL_C(1, 0); | |||
//! row 2 | |||
LOAD_SRC(2); | |||
CAL_C(0, 3 * 2); | |||
CAL_C(1, 3 * 1); | |||
CAL_C(2, 3 * 0); | |||
//! row 3 | |||
LOAD_SRC(3); | |||
CAL_C(0, 3 * 3); | |||
CAL_C(1, 3 * 2); | |||
CAL_C(2, 3 * 1); | |||
CAL_C(3, 3 * 0); | |||
//! row 4 | |||
LOAD_SRC(4); | |||
CAL_C(0, 3 * 4); | |||
CAL_C(1, 3 * 3); | |||
CAL_C(2, 3 * 2); | |||
CAL_C(3, 3 * 1); | |||
//! update flt 4 -> 0 | |||
flt[0] = vld1q_s8(weight + 4 * 16); | |||
//! row 5 | |||
LOAD_SRC(5); | |||
CAL_C(0, 3 * 5); | |||
CAL_C(1, 3 * 4); | |||
CAL_C(2, 3 * 3); | |||
CAL_C(3, 3 * 2); | |||
//! update flt 5 -> 1 | |||
flt[1] = vld1q_s8(weight + 5 * 16); | |||
//! row 6 | |||
LOAD_SRC(6); | |||
CAL_C(0, 3 * 6); | |||
CAL_C(1, 3 * 5); | |||
CAL_C(2, 3 * 4); | |||
CAL_C(3, 3 * 3); | |||
//! update flt 6 -> 2 | |||
flt[2] = vld1q_s8(weight + 6 * 16); | |||
//! row 7 | |||
LOAD_SRC(7); | |||
CAL_C(0, 3 * 7); | |||
CAL_C(1, 3 * 6); | |||
CAL_C(2, 3 * 5); | |||
CAL_C(3, 3 * 4); | |||
//! row 8 | |||
LOAD_SRC(8); | |||
CAL_C(3, 3 * 5); | |||
//! update flt 7 -> 3 | |||
flt[3] = vld1q_s8(weight + 7 * 16); | |||
CAL_C(2, 3 * 6); | |||
CAL_C(1, 3 * 7); | |||
CAL_C(0, 3 * 8); | |||
//! row 9 | |||
LOAD_SRC(9); | |||
CAL_C(0, 3 * 9); | |||
CAL_C(1, 3 * 8); | |||
CAL_C(2, 3 * 7); | |||
CAL_C(3, 3 * 6); | |||
//! row 10 | |||
LOAD_SRC(10); | |||
//! update flt 8 -> 0 | |||
flt[0] = vld1q_s8(weight + 8 * 16); | |||
CAL_C(3, 3 * 7); | |||
CAL_C(2, 3 * 8); | |||
CAL_C(1, 3 * 9); | |||
CAL_C(0, 3 * 10); | |||
//! row 11 | |||
LOAD_SRC(11); | |||
CAL_C(1, 3 * 10); | |||
CAL_C(2, 3 * 9); | |||
CAL_C(3, 3 * 8); | |||
//! row 12 | |||
LOAD_SRC(12); | |||
CAL_C(2, 3 * 10); | |||
CAL_C(3, 3 * 9); | |||
//! row 13 | |||
LOAD_SRC(13); | |||
CAL_C(3, 3 * 10); | |||
float32x4_t dst_reg[4][4]; | |||
#define cb(step) \ | |||
dst_reg[step][0] = vcvtq_f32_s32(c[step][0]); \ | |||
dst_reg[step][1] = vcvtq_f32_s32(c[step][1]); \ | |||
dst_reg[step][2] = vcvtq_f32_s32(c[step][2]); \ | |||
dst_reg[step][3] = vcvtq_f32_s32(c[step][3]); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
#define cb(step) \ | |||
dst_reg[step][0] = vmulq_n_f32(dst_reg[step][0], scale); \ | |||
dst_reg[step][1] = vmulq_n_f32(dst_reg[step][1], scale); \ | |||
dst_reg[step][2] = vmulq_n_f32(dst_reg[step][2], scale); \ | |||
dst_reg[step][3] = vmulq_n_f32(dst_reg[step][3], scale); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
int8_t* dst_store = dst + oh * OW + ow; | |||
int8x16_t relu_reg = vdupq_n_s8(relu_val); | |||
#define cb(step) \ | |||
quant_store_s8( \ | |||
dst_reg[step][0], dst_reg[step][1], dst_reg[step][2], dst_reg[step][3], \ | |||
dst_store + step * OW, relu_reg); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
} | |||
#endif |
@@ -0,0 +1,249 @@ | |||
#include "megdnn/arch.h" | |||
#if MGB_ENABLE_DOT | |||
#include "src/arm_common/simd_macro/marm_neon.h" | |||
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large.h" | |||
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_large_common.h" | |||
#include "src/common/unroll_macro.h" | |||
MEGDNN_ATTRIBUTE_TARGET("dotprod") | |||
void megdnn_dot_nchw_large_chanwise_direct_conv_11x11s2_oh4_ow16( | |||
const int8_t* src, const int8_t* weight, int32_t bias, int8_t* dst, size_t oh, | |||
size_t ow, size_t OH, size_t OW, size_t pad_iw, const float scale, | |||
int8_t relu_val) { | |||
//! 4x16 | |||
const size_t SH = 2; | |||
const size_t SW = 2; | |||
static const uint8_t tbl_array_0[16] = {0, 1, 2, 3, 2, 3, 4, 5, | |||
4, 5, 6, 7, 6, 7, 8, 9}; | |||
static const uint8_t tbl_array_1[16] = {4, 5, 6, 7, 6, 7, 8, 9, | |||
8, 9, 10, 11, 10, 11, 12, 13}; | |||
uint8x16_t tbl_reg_0 = vld1q_u8(&tbl_array_0[0]); | |||
uint8x16_t tbl_reg_1 = vld1q_u8(&tbl_array_1[0]); | |||
const int8_t* src_n = src + oh * SH * pad_iw + ow * SW; | |||
//! init | |||
int32x4_t c[4][4]; | |||
#define cb(step) \ | |||
c[step][0] = vdupq_n_s32(bias); \ | |||
c[step][1] = vdupq_n_s32(bias); \ | |||
c[step][2] = vdupq_n_s32(bias); \ | |||
c[step][3] = vdupq_n_s32(bias); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
#define flt_reg 9 | |||
#define flt_per_reg 4 | |||
int8x16_t flt[flt_reg]; | |||
#define cb(step) flt[step] = vld1q_s8(weight + step * 16); | |||
UNROLL_CALL_RAW(flt_reg, cb); | |||
#undef cb | |||
#define CAL_C(oh, flt_start) \ | |||
c[oh][0] = vdotq_laneq_s32( \ | |||
c[oh][0], n0123_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ | |||
(flt_start + 0) % flt_per_reg); \ | |||
c[oh][1] = vdotq_laneq_s32( \ | |||
c[oh][1], n4567_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ | |||
(flt_start + 0) % flt_per_reg); \ | |||
c[oh][2] = vdotq_laneq_s32( \ | |||
c[oh][2], n89ab_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ | |||
(flt_start + 0) % flt_per_reg); \ | |||
c[oh][3] = vdotq_laneq_s32( \ | |||
c[oh][3], ncdef_0, flt[(flt_start + 0) / flt_per_reg % flt_reg], \ | |||
(flt_start + 0) % flt_per_reg); \ | |||
c[oh][0] = vdotq_laneq_s32( \ | |||
c[oh][0], n0123_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ | |||
(flt_start + 1) % flt_per_reg); \ | |||
c[oh][1] = vdotq_laneq_s32( \ | |||
c[oh][1], n4567_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ | |||
(flt_start + 1) % flt_per_reg); \ | |||
c[oh][2] = vdotq_laneq_s32( \ | |||
c[oh][2], n89ab_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ | |||
(flt_start + 1) % flt_per_reg); \ | |||
c[oh][3] = vdotq_laneq_s32( \ | |||
c[oh][3], ncdef_1, flt[(flt_start + 1) / flt_per_reg % flt_reg], \ | |||
(flt_start + 1) % flt_per_reg); \ | |||
c[oh][0] = vdotq_laneq_s32( \ | |||
c[oh][0], n0123_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ | |||
(flt_start + 2) % flt_per_reg); \ | |||
c[oh][1] = vdotq_laneq_s32( \ | |||
c[oh][1], n4567_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ | |||
(flt_start + 2) % flt_per_reg); \ | |||
c[oh][2] = vdotq_laneq_s32( \ | |||
c[oh][2], n89ab_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ | |||
(flt_start + 2) % flt_per_reg); \ | |||
c[oh][3] = vdotq_laneq_s32( \ | |||
c[oh][3], ncdef_2, flt[(flt_start + 2) / flt_per_reg % flt_reg], \ | |||
(flt_start + 2) % flt_per_reg); | |||
#define LOAD_SRC(row_id) \ | |||
read_w[0] = vld1q_s8(src_n + row_id * pad_iw); \ | |||
read_w[1] = vld1q_s8(src_n + row_id * pad_iw + 16); \ | |||
read_w[2] = vld1q_s8(src_n + row_id * pad_iw + 32); \ | |||
ext_8 = vextq_s8(read_w[0], read_w[1], 8); \ | |||
ext_24 = vextq_s8(read_w[1], read_w[2], 8); \ | |||
n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); \ | |||
n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0); \ | |||
n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0); \ | |||
ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0); \ | |||
n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1); \ | |||
n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1); \ | |||
n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1); \ | |||
ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1); \ | |||
n0123_2 = n4567_0; \ | |||
n4567_2 = n89ab_0; \ | |||
n89ab_2 = ncdef_0; \ | |||
ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0); | |||
//! row 0 | |||
int8x16_t read_w[3]; | |||
read_w[0] = vld1q_s8(src_n); | |||
read_w[1] = vld1q_s8(src_n + 16); | |||
read_w[2] = vld1q_s8(src_n + 32); | |||
int8x16_t ext_8 = vextq_s8(read_w[0], read_w[1], 8); | |||
int8x16_t ext_24 = vextq_s8(read_w[1], read_w[2], 8); | |||
int8x16_t n0123_0 = vqtbl1q_s8(read_w[0], tbl_reg_0); | |||
int8x16_t n4567_0 = vqtbl1q_s8(ext_8, tbl_reg_0); | |||
int8x16_t n89ab_0 = vqtbl1q_s8(read_w[1], tbl_reg_0); | |||
int8x16_t ncdef_0 = vqtbl1q_s8(ext_24, tbl_reg_0); | |||
int8x16_t n0123_1 = vqtbl1q_s8(read_w[0], tbl_reg_1); | |||
int8x16_t n4567_1 = vqtbl1q_s8(ext_8, tbl_reg_1); | |||
int8x16_t n89ab_1 = vqtbl1q_s8(read_w[1], tbl_reg_1); | |||
int8x16_t ncdef_1 = vqtbl1q_s8(ext_24, tbl_reg_1); | |||
int8x16_t n0123_2 = n4567_0; | |||
int8x16_t n4567_2 = n89ab_0; | |||
int8x16_t n89ab_2 = ncdef_0; | |||
int8x16_t ncdef_2 = vqtbl1q_s8(read_w[2], tbl_reg_0); | |||
CAL_C(0, 0); | |||
//! row 1 | |||
LOAD_SRC(1); | |||
CAL_C(0, 3 * 1); | |||
//! row 2 | |||
LOAD_SRC(2); | |||
CAL_C(0, 3 * 2); | |||
CAL_C(1, 3 * 0); | |||
//! row 3 | |||
LOAD_SRC(3); | |||
CAL_C(0, 3 * 3); | |||
CAL_C(1, 3 * 1); | |||
//! row 4 | |||
LOAD_SRC(4); | |||
CAL_C(0, 3 * 4); | |||
CAL_C(1, 3 * 2); | |||
CAL_C(2, 3 * 0); | |||
//! row 5 | |||
LOAD_SRC(5); | |||
CAL_C(0, 3 * 5); | |||
CAL_C(1, 3 * 3); | |||
CAL_C(2, 3 * 1); | |||
//! row 6 | |||
LOAD_SRC(6); | |||
CAL_C(0, 3 * 6); | |||
CAL_C(1, 3 * 4); | |||
CAL_C(2, 3 * 2); | |||
CAL_C(3, 3 * 0); | |||
//! row 7 | |||
LOAD_SRC(7); | |||
CAL_C(0, 3 * 7); | |||
CAL_C(1, 3 * 5); | |||
CAL_C(2, 3 * 3); | |||
CAL_C(3, 3 * 1); | |||
//! row 8 | |||
LOAD_SRC(8); | |||
CAL_C(0, 3 * 8); | |||
CAL_C(1, 3 * 6); | |||
CAL_C(2, 3 * 4); | |||
CAL_C(3, 3 * 2); | |||
//! row 9 | |||
LOAD_SRC(9); | |||
CAL_C(0, 3 * 9); | |||
CAL_C(1, 3 * 7); | |||
CAL_C(2, 3 * 5); | |||
CAL_C(3, 3 * 3); | |||
//! row 10 | |||
LOAD_SRC(10); | |||
CAL_C(0, 3 * 10); | |||
CAL_C(1, 3 * 8); | |||
CAL_C(2, 3 * 6); | |||
CAL_C(3, 3 * 4); | |||
//! row 11 | |||
LOAD_SRC(11); | |||
CAL_C(1, 3 * 9); | |||
CAL_C(2, 3 * 7); | |||
CAL_C(3, 3 * 5); | |||
//! row 12 | |||
LOAD_SRC(12); | |||
CAL_C(1, 3 * 10); | |||
CAL_C(2, 3 * 8); | |||
CAL_C(3, 3 * 6); | |||
//! row 13 | |||
LOAD_SRC(13); | |||
CAL_C(2, 3 * 9); | |||
CAL_C(3, 3 * 7); | |||
//! row 14 | |||
LOAD_SRC(14); | |||
CAL_C(2, 3 * 10); | |||
CAL_C(3, 3 * 8); | |||
//! row 15 | |||
LOAD_SRC(15); | |||
CAL_C(3, 3 * 9); | |||
//! row 16 | |||
LOAD_SRC(16); | |||
CAL_C(3, 3 * 10); | |||
float32x4_t dst_reg[4][4]; | |||
#define cb(step) \ | |||
dst_reg[step][0] = vcvtq_f32_s32(c[step][0]); \ | |||
dst_reg[step][1] = vcvtq_f32_s32(c[step][1]); \ | |||
dst_reg[step][2] = vcvtq_f32_s32(c[step][2]); \ | |||
dst_reg[step][3] = vcvtq_f32_s32(c[step][3]); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
#define cb(step) \ | |||
dst_reg[step][0] = vmulq_n_f32(dst_reg[step][0], scale); \ | |||
dst_reg[step][1] = vmulq_n_f32(dst_reg[step][1], scale); \ | |||
dst_reg[step][2] = vmulq_n_f32(dst_reg[step][2], scale); \ | |||
dst_reg[step][3] = vmulq_n_f32(dst_reg[step][3], scale); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
int8_t* dst_store = dst + oh * OW + ow; | |||
int8x16_t relu_reg = vdupq_n_s8(relu_val); | |||
#define cb(step) \ | |||
quant_store_s8( \ | |||
dst_reg[step][0], dst_reg[step][1], dst_reg[step][2], dst_reg[step][3], \ | |||
dst_store + step * OW, relu_reg); | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
} | |||
#endif |
@@ -36,16 +36,14 @@ void megdnn_dot_nchw_large_chanwise_direct_conv_9x9s2_oh4_ow16( | |||
UNROLL_CALL_RAW(4, cb); | |||
#undef cb | |||
constexpr int flt_reg = 7; | |||
constexpr int flt_per_reg = 4; | |||
int8x16_t flt[7]; | |||
flt[0] = vld1q_s8(weight + 0 * 16); | |||
flt[1] = vld1q_s8(weight + 1 * 16); | |||
flt[2] = vld1q_s8(weight + 2 * 16); | |||
flt[3] = vld1q_s8(weight + 3 * 16); | |||
flt[4] = vld1q_s8(weight + 4 * 16); | |||
flt[5] = vld1q_s8(weight + 5 * 16); | |||
flt[6] = vld1q_s8(weight + 6 * 16); | |||
#define flt_reg 7 | |||
#define flt_per_reg 4 | |||
int8x16_t flt[flt_reg]; | |||
#define cb(step) flt[step] = vld1q_s8(weight + step * 16); | |||
UNROLL_CALL_RAW(flt_reg, cb); | |||
#undef cb | |||
#define CAL_C(oh, flt_start) \ | |||
c[oh][0] = vdotq_laneq_s32( \ | |||
@@ -2060,6 +2060,16 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) { | |||
benchmark1.set_display(false); | |||
benchmark1.set_times(RUN); | |||
Benchmarker<ConvBias> benchmark2(handle()); | |||
benchmark2.set_dtype(0, dtype::QuantizedS8(2.5f)) | |||
.set_dtype(1, dtype::QuantizedS8(2.5f)) | |||
.set_dtype(2, dtype::QuantizedS32(6.25f)) | |||
.set_dtype(4, dtype::QuantizedS8(60.25f)); | |||
benchmark2.set_before_exec_callback( | |||
conv_bias::ConvBiasAlgoChecker<ConvBiasForward>("ARMDOTS8")); | |||
benchmark2.set_display(false); | |||
benchmark2.set_times(RUN); | |||
for (auto&& arg : args) { | |||
TensorLayout dst_layout; | |||
auto opr = handle()->create_operator<ConvBias>(); | |||
@@ -2070,6 +2080,12 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) { | |||
//! dst.nr_elems * FH * FW * 2 | |||
float computations = | |||
dst_layout.total_nr_elems() * arg.filter[3] * arg.filter[4] * 2.0 / 1e6; | |||
float computations_5x5 = dst_layout.total_nr_elems() * 5 * 5 * 2.0 / 1e6; | |||
float computations_11x11 = dst_layout.total_nr_elems() * 11 * 11 * 2.0 / 1e6; | |||
param::ConvBias param_5x5 = arg.param; | |||
param_5x5.pad_h = param_5x5.pad_w = 5 / 2; | |||
param::ConvBias param_11x11 = arg.param; | |||
param_11x11.pad_h = param_11x11.pad_w = 11 / 2; | |||
auto used0 = benchmark0.set_param(arg.param).exec( | |||
{arg.src, arg.filter, arg.bias, {}, {}}) / | |||
@@ -2077,11 +2093,26 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_LARGE_KERN_NCHW_DOT) { | |||
auto used1 = benchmark1.set_param(arg.param).exec( | |||
{arg.src, arg.filter, arg.bias, {}, {}}) / | |||
RUN; | |||
TensorShape flt_5x5_shape = arg.filter; | |||
flt_5x5_shape[3] = flt_5x5_shape[4] = 5; | |||
auto used5x5 = benchmark2.set_param(param_5x5).exec( | |||
{arg.src, flt_5x5_shape, arg.bias, {}, {}}) / | |||
RUN; | |||
TensorShape flt_11x11_shape = arg.filter; | |||
flt_11x11_shape[3] = flt_11x11_shape[4] = 11; | |||
auto used11x11 = benchmark0.set_param(param_11x11) | |||
.exec({arg.src, flt_11x11_shape, arg.bias, {}, {}}) / | |||
RUN; | |||
printf("%s %s: Direct use: %f ms %f Gflops im2col: %f ms %f GFlops " | |||
"speedup: %f\n", | |||
arg.src.to_string().c_str(), arg.filter.to_string().c_str(), used0, | |||
computations / used0, used1, computations / used1, used1 / used0); | |||
printf("%s %s s %u: Direct use: %f ms %f Gflops im2col: %f ms %f GFlops " | |||
"speedup: %f, compare 5x5 %f ms %f GFlops speedup %f, compare 11x11 %f " | |||
"ms %f GFops speedup %f\n", | |||
arg.src.to_string().c_str(), arg.filter.to_string().c_str(), | |||
arg.param.stride_h, used0, computations / used0, used1, | |||
computations / used1, used1 / used0, used5x5, computations_5x5 / used5x5, | |||
used5x5 / used0, used11x11, computations_11x11 / used11x11, | |||
used11x11 / used0); | |||
} | |||
} | |||
@@ -612,13 +612,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) { | |||
#if MGB_ENABLE_DOT | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S1) { | |||
checker_conv_bias_qint8x8x8( | |||
get_channel_wise_args({9}, 1, false, true, true, true), handle(), | |||
get_channel_wise_args({9, 11}, 1, false, true, true, true), handle(), | |||
"ARMDOTS8_DIRECT_CHANWISE_LARGE"); | |||
} | |||
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_DIRECT_LARGE_S2) { | |||
checker_conv_bias_qint8x8x8( | |||
get_channel_wise_args({9}, 2, false, true, true, true), handle(), | |||
get_channel_wise_args({9, 11}, 2, false, true, true, true), handle(), | |||
"ARMDOTS8_DIRECT_CHANWISE_LARGE"); | |||
} | |||