diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h index b40e1688..3bba1637 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s1.h @@ -24,82 +24,75 @@ using namespace megdnn; using namespace arm_common; namespace { -template +template struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); }; -template -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step, lane) \ - c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ - src[(step + src_idx) % 8], lane); \ - c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ - src[(step + src_idx) % 8], lane); - - UNROLL_CALL_RAW(8, cb, 0); - UNROLL_CALL_RAW(8, cb, 1); - UNROLL_CALL_RAW(8, cb, 2); - UNROLL_CALL_RAW(8, cb, 3); -#undef cb - } -}; -template -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step, lane) \ - c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ - src[(step + src_idx) % 4], lane); \ - c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ - src[(step + src_idx) % 4], lane); - - UNROLL_CALL_RAW(4, cb, 0); - UNROLL_CALL_RAW(4, cb, 1); - UNROLL_CALL_RAW(4, cb, 2); - UNROLL_CALL_RAW(4, cb, 3); -#undef cb - } +template +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&) {} }; -template -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step, lane) \ - c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ - src[(step + src_idx) % 8], lane); - UNROLL_CALL_RAW(8, cb, 0); - UNROLL_CALL_RAW(8, cb, 1); - UNROLL_CALL_RAW(8, cb, 2); - UNROLL_CALL_RAW(8, cb, 3); -#undef cb - } -}; -template -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step, lane) \ - c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ - src[(step + src_idx) % 4], lane); +#define cb2(step, lane, ow_block) \ + c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ + src[(step + src_idx) % ow_block], lane); \ + c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ + src[(step + src_idx) % ow_block], lane); - UNROLL_CALL_RAW(4, cb, 0); - UNROLL_CALL_RAW(4, cb, 1); - UNROLL_CALL_RAW(4, cb, 2); - UNROLL_CALL_RAW(4, cb, 3); +#define cb(step, lane, ow_block) \ + c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ + src[(step + src_idx) % ow_block], lane); + +#define SHIFT_CAL_HELPER(ow_block, remain_w) \ + template \ + struct ShiftCalHelper { \ + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ + UNROLL_CALL_RAW(remain_w, cb2, 0, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb2, 1, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb2, 2, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb2, 3, ow_block); \ + } \ + }; \ + template \ + struct ShiftCalHelper { \ + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ + UNROLL_CALL_RAW(remain_w, cb, 0, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb, 1, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb, 2, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb, 3, ow_block); \ + } \ + }; + +SHIFT_CAL_HELPER(8, 1); +SHIFT_CAL_HELPER(8, 2); +SHIFT_CAL_HELPER(8, 3); +SHIFT_CAL_HELPER(8, 4); +SHIFT_CAL_HELPER(8, 5); +SHIFT_CAL_HELPER(8, 6); +SHIFT_CAL_HELPER(8, 7); +SHIFT_CAL_HELPER(8, 8); + +SHIFT_CAL_HELPER(4, 1); +SHIFT_CAL_HELPER(4, 2); +SHIFT_CAL_HELPER(4, 3); +SHIFT_CAL_HELPER(4, 4); + +#undef SHIFT_CAL_HELPER #undef cb - } -}; +#undef cb2 -template +template MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { - ShiftCalHelper::impl( - c, src, weight); + ShiftCalHelper::impl(c, src, weight); }; template struct OCHelper { @@ -151,7 +144,7 @@ struct KerNeonXXs1Nchw44FP32 { const int ld_src_iw = iw * oc_step; constexpr int c_dim = OCHelper::val; float32x4_t c[c_dim][ow_block]; - init_ocx_ow8(c, bias_ptr, ld_bias); + init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; @@ -162,11 +155,11 @@ struct KerNeonXXs1Nchw44FP32 { 0); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, ow_block>(c, src, weight); + cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; weight_ptr += ld_weight_fh; } @@ -196,7 +189,7 @@ struct KerNeonXXs1Nchw44FP32 { const int ld_src_iw = iw * oc_step; constexpr int c_dim = OCHelper::val; float32x4_t c[c_dim][ow_block]; - init_ocx_ow8(c, bias_ptr, ld_bias); + init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; @@ -207,15 +200,15 @@ struct KerNeonXXs1Nchw44FP32 { 0); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, ow_block>(c, src, weight); + cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<2, 0, c_dim, ow_block>(c, src, weight); + cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; weight_ptr += ld_weight_fh; } @@ -244,7 +237,7 @@ struct KerNeonXXs1Nchw44FP32 { const int ld_src_iw = iw * oc_step; constexpr int c_dim = OCHelper::val; float32x4_t c[c_dim][ow_block]; - init_ocx_ow8(c, bias_ptr, ld_bias); + init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; @@ -255,27 +248,27 @@ struct KerNeonXXs1Nchw44FP32 { 0); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, ow_block>(c, src, weight); + cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<2, 0, c_dim, ow_block>(c, src, weight); + cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<3, 0, c_dim, ow_block>(c, src, weight); + cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight); src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<4, 0, c_dim, ow_block>(c, src, weight); + cal_helper<4, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; weight_ptr += ld_weight_fh; } @@ -305,7 +298,7 @@ struct KerNeonXXs1Nchw44FP32 { const int ld_src_iw = iw * oc_step; constexpr int c_dim = OCHelper::val; float32x4_t c[c_dim][ow_block]; - init_ocx_ow8(c, bias_ptr, ld_bias); + init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; @@ -316,37 +309,37 @@ struct KerNeonXXs1Nchw44FP32 { 0); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src[0] = vld1q_f32(src_ptr + (ow_block)*ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, ow_block>(c, src, weight); + cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); src[1] = vld1q_f32(src_ptr + (ow_block + 1) * ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<2, 0, c_dim, ow_block>(c, src, weight); + cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); src[2] = vld1q_f32(src_ptr + (ow_block + 2) * ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<3, 0, c_dim, ow_block>(c, src, weight); + cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight); src[3] = vld1q_f32(src_ptr + (ow_block + 3) * ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<4, 0, c_dim, ow_block>(c, src, weight); + cal_helper<4, 0, c_dim, ow_block, remain_w>(c, src, weight); src[4] = vld1q_f32(src_ptr + (ow_block + 4) * ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<5, 0, c_dim, ow_block>(c, src, weight); + cal_helper<5, 0, c_dim, ow_block, remain_w>(c, src, weight); src[5] = vld1q_f32(src_ptr + (ow_block + 5) * ic_step); load_helper( weight, weight_ptr, ld_weight_oc); - cal_helper<6, 0, c_dim, ow_block>(c, src, weight); + cal_helper<6, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; weight_ptr += ld_weight_fh; } diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h index 73141077..f7057e3b 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw44_kern_common_s2.h @@ -24,83 +24,77 @@ using namespace megdnn; using namespace arm_common; namespace { -template +template struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); }; -template -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step, lane) \ - c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ - src[(step + src_idx) % 8], lane); \ - c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ - src[(step + src_idx) % 8], lane); - - UNROLL_CALL_RAW(8, cb, 0); - UNROLL_CALL_RAW(8, cb, 1); - UNROLL_CALL_RAW(8, cb, 2); - UNROLL_CALL_RAW(8, cb, 3); -#undef cb - } -}; -template -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step, lane) \ - c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ - src[(step + src_idx) % 4], lane); \ - c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ - src[(step + src_idx) % 4], lane); - - UNROLL_CALL_RAW(4, cb, 0); - UNROLL_CALL_RAW(4, cb, 1); - UNROLL_CALL_RAW(4, cb, 2); - UNROLL_CALL_RAW(4, cb, 3); -#undef cb - } +template +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&) {} }; -template -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step, lane) \ - c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ - src[(step + src_idx) % 8], lane); - UNROLL_CALL_RAW(8, cb, 0); - UNROLL_CALL_RAW(8, cb, 1); - UNROLL_CALL_RAW(8, cb, 2); - UNROLL_CALL_RAW(8, cb, 3); -#undef cb - } -}; -template -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step, lane) \ - c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ - src[(step + src_idx) % 4], lane); +#define cb2(step, lane, ow_block) \ + c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ + src[(step + src_idx) % ow_block], lane); \ + c[1][step] = vfmaq_laneq_f32(c[1][step], weight[1][lane], \ + src[(step + src_idx) % ow_block], lane); - UNROLL_CALL_RAW(4, cb, 0); - UNROLL_CALL_RAW(4, cb, 1); - UNROLL_CALL_RAW(4, cb, 2); - UNROLL_CALL_RAW(4, cb, 3); +#define cb(step, lane, ow_block) \ + c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][lane], \ + src[(step + src_idx) % ow_block], lane); + +#define SHIFT_CAL_HELPER(ow_block, remain_w) \ + template \ + struct ShiftCalHelper { \ + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ + UNROLL_CALL_RAW(remain_w, cb2, 0, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb2, 1, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb2, 2, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb2, 3, ow_block); \ + } \ + }; \ + template \ + struct ShiftCalHelper { \ + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ + UNROLL_CALL_RAW(remain_w, cb, 0, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb, 1, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb, 2, ow_block); \ + UNROLL_CALL_RAW(remain_w, cb, 3, ow_block); \ + } \ + }; + +SHIFT_CAL_HELPER(8, 1); +SHIFT_CAL_HELPER(8, 2); +SHIFT_CAL_HELPER(8, 3); +SHIFT_CAL_HELPER(8, 4); +SHIFT_CAL_HELPER(8, 5); +SHIFT_CAL_HELPER(8, 6); +SHIFT_CAL_HELPER(8, 7); +SHIFT_CAL_HELPER(8, 8); + +SHIFT_CAL_HELPER(4, 1); +SHIFT_CAL_HELPER(4, 2); +SHIFT_CAL_HELPER(4, 3); +SHIFT_CAL_HELPER(4, 4); + +#undef SHIFT_CAL_HELPER #undef cb - } -}; +#undef cb2 -template +template MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { - ShiftCalHelper::impl( - c, src, weight); + ShiftCalHelper::impl(c, src, weight); }; + template struct OCHelper { public: @@ -151,7 +145,7 @@ struct KerNeonXXs2Nchw44FP32 { const int ld_src_iw = iw * oc_step; constexpr int c_dim = OCHelper::val; float32x4_t c[c_dim][ow_block]; - init_ocx_ow8(c, bias_ptr, ld_bias); + init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; @@ -163,13 +157,13 @@ struct KerNeonXXs2Nchw44FP32 { load_helper(src, src_ptr, 0); load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); load_helper(src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; src_ptr_odd += ld_src_iw; weight_ptr += ld_weight_fh; @@ -177,13 +171,13 @@ struct KerNeonXXs2Nchw44FP32 { load_helper(src, src_ptr, 0); load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); load_helper(src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; src_ptr_odd += ld_src_iw; weight_ptr += ld_weight_fh; @@ -213,7 +207,7 @@ struct KerNeonXXs2Nchw44FP32 { const int ld_src_iw = iw * oc_step; constexpr int c_dim = OCHelper::val; float32x4_t c[c_dim][ow_block]; - init_ocx_ow8(c, bias_ptr, ld_bias); + init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; const float* src_ptr_odd = src_ptr_odd_origin + ic_idx * ld_src_ic; @@ -224,18 +218,18 @@ struct KerNeonXXs2Nchw44FP32 { load_helper(src, src_ptr, 0); load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src[0] = vld1q_f32(src_ptr + ow_block * simd_len); load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, ow_block>(c, src, weight); + cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); load_helper(src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; src_ptr_odd += ld_src_iw; weight_ptr += ld_weight_fh; @@ -243,17 +237,17 @@ struct KerNeonXXs2Nchw44FP32 { load_helper(src, src_ptr, 0); load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src[0] = vld1q_f32(src_ptr + ow_block * simd_len); load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, ow_block>(c, src, weight); + cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); load_helper(src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; src_ptr_odd += ld_src_iw; weight_ptr += ld_weight_fh; @@ -261,18 +255,18 @@ struct KerNeonXXs2Nchw44FP32 { load_helper(src, src_ptr, 0); load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src[0] = vld1q_f32(src_ptr + ow_block * simd_len); load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, ow_block>(c, src, weight); + cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); load_helper(src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; src_ptr_odd += ld_src_iw; weight_ptr += ld_weight_fh; @@ -302,7 +296,7 @@ struct KerNeonXXs2Nchw44FP32 { const int ld_src_iw = iw * oc_step; constexpr int c_dim = OCHelper::val; float32x4_t c[c_dim][ow_block]; - init_ocx_ow8(c, bias_ptr, ld_bias); + init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; @@ -316,25 +310,25 @@ struct KerNeonXXs2Nchw44FP32 { 0); load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src[0] = vld1q_f32(src_ptr + ow_block * simd_len); load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, ow_block>(c, src, weight); + cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<2, 0, c_dim, ow_block>(c, src, weight); + cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); // odd element load_helper( src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, ow_block>(c, src, weight); + cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; src_ptr_odd += ld_src_iw; @@ -371,7 +365,7 @@ struct KerNeonXXs2Nchw44FP32 { const int ld_src_iw = iw * oc_step; constexpr int c_dim = OCHelper::val; float32x4_t c[c_dim][ow_block]; - init_ocx_ow8(c, bias_ptr, ld_bias); + init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { const float* src_ptr = src_ptr_origin + ic_idx * ld_src_ic; @@ -385,33 +379,33 @@ struct KerNeonXXs2Nchw44FP32 { 0); load_helper<4, 0, oc_step, c_dim, Vld1q_f32>(weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src[0] = vld1q_f32(src_ptr + ow_block * simd_len); load_helper<4, 2 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, ow_block>(c, src, weight); + cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); src[1] = vld1q_f32(src_ptr + (ow_block + 1) * simd_len); load_helper<4, 4 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<2, 0, c_dim, ow_block>(c, src, weight); + cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); src[2] = vld1q_f32(src_ptr + (ow_block + 2) * simd_len); load_helper<4, 6 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<3, 0, c_dim, ow_block>(c, src, weight); + cal_helper<3, 0, c_dim, ow_block, remain_w>(c, src, weight); // odd element load_helper( src, src_ptr_odd, 0); load_helper<4, 1 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, ow_block>(c, src, weight); + cal_helper<0, 0, c_dim, ow_block, remain_w>(c, src, weight); src[0] = vld1q_f32(src_ptr_odd + ow_block * simd_len); load_helper<4, 3 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<1, 0, c_dim, ow_block>(c, src, weight); + cal_helper<1, 0, c_dim, ow_block, remain_w>(c, src, weight); src[1] = vld1q_f32(src_ptr_odd + (ow_block + 1) * simd_len); load_helper<4, 5 * ld_weight, oc_step, c_dim, Vld1q_f32>( weight, weight_ptr, ld_weight_oc); - cal_helper<2, 0, c_dim, ow_block>(c, src, weight); + cal_helper<2, 0, c_dim, ow_block, remain_w>(c, src, weight); src_ptr += ld_src_iw; src_ptr_odd += ld_src_iw; diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h index c8a6439f..321fe2a5 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h @@ -39,16 +39,18 @@ namespace { *\tparam T2 is type of src regs *\tparam T3 is type of weight regs */ -template +template struct ShiftCalHelper { static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight); }; -template -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { +template +struct ShiftCalHelper { + static MEGDNN_ALWAYS_INLINE void impl(T&, T2&, T3&) {} +}; + #define cb(step) \ c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][weight_idx], \ src[(step * stride + src_idx) / 4], \ @@ -57,29 +59,47 @@ struct ShiftCalHelper { src[(step * stride + src_idx) / 4], \ (step * stride + src_idx) % 4); - UNROLL_CALL_RAW(8, cb); -#undef cb - } -}; -template -struct ShiftCalHelper { - static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { -#define cb(step) \ +#define cb2(step) \ c[0][step] = vfmaq_laneq_f32(c[0][step], weight[0][weight_idx], \ src[(step * stride + src_idx) / 4], \ (step * stride + src_idx) % 4); - UNROLL_CALL_RAW(8, cb); +#define SHIFT_CAL_HELPER(ow_remain) \ + template \ + struct ShiftCalHelper { \ + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ + UNROLL_CALL_RAW(ow_remain, cb); \ + } \ + }; \ + template \ + struct ShiftCalHelper { \ + static MEGDNN_ALWAYS_INLINE void impl(T& c, T2& src, T3& weight) { \ + UNROLL_CALL_RAW(ow_remain, cb2); \ + } \ + }; + +SHIFT_CAL_HELPER(1) +SHIFT_CAL_HELPER(2) +SHIFT_CAL_HELPER(3) +SHIFT_CAL_HELPER(4) +SHIFT_CAL_HELPER(5) +SHIFT_CAL_HELPER(6) +SHIFT_CAL_HELPER(7) +SHIFT_CAL_HELPER(8) + +#undef SHIFT_CAL_HELPER #undef cb - } -}; +#undef cb2 -template +template MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { - ShiftCalHelper::impl(c, src, - weight); + ShiftCalHelper::impl(c, src, weight); }; enum CpuTag { DEFAULT_CPU_TAG = 0, @@ -134,7 +154,7 @@ struct KerNeonXXs2NchwNchw44FP32::val; float32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { float32x4_t src[src_reg_size]; @@ -145,13 +165,13 @@ struct KerNeonXXs2NchwNchw44FP32( \ weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ - cal_helper<0, 0, c_dim, stride>(c, src, weight); \ - cal_helper<1, 1, c_dim, stride>(c, src, weight); \ - cal_helper<2, 2, c_dim, stride>(c, src, weight); \ - cal_helper<3, 3, c_dim, stride>(c, src, weight); \ - cal_helper<4, 4, c_dim, stride>(c, src, weight); \ - cal_helper<5, 5, c_dim, stride>(c, src, weight); \ - cal_helper<6, 6, c_dim, stride>(c, src, weight); + cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ + cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ + cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); \ + cal_helper<3, 3, c_dim, stride, remain_w>(c, src, weight); \ + cal_helper<4, 4, c_dim, stride, remain_w>(c, src, weight); \ + cal_helper<5, 5, c_dim, stride, remain_w>(c, src, weight); \ + cal_helper<6, 6, c_dim, stride, remain_w>(c, src, weight); UNROLL_CALL_RAW(7, KERNEL_CB) #undef KERNEL_CB @@ -185,7 +205,7 @@ struct KerNeonXXs2NchwNchw44FP32::val; float32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { float32x4_t src[src_reg_size]; @@ -196,11 +216,11 @@ struct KerNeonXXs2NchwNchw44FP32( \ weight, weight_ptr + step * ld_weight_fw, ld_weight_oc); \ - cal_helper<0, 0, c_dim, stride>(c, src, weight); \ - cal_helper<1, 1, c_dim, stride>(c, src, weight); \ - cal_helper<2, 2, c_dim, stride>(c, src, weight); \ - cal_helper<3, 3, c_dim, stride>(c, src, weight); \ - cal_helper<4, 4, c_dim, stride>(c, src, weight); + cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); \ + cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); \ + cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); \ + cal_helper<3, 3, c_dim, stride, remain_w>(c, src, weight); \ + cal_helper<4, 4, c_dim, stride, remain_w>(c, src, weight); UNROLL_CALL_RAW(5, KERNEL_CB) #undef KERNEL_CB @@ -233,7 +253,7 @@ struct KerNeonXXs2NchwNchw44FP32::val; float32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { float32x4_t src[src_reg_size]; @@ -243,27 +263,27 @@ struct KerNeonXXs2NchwNchw44FP32( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, stride>(c, src, weight); - cal_helper<1, 1, c_dim, stride>(c, src, weight); - cal_helper<2, 2, c_dim, stride>(c, src, weight); + cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); + cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); + cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); // row 1 load_helper( src, src_ptr + iw, 0); load_helper( weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); - cal_helper<0, 0, c_dim, stride>(c, src, weight); - cal_helper<1, 1, c_dim, stride>(c, src, weight); - cal_helper<2, 2, c_dim, stride>(c, src, weight); + cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); + cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); + cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); // row 2 load_helper( src, src_ptr + 2 * iw, 0); load_helper( weight, weight_ptr + 2 * ld_weight_fw, ld_weight_oc); - cal_helper<0, 0, c_dim, stride>(c, src, weight); - cal_helper<1, 1, c_dim, stride>(c, src, weight); - cal_helper<2, 2, c_dim, stride>(c, src, weight); + cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); + cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); + cal_helper<2, 2, c_dim, stride, remain_w>(c, src, weight); src_ptr += ld_src_ic; weight_ptr += ld_weight_ic; @@ -634,7 +654,7 @@ struct KerNeonXXs2NchwNchw44FP32::val; float32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { float32x4_t src[src_reg_size]; @@ -644,16 +664,16 @@ struct KerNeonXXs2NchwNchw44FP32( weight, weight_ptr, ld_weight_oc); - cal_helper<0, 0, c_dim, stride>(c, src, weight); - cal_helper<1, 1, c_dim, stride>(c, src, weight); + cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); + cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); // row 1 load_helper( src, src_ptr + iw, 0); load_helper( weight, weight_ptr + 1 * ld_weight_fw, ld_weight_oc); - cal_helper<0, 0, c_dim, stride>(c, src, weight); - cal_helper<1, 1, c_dim, stride>(c, src, weight); + cal_helper<0, 0, c_dim, stride, remain_w>(c, src, weight); + cal_helper<1, 1, c_dim, stride, remain_w>(c, src, weight); src_ptr += ld_src_ic; weight_ptr += ld_weight_ic; diff --git a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp index 75e73f91..68b698c4 100644 --- a/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/do_conv_stride1.cpp @@ -6,14 +6,15 @@ * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. */ #include #include "src/arm_common/conv_bias/fp32/do_conv_stride1.h" -#include "src/arm_common/simd_macro/neon_helper.h" #include "src/arm_common/conv_bias/postprocess_helper.h" +#include "src/arm_common/simd_macro/neon_helper.h" #include "midout.h" @@ -27,10 +28,9 @@ using namespace conv_stride1; using NCBKernSizeParam = fallback::ConvBiasImpl::NCBKernSizeParam; using NCBKernParam = fallback::ConvBiasImpl::NCBKernParam; - -void conv_stride1::do_conv_2x2_stride1(const float* src, const float* filter, float* dst, - size_t IH, size_t IW, size_t OH, size_t OW, - size_t IC) { +void conv_stride1::do_conv_2x2_stride1(const float* src, const float* filter, + float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - OW; //! unroll of 2 size_t ic = 0; @@ -143,9 +143,9 @@ void conv_stride1::do_conv_2x2_stride1(const float* src, const float* filter, fl } } -void conv_stride1::do_conv_3x3_stride1(const float* src, const float* filter, float* dst, - size_t IH, size_t IW, size_t OH, size_t OW, - size_t IC) { +void conv_stride1::do_conv_3x3_stride1(const float* src, const float* filter, + float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - OW; rep(ic, IC) { @@ -193,7 +193,7 @@ void conv_stride1::do_conv_3x3_stride1(const float* src, const float* filter, fl MEGDNN_SIMD_TYPE _r22 = MEGDNN_SIMD_EXT(_r20, _r20n, 2); MEGDNN_SIMD_TYPE _r30 = MEGDNN_SIMD_LOADU(r3); - MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU(r3 + 4); + MEGDNN_SIMD_TYPE _r30n = MEGDNN_SIMD_LOADU_2(r3 + 4); MEGDNN_SIMD_TYPE _r31 = MEGDNN_SIMD_EXT(_r30, _r30n, 1); MEGDNN_SIMD_TYPE _r32 = MEGDNN_SIMD_EXT(_r30, _r30n, 2); @@ -290,9 +290,9 @@ void conv_stride1::do_conv_3x3_stride1(const float* src, const float* filter, fl } } -void conv_stride1::do_conv_5x5_stride1(const float* src, const float* filter, float* dst, - size_t IH, size_t IW, size_t OH, size_t OW, - size_t IC) { +void conv_stride1::do_conv_5x5_stride1(const float* src, const float* filter, + float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - OW; rep(ic, IC) { @@ -530,9 +530,9 @@ void conv_stride1::do_conv_5x5_stride1(const float* src, const float* filter, fl } } -void conv_stride1::do_conv_7x7_stride1(const float* src, const float* filter, float* dst, - size_t IH, size_t IW, size_t OH, size_t OW, - size_t IC) { +void conv_stride1::do_conv_7x7_stride1(const float* src, const float* filter, + float* dst, size_t IH, size_t IW, + size_t OH, size_t OW, size_t IC) { const size_t tail_step = IW - OW; rep(ic, IC) { @@ -688,7 +688,7 @@ void conv_stride1::do_conv_7x7_stride1(const float* src, const float* filter, fl _sum = MEGDNN_SIMD_FMA_LANE(_sum, _r56, _k39404142, 2); MEGDNN_SIMD_TYPE _k42434445 = MEGDNN_SIMD_LOADU(k6); - MEGDNN_SIMD_TYPE _k46474849 = MEGDNN_SIMD_LOADU(k6 + 4); + MEGDNN_SIMD_TYPE _k46474849 = MEGDNN_SIMD_LOADU_3(k6 + 4); MEGDNN_SIMD_TYPE _r60 = MEGDNN_SIMD_LOADU(r6); MEGDNN_SIMD_TYPE _r64 = MEGDNN_SIMD_LOADU(r6 + 4); diff --git a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp index 687f817e..8ace555f 100644 --- a/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/fp32/f32_direct_nchw44_algo.cpp @@ -126,7 +126,8 @@ static void do_conv_kern(const WorkspaceBundle& bundle, ? oh_idx * oh_block * ow * pack_c : oc_idx; const float* bptr = - kern_param.bias(batch_id, group_id) + bias_offset; + kern_param.bias(batch_id, group_id, oc_idx, 1, pack_c) + + bias_offset; Op op; conv_bias::conv_direct_fp32_nchw44( diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp index 1c879d91..81f53770 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp @@ -69,7 +69,7 @@ struct KerNeonDotXXs2Nchw44Int8::val; int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, ld_bias); + init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { int8x16_t src[src_reg]; int8x16_t weight[c_dim][weight_reg]; @@ -117,7 +117,7 @@ struct KerNeonDotXXs2Nchw44Int8::val; int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, ld_bias); + init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { int8x16_t src[src_reg]; int8x16_t weight[c_dim][weight_reg]; @@ -171,7 +171,7 @@ struct KerNeonDotXXs2Nchw44Int8::val; int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, ld_bias); + init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { int8x16_t src[src_reg]; @@ -220,7 +220,7 @@ struct KerNeonDotXXs2Nchw44Int8::val; int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, ld_bias); + init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { int8x16_t src[src_reg]; diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp index 46e0177f..2ea726a1 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp @@ -80,7 +80,7 @@ struct KerNeonDotXXs2Nchw44Int8::val; int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, ld_bias); + init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { int8x16_t src[2][src_reg]; int8x16_t weight[c_dim][weight_reg]; @@ -131,7 +131,7 @@ struct KerNeonDotXXs2Nchw44Int8::val; int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, ld_bias); + init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { int8x16_t src[2][src_reg]; int8x16_t weight[c_dim][weight_reg]; @@ -189,7 +189,7 @@ struct KerNeonDotXXs2Nchw44Int8::val; int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, ld_bias); + init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { int8x16_t src[2][src_reg]; @@ -244,7 +244,7 @@ struct KerNeonDotXXs2Nchw44Int8::val; int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, ld_bias); + init_ocx_ow8(c, bias_ptr, ld_bias); for (int ic_idx = 0; ic_idx < ic; ic_idx += ic_step) { int8x16_t src[2][src_reg]; diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp index c7736149..8b82b932 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s1.cpp @@ -45,7 +45,7 @@ static void ker_neon_dirctconv_2x2s1_oc8_ow8(const int8_t* src_ptr, int8x16_t src[8 + 1]; int16x8_t temp_c[4]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { @@ -135,7 +135,7 @@ static void ker_neon_dirctconv_2x2s1_oc4_ow8(const int8_t* src_ptr, int8x16_t weight[1][2]; int8x16_t src[8 + 1]; int16x8_t temp_c[2]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { @@ -224,7 +224,7 @@ struct KerNeonDirectStride1Int8 { int8x16_t weight[3]; int8x16_t src[8 + 2]; int16x8_t temp_c[2]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { @@ -306,7 +306,7 @@ struct KerNeonDirectStride1Int8 { int8x16_t weight[5]; int8x16_t src[8 + 2]; int16x8_t temp_c[2]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { @@ -409,7 +409,7 @@ struct KerNeonDirectStride1Int8 { int8x16_t weight[7]; int8x16_t src[8 + 2]; int16x8_t temp_c[2]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { @@ -569,7 +569,7 @@ void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src, (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; const size_t dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_2x2s1_oc8_ow8( src + src_offset, filter + weight_offset, bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_oc, op); @@ -594,7 +594,7 @@ void conv_direct_stride1_2x2_int8_nchw44(const int8_t* src, (oh_idx * iw + ow_idx) * ic_step * pack_iw_len; const size_t dst_offset = oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - ker_neon_dirctconv_2x2s1_oc4_ow8( src + src_offset, filter + weight_offset, bias + oc_idx, dst + dst_offset, ic, ih, iw, ld_oc, op); diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp index c202512b..b6959dc0 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw44_s2.cpp @@ -54,7 +54,7 @@ static void ker_neon_dirctconv_2x2s2_oc8_ow8(const int8_t* src_ptr, int8x16_t src[8 + 1]; int16x8_t temp_c[4]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { @@ -151,7 +151,7 @@ static void ker_neon_dirctconv_2x2s2_oc4_ow8(const int8_t* src_ptr, int8x16_t weight[2]; int8x16_t src[8 + 1]; int16x8_t temp_c[2]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { @@ -239,7 +239,7 @@ struct KerNeonDirectStride2Int8 { int8x16_t weight[3]; int8x16_t src[8 + 2]; int16x8_t temp_c[4]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { @@ -327,7 +327,7 @@ struct KerNeonDirectStride2Int8 { int8x16_t weight[5]; int8x16_t src[8 + 2]; int16x8_t temp_c[4]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + @@ -435,7 +435,7 @@ struct KerNeonDirectStride2Int8 { int8x16_t weight[7]; int8x16_t src[8 + 2]; int16x8_t temp_c[4]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { for (int fh_idx = 0; fh_idx < fh; ++fh_idx) { const int8_t* src_ic_0_3 = src_ptr + ic_idx * ic_stride + diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp index 2722a94b..0dacdf83 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/int8_direct_nchw_nchw44_s1.cpp @@ -131,7 +131,7 @@ struct KerNeonXXs2NchwNchw44 { const int ld_weight_oc = oc_step * filter_height * filter_width * ic; constexpr int c_dim = OCHelper::val; int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; @@ -178,7 +178,7 @@ struct KerNeonXXs2NchwNchw44 { const int ld_weight_oc = oc_step * filter_height * filter_width * ic; constexpr int c_dim = OCHelper::val; int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; @@ -232,7 +232,7 @@ struct KerNeonXXs2NchwNchw44 { const int ld_weight_oc = oc_step * filter_height * filter_width * ic; constexpr int c_dim = OCHelper::val; int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; @@ -279,7 +279,7 @@ struct KerNeonXXs2NchwNchw44 { const int ld_weight_oc = oc_step * filter_height * filter_width * ic; constexpr int c_dim = OCHelper::val; int32x4_t c[c_dim][8]; - init_ocx_ow8(c, bias_ptr, oc_step); + init_ocx_ow8(c, bias_ptr, oc_step); for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { const int8_t* nchw_src_ptr = src_ptr + ic_idx * ic_stride; diff --git a/dnn/src/arm_common/conv_bias/intrinsic_helper.h b/dnn/src/arm_common/conv_bias/intrinsic_helper.h index 772d3f8e..8d80e7f4 100644 --- a/dnn/src/arm_common/conv_bias/intrinsic_helper.h +++ b/dnn/src/arm_common/conv_bias/intrinsic_helper.h @@ -643,128 +643,95 @@ __ai int32x4_t neon_vld1q(const int* ptr) { __ai int16x8_t neon_vld1q(const int16_t* ptr) { return vld1q_s16(ptr); } - -template +template +struct NeonLdqSimd; +template <> +struct NeonLdqSimd { + static constexpr int simd_len = 4; +}; +template <> +struct NeonLdqSimd { + static constexpr int simd_len = 4; +}; +template <> +struct NeonLdqSimd { + static constexpr int simd_len = 8; +}; +template struct InitOcxOw8 { static __ai void impl(T& c, const T2* bias_ptr, int oc_step); }; -template -struct InitOcxOw8<2, BiasMode::NO_BIAS, 8, T, T2> { - static __ai void impl(T& c, const T2*, int) { -#define BAIS_INIT(step) \ - c[0][step] = neon_vdupq_n(static_cast(0)); \ - c[1][step] = neon_vdupq_n(static_cast(0)); - UNROLL_CALL_RAW(8, BAIS_INIT); -#undef BAIS_INIT - } +template +struct InitOcxOw8 { + static __ai void impl(T&, const T2*, int) {} }; -template -struct InitOcxOw8<2, BiasMode::NO_BIAS, 4, T, T2> { - static __ai void impl(T& c, const T2*, int) { -#define BAIS_INIT(step) \ + +#define BAIS_INIT_NO_BIAS_C2(step) \ c[0][step] = neon_vdupq_n(static_cast(0)); \ c[1][step] = neon_vdupq_n(static_cast(0)); - UNROLL_CALL_RAW(4, BAIS_INIT); -#undef BAIS_INIT - } -}; -template -struct InitOcxOw8<2, BiasMode::BROADCAST_CHANNEL_BIAS, 8, T, T2> { - static __ai void impl(T& c, const T2* bias_ptr, int oc_step) { -#define BAIS_INIT(step) \ - c[0][step] = neon_vld1q(bias_ptr); \ - c[1][step] = neon_vld1q(bias_ptr + oc_step); - UNROLL_CALL_RAW(8, BAIS_INIT); -#undef BAIS_INIT - } -}; -template -struct InitOcxOw8<2, BiasMode::BROADCAST_CHANNEL_BIAS, 4, T, T2> { - static __ai void impl(T& c, const T2* bias_ptr, int oc_step) { -#define BAIS_INIT(step) \ +#define BAIS_INIT_NO_BIAS_C1(step) \ + c[0][step] = neon_vdupq_n(static_cast(0)); + +#define BAIS_INIT_BROADCAST_C2(step) \ c[0][step] = neon_vld1q(bias_ptr); \ c[1][step] = neon_vld1q(bias_ptr + oc_step); - UNROLL_CALL_RAW(4, BAIS_INIT); -#undef BAIS_INIT - } -}; -template -struct InitOcxOw8<2, BiasMode::BIAS, 8, T, T2> { - static __ai void impl(T& c, const T2* bias_ptr, int oc_step) { - constexpr int simd_len = 4; -#define BAIS_INIT(step) \ - c[0][step] = neon_vld1q(bias_ptr + step * simd_len); \ - c[1][step] = neon_vld1q(bias_ptr + oc_step + step * simd_len); - UNROLL_CALL_RAW(8, BAIS_INIT); -#undef BAIS_INIT - } -}; -template -struct InitOcxOw8<2, BiasMode::BIAS, 4, T, T2> { - static __ai void impl(T& c, const T2* bias_ptr, int oc_step) { - constexpr int simd_len = 4; -#define BAIS_INIT(step) \ +#define BAIS_INIT_BROADCAST_C1(step) c[0][step] = neon_vld1q(bias_ptr); + +#define BAIS_INIT_BIAS_C2(step) \ c[0][step] = neon_vld1q(bias_ptr + step * simd_len); \ c[1][step] = neon_vld1q(bias_ptr + oc_step + step * simd_len); - UNROLL_CALL_RAW(4, BAIS_INIT); -#undef BAIS_INIT - } -}; - -template -struct InitOcxOw8<1, BiasMode::NO_BIAS, 8, T, T2> { - static __ai void impl(T& c, const T2*, int) { -#define BAIS_INIT(step) c[0][step] = neon_vdupq_n(static_cast(0)); - UNROLL_CALL_RAW(8, BAIS_INIT); -#undef BAIS_INIT - } -}; -template -struct InitOcxOw8<1, BiasMode::NO_BIAS, 4, T, T2> { - static __ai void impl(T& c, const T2*, int) { -#define BAIS_INIT(step) c[0][step] = neon_vdupq_n(static_cast(0)); - UNROLL_CALL_RAW(4, BAIS_INIT); -#undef BAIS_INIT - } -}; -template -struct InitOcxOw8<1, BiasMode::BROADCAST_CHANNEL_BIAS, 8, T, T2> { - static __ai void impl(T& c, const T2* bias_ptr, int) { -#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr); - UNROLL_CALL_RAW(8, BAIS_INIT); -#undef BAIS_INIT - } -}; -template -struct InitOcxOw8<1, BiasMode::BROADCAST_CHANNEL_BIAS, 4, T, T2> { - static __ai void impl(T& c, const T2* bias_ptr, int) { -#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr); - UNROLL_CALL_RAW(4, BAIS_INIT); -#undef BAIS_INIT - } -}; -template -struct InitOcxOw8<1, BiasMode::BIAS, 8, T, T2> { - static __ai void impl(T& c, const T2* bias_ptr, int) { - constexpr int simd_len = 4; -#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr + step * simd_len); - UNROLL_CALL_RAW(8, BAIS_INIT); -#undef BAIS_INIT - } -}; -template -struct InitOcxOw8<1, BiasMode::BIAS, 4, T, T2> { - static __ai void impl(T& c, const T2* bias_ptr, int) { - constexpr int simd_len = 4; -#define BAIS_INIT(step) c[0][step] = neon_vld1q(bias_ptr + step * simd_len); - UNROLL_CALL_RAW(4, BAIS_INIT); -#undef BAIS_INIT - } -}; -template +#define BAIS_INIT_BIAS_C1(step) \ + c[0][step] = neon_vld1q(bias_ptr + step * simd_len); + +#define INSTANCE_InitOcxOw8(ow_remain, cdim) \ + template \ + struct InitOcxOw8 { \ + static __ai void impl(T& c, const T2*, int) { \ + UNROLL_CALL_RAW(ow_remain, BAIS_INIT_NO_BIAS_C##cdim); \ + } \ + }; \ + template \ + struct InitOcxOw8 { \ + static __ai void impl(T& c, const T2* bias_ptr, int oc_step) { \ + (void)oc_step; \ + UNROLL_CALL_RAW(ow_remain, BAIS_INIT_BROADCAST_C##cdim); \ + } \ + }; \ + template \ + struct InitOcxOw8 { \ + static __ai void impl(T& c, const T2* bias_ptr, int oc_step) { \ + constexpr int simd_len = NeonLdqSimd::simd_len; \ + (void)oc_step; \ + UNROLL_CALL_RAW(ow_remain, BAIS_INIT_BIAS_C##cdim); \ + } \ + }; +#define INSTANCE_InitOcxOw8_C(ow_remain) \ + INSTANCE_InitOcxOw8(ow_remain, 2); \ + INSTANCE_InitOcxOw8(ow_remain, 1); + +INSTANCE_InitOcxOw8_C(1); +INSTANCE_InitOcxOw8_C(2); +INSTANCE_InitOcxOw8_C(3); +INSTANCE_InitOcxOw8_C(4); +INSTANCE_InitOcxOw8_C(5); +INSTANCE_InitOcxOw8_C(6); +INSTANCE_InitOcxOw8_C(7); +INSTANCE_InitOcxOw8_C(8); + +#undef INSTANCE_InitOcxOw8 +#undef INSTANCE_InitOcxOw8_C +#undef BAIS_INIT_BIAS_C1 +#undef BAIS_INIT_BIAS_C2 +#undef BAIS_INIT_BROADCAST_C1 +#undef BAIS_INIT_BROADCAST_C2 +#undef BAIS_INIT_NO_BIAS_C1 +#undef BAIS_INIT_NO_BIAS_C2 + +template __ai void init_ocx_ow8(T& c, const T2* bias_ptr, int oc_step) { - InitOcxOw8::impl(c, bias_ptr, oc_step); + InitOcxOw8::impl(c, bias_ptr, oc_step); } /////////////////////init_ocx_ow4///////////////////// template diff --git a/dnn/src/arm_common/simd_macro/neon_helper.h b/dnn/src/arm_common/simd_macro/neon_helper.h index 3f0cb92b..b77105b2 100644 --- a/dnn/src/arm_common/simd_macro/neon_helper.h +++ b/dnn/src/arm_common/simd_macro/neon_helper.h @@ -18,6 +18,8 @@ #define MEGDNN_SIMD_TYPE float32x4_t #define MEGDNN_SIMD_TYPE2 float32x4x2_t #define MEGDNN_SIMD_LOADU(addr) vld1q_f32(addr) +#define MEGDNN_SIMD_LOADU_2(addr) vcombine_f32(vld1_f32(addr), vdup_n_f32(0.f)) +#define MEGDNN_SIMD_LOADU_3(addr) vld1q_lane_f32(addr + 2, vcombine_f32(vld1_f32(addr), vdup_n_f32(0.f)), 2) #define MEGDNN_SIMD_STOREU(addr, reg) vst1q_f32(addr, reg) #define MEGDNN_SIMD_SETZERO() vdupq_n_f32(0.0f) #define MEGDNN_SIMD_SET1(num) vdupq_n_f32(num) diff --git a/dnn/test/common/checker.h b/dnn/test/common/checker.h index 3ede3b38..e6b355fb 100644 --- a/dnn/test/common/checker.h +++ b/dnn/test/common/checker.h @@ -23,6 +23,20 @@ #include #include +// clang-format off +#if defined(__has_feature) + #if __has_feature(address_sanitizer) + #define MEGDNN_TEST_ASAN 1 + #else + #define MEGDNN_TEST_ASAN 0 + #endif +#elif defined(__SANITIZE_ADDRESS__) + #define MEGDNN_TEST_ASAN 1 +#else + #define MEGDNN_TEST_ASAN 0 +#endif +// clang-format on + namespace megdnn { namespace test { diff --git a/dnn/test/fallback/warp_perspective.cpp b/dnn/test/fallback/warp_perspective.cpp index ce557e23..97951a57 100644 --- a/dnn/test/fallback/warp_perspective.cpp +++ b/dnn/test/fallback/warp_perspective.cpp @@ -76,6 +76,9 @@ TEST_F(FALLBACK, WARP_PERSPECTIVE) { checker.set_param(param); checker.exec({{1000, 2, 10, 11}, {1000, 3, 3}, {1000, 2, 12, 13}}); } +#if MEGDNN_TEST_ASAN +//! asan detect nan will make test failed +#else // resize nan case UniformFloatRNG rng_zero(0, 0); checker.set_rng(1, &rng_zero); @@ -85,6 +88,7 @@ TEST_F(FALLBACK, WARP_PERSPECTIVE) { checker.set_param(param); checker.exec({{1000, 2, 10, 11}, {1000, 3, 3}, {1000, 2, 12, 13}}); } +#endif } TEST_F(FALLBACK, WARP_PERSPECTIVE_MAT_IDX) { diff --git a/dnn/test/naive/warp_perspective.cpp b/dnn/test/naive/warp_perspective.cpp index 0d8f6ae8..2168853b 100644 --- a/dnn/test/naive/warp_perspective.cpp +++ b/dnn/test/naive/warp_perspective.cpp @@ -352,6 +352,9 @@ TEST_F(NAIVE_MULTI_THREADS, WARP_PERSPECTIVE_FORWARD_HWCD4) { checker.execs({{22, 10, 1, 11, 4}, {22, 3, 3}, {22, 11, 1, 12, 4}}); } } +#if MEGDNN_TEST_ASAN +//! asan detect nan will make test failed +#else // nan case NanMatRNG rng_nan; UniformFloatRNG rng_zero(0, 0); @@ -369,6 +372,7 @@ TEST_F(NAIVE_MULTI_THREADS, WARP_PERSPECTIVE_FORWARD_HWCD4) { checker.set_param(param); checker.exec({{10, 10, 1, 11, 4}, {10, 3, 3}, {10, 12, 1, 13, 4}}); } +#endif } #if MEGDNN_WITH_BENCHMARK