diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h index 70b9f128..c8a6439f 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h @@ -19,6 +19,13 @@ #include "src/common/unroll_macro.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" +#if MEGDNN_ARMV7 +#include "src/armv7/matrix_mul/asm/common.h" +#endif + +#if MGB_ENABLE_CPUINFO +#include "cpuinfo.h" +#endif using namespace megdnn; using namespace arm_common; @@ -74,6 +81,10 @@ MEGDNN_ALWAYS_INLINE void cal_helper(T& c, T2& src, T3& weight) { ShiftCalHelper::impl(c, src, weight); }; +enum CpuTag { + DEFAULT_CPU_TAG = 0, + A7_TAG, +}; template struct OCHelper { public: @@ -95,7 +106,8 @@ public: * oc8_ow8(m = 8, n = 8) and oc4_ow8(m = 4, n = 8) gemm like kernel **/ template + int oc_block, int stride, int ow_block, + int tag = CpuTag::DEFAULT_CPU_TAG> struct KerNeonXXs2NchwNchw44FP32 { static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, const float32_t* bias_ptr, float32_t* dst_ptr, int ic, @@ -261,6 +273,346 @@ struct KerNeonXXs2NchwNchw44FP32 +struct KerNeonXXs2NchwNchw44FP32 { + static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, + int ih, int iw, int ld_dst_oc, const Op& op) { + constexpr int oc_block = 4; + constexpr int stride = 2; + constexpr int remain_w = 8; + constexpr int ow_block = 8; + constexpr int loop_ic_step = 1; + constexpr int filter_size = 3; + constexpr int oc_step = 4; + constexpr int src_line_block = ow_block * stride + filter_size - stride; + + const int iw_skip_bytes = + (iw - round_up(src_line_block, 2)) * sizeof(float); + const int ld_src_ic_skip_bytes = + iw * (ih - filter_size) * sizeof(float) + iw_skip_bytes; + constexpr int c_dim = OCHelper::val; + float32x4_t c[1][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + const int img_stride = ih * iw; + constexpr int filter_stride = filter_size * filter_size * oc_step; + megdnn::armv7::prefetch_2x(src_ptr); + megdnn::armv7::prefetch_2x(src_ptr + iw); + megdnn::armv7::prefetch_2x(src_ptr + 2 * iw); + megdnn::armv7::prefetch_2x(weight_ptr); + + /** + * c q8-q15 + * src q0-q4 + * weight q5-q7 + * optimized for A7 + * + */ + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + megdnn::armv7::prefetch_2x(src_ptr + img_stride); + megdnn::armv7::prefetch_2x(src_ptr + img_stride + iw); + megdnn::armv7::prefetch_2x(src_ptr + img_stride + 2 * iw); + megdnn::armv7::prefetch_2x(weight_ptr + filter_stride); + asm volatile( + + "2:\n" + //! row 0 + "vld1.32 {d10, d11}, [%[weight_ptr]]!\n" + "vld1.32 {d0, d1}, [%[src_ptr]]!\n" + "vld1.32 {d2, d3}, [%[src_ptr]]!\n" + "vld1.32 {d4, d5}, [%[src_ptr]]!\n" + "vld1.32 {d6, d7}, [%[src_ptr]]!\n" + "vld1.32 {d8}, [%[src_ptr]]!\n" + "vld1.32 {d12, d13}, [%[weight_ptr]]!\n" + "vld1.32 {d14, d15}, [%[weight_ptr]]!\n" + "add %[src_ptr], %[src_ptr], %[iw_skip_bytes]\n" + + "vmla.f32 %q[c0], q5, d0[0]\n" + "vmla.f32 %q[c1], q5, d1[0]\n" + "vmla.f32 %q[c2], q5, d2[0]\n" + "vmla.f32 %q[c3], q5, d3[0]\n" + "vmla.f32 %q[c4], q5, d4[0]\n" + "vmla.f32 %q[c5], q5, d5[0]\n" + "vmla.f32 %q[c6], q5, d6[0]\n" + "vmla.f32 %q[c7], q5, d7[0]\n" + + "vmla.f32 %q[c0], q6, d0[1]\n" + "vmla.f32 %q[c1], q6, d1[1]\n" + "vmla.f32 %q[c2], q6, d2[1]\n" + "vmla.f32 %q[c3], q6, d3[1]\n" + "vmla.f32 %q[c4], q6, d4[1]\n" + "vmla.f32 %q[c5], q6, d5[1]\n" + "vmla.f32 %q[c6], q6, d6[1]\n" + "vmla.f32 %q[c7], q6, d7[1]\n" + + "vmla.f32 %q[c0], q7, d1[0]\n" + "vmla.f32 %q[c1], q7, d2[0]\n" + "vmla.f32 %q[c2], q7, d3[0]\n" + "vmla.f32 %q[c3], q7, d4[0]\n" + "vmla.f32 %q[c4], q7, d5[0]\n" + "vmla.f32 %q[c5], q7, d6[0]\n" + "vmla.f32 %q[c6], q7, d7[0]\n" + "vmla.f32 %q[c7], q7, d8[0]\n" + + //! row 1 + "vld1.32 {d10, d11}, [%[weight_ptr]]!\n" + "vld1.32 {d0, d1}, [%[src_ptr]]!\n" + "vld1.32 {d2, d3}, [%[src_ptr]]!\n" + "vld1.32 {d4, d5}, [%[src_ptr]]!\n" + "vld1.32 {d6, d7}, [%[src_ptr]]!\n" + "vld1.32 {d8}, [%[src_ptr]]!\n" + "vld1.32 {d12, d13}, [%[weight_ptr]]!\n" + "vld1.32 {d14, d15}, [%[weight_ptr]]!\n" + "add %[src_ptr], %[src_ptr], %[iw_skip_bytes]\n" + + "vmla.f32 %q[c0], q5, d0[0]\n" + "vmla.f32 %q[c1], q5, d1[0]\n" + "vmla.f32 %q[c2], q5, d2[0]\n" + "vmla.f32 %q[c3], q5, d3[0]\n" + "vmla.f32 %q[c4], q5, d4[0]\n" + "vmla.f32 %q[c5], q5, d5[0]\n" + "vmla.f32 %q[c6], q5, d6[0]\n" + "vmla.f32 %q[c7], q5, d7[0]\n" + + "vmla.f32 %q[c0], q6, d0[1]\n" + "vmla.f32 %q[c1], q6, d1[1]\n" + "vmla.f32 %q[c2], q6, d2[1]\n" + "vmla.f32 %q[c3], q6, d3[1]\n" + "vmla.f32 %q[c4], q6, d4[1]\n" + "vmla.f32 %q[c5], q6, d5[1]\n" + "vmla.f32 %q[c6], q6, d6[1]\n" + "vmla.f32 %q[c7], q6, d7[1]\n" + + "vmla.f32 %q[c0], q7, d1[0]\n" + "vmla.f32 %q[c1], q7, d2[0]\n" + "vmla.f32 %q[c2], q7, d3[0]\n" + "vmla.f32 %q[c3], q7, d4[0]\n" + "vmla.f32 %q[c4], q7, d5[0]\n" + "vmla.f32 %q[c5], q7, d6[0]\n" + "vmla.f32 %q[c6], q7, d7[0]\n" + "vmla.f32 %q[c7], q7, d8[0]\n" + + //! row 2 + "vld1.32 {d10, d11}, [%[weight_ptr]]!\n" + "vld1.32 {d0, d1}, [%[src_ptr]]!\n" + "vld1.32 {d2, d3}, [%[src_ptr]]!\n" + "vld1.32 {d4, d5}, [%[src_ptr]]!\n" + "vld1.32 {d6, d7}, [%[src_ptr]]!\n" + "vld1.32 {d8}, [%[src_ptr]]!\n" + "vld1.32 {d12, d13}, [%[weight_ptr]]!\n" + "vld1.32 {d14, d15}, [%[weight_ptr]]!\n" + "add %[src_ptr], %[src_ptr], %[ld_src_ic_skip_bytes]\n" + + "vmla.f32 %q[c0], q5, d0[0]\n" + "vmla.f32 %q[c1], q5, d1[0]\n" + "vmla.f32 %q[c2], q5, d2[0]\n" + "vmla.f32 %q[c3], q5, d3[0]\n" + "vmla.f32 %q[c4], q5, d4[0]\n" + "vmla.f32 %q[c5], q5, d5[0]\n" + "vmla.f32 %q[c6], q5, d6[0]\n" + "vmla.f32 %q[c7], q5, d7[0]\n" + + "vmla.f32 %q[c0], q6, d0[1]\n" + "vmla.f32 %q[c1], q6, d1[1]\n" + "vmla.f32 %q[c2], q6, d2[1]\n" + "vmla.f32 %q[c3], q6, d3[1]\n" + "vmla.f32 %q[c4], q6, d4[1]\n" + "vmla.f32 %q[c5], q6, d5[1]\n" + "vmla.f32 %q[c6], q6, d6[1]\n" + "vmla.f32 %q[c7], q6, d7[1]\n" + + "vmla.f32 %q[c0], q7, d1[0]\n" + "vmla.f32 %q[c1], q7, d2[0]\n" + "vmla.f32 %q[c2], q7, d3[0]\n" + "vmla.f32 %q[c3], q7, d4[0]\n" + "vmla.f32 %q[c4], q7, d5[0]\n" + "vmla.f32 %q[c5], q7, d6[0]\n" + "vmla.f32 %q[c6], q7, d7[0]\n" + "vmla.f32 %q[c7], q7, d8[0]\n" + + "6:\n" + : [c0] "+w"(c[0][0]), [c1] "+w"(c[0][1]), + [c2] "+w"(c[0][2]), [c3] "+w"(c[0][3]), + [c4] "+w"(c[0][4]), [c5] "+w"(c[0][5]), + [c6] "+w"(c[0][6]), [c7] "+w"(c[0][7]), + [src_ptr] "+r"(src_ptr), [weight_ptr] "+r"(weight_ptr) + + : [ld_src_ic_skip_bytes] "r"(ld_src_ic_skip_bytes), + [iw_skip_bytes] "r"(iw_skip_bytes) + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", + "d9", "d10", "d11", "d12", "d13", "d14", "d15", "r1", + "r2", "cc", "memory"); + } + store_ocx_ow8_remain_static(c, op, dst_ptr, + ld_dst_oc); + } +}; + +template +struct KerNeonXXs2NchwNchw44FP32 { + static void impl(const float32_t* src_ptr, const float32_t* weight_ptr, + const float32_t* bias_ptr, float32_t* dst_ptr, int ic, + int ih, int iw, int ld_dst_oc, const Op& op) { + constexpr int oc_block = 4; + constexpr int stride = 2; + constexpr int remain_w = 8; + constexpr int ow_block = 8; + constexpr int loop_ic_step = 1; + constexpr int filter_size = 3; + constexpr int oc_step = 4; + constexpr int src_line_block = ow_block * stride + filter_size - stride; + + const int iw_skip_bytes = + (iw - round_up(src_line_block, 2)) * sizeof(float); + const int ld_src_ic_skip_bytes = + iw * (ih - filter_size) * sizeof(float) + iw_skip_bytes; + constexpr int c_dim = OCHelper::val; + float32x4_t c[1][8]; + init_ocx_ow8(c, bias_ptr, oc_step); + /** + * c q8-q15 + * src q0-q4 + * weight q5-q7 + * optimized for big core + * + */ + for (int ic_idx = 0; ic_idx < ic; ic_idx += loop_ic_step) { + asm volatile( + + "2:\n" + //! row 0 + "vld1.32 {d10, d11}, [%[weight_ptr]]!\n" + "vld1.32 {d0, d1}, [%[src_ptr]]!\n" + "vld1.32 {d2, d3}, [%[src_ptr]]!\n" + "vld1.32 {d4, d5}, [%[src_ptr]]!\n" + "vmla.f32 %q[c0], q5, d0[0]\n" + "vld1.32 {d6, d7}, [%[src_ptr]]!\n" + "vmla.f32 %q[c1], q5, d1[0]\n" + "vmla.f32 %q[c2], q5, d2[0]\n" + "vld1.32 {d12, d13}, [%[weight_ptr]]!\n" + "vmla.f32 %q[c3], q5, d3[0]\n" + "vld1.32 {d14, d15}, [%[weight_ptr]]!\n" + "vmla.f32 %q[c4], q5, d4[0]\n" + "vld1.32 {d8}, [%[src_ptr]]!\n" + "vmla.f32 %q[c5], q5, d5[0]\n" + "add %[src_ptr], %[src_ptr], %[iw_skip_bytes]\n" + "vmla.f32 %q[c6], q5, d6[0]\n" + "vmla.f32 %q[c7], q5, d7[0]\n" + "vld1.32 {d10, d11}, [%[weight_ptr]]!\n" + + "vmla.f32 %q[c0], q6, d0[1]\n" + "vmla.f32 %q[c1], q6, d1[1]\n" + "vmla.f32 %q[c2], q6, d2[1]\n" + "vmla.f32 %q[c3], q6, d3[1]\n" + "vmla.f32 %q[c4], q6, d4[1]\n" + "vmla.f32 %q[c5], q6, d5[1]\n" + "vmla.f32 %q[c6], q6, d6[1]\n" + "vmla.f32 %q[c7], q6, d7[1]\n" + "vld1.32 {d12, d13}, [%[weight_ptr]]!\n" + + "vmla.f32 %q[c0], q7, d1[0]\n" + "vld1.32 {d0, d1}, [%[src_ptr]]!\n" + "vmla.f32 %q[c1], q7, d2[0]\n" + "vmla.f32 %q[c2], q7, d3[0]\n" + "vld1.32 {d2, d3}, [%[src_ptr]]!\n" + "vmla.f32 %q[c3], q7, d4[0]\n" + "vmla.f32 %q[c4], q7, d5[0]\n" + "vld1.32 {d4, d5}, [%[src_ptr]]!\n" + "vmla.f32 %q[c5], q7, d6[0]\n" + "vmla.f32 %q[c6], q7, d7[0]\n" + "vld1.32 {d6, d7}, [%[src_ptr]]!\n" + "vmla.f32 %q[c7], q7, d8[0]\n" + "vld1.32 {d14, d15}, [%[weight_ptr]]!\n" + //! row 1 + + "vmla.f32 %q[c0], q5, d0[0]\n" + "vld1.32 {d8}, [%[src_ptr]]!\n" + "vmla.f32 %q[c1], q5, d1[0]\n" + "add %[src_ptr], %[src_ptr], %[iw_skip_bytes]\n" + "vmla.f32 %q[c2], q5, d2[0]\n" + "vmla.f32 %q[c3], q5, d3[0]\n" + "vmla.f32 %q[c4], q5, d4[0]\n" + "vmla.f32 %q[c5], q5, d5[0]\n" + "vmla.f32 %q[c6], q5, d6[0]\n" + "vmla.f32 %q[c7], q5, d7[0]\n" + "vld1.32 {d10, d11}, [%[weight_ptr]]!\n" + + "vmla.f32 %q[c0], q6, d0[1]\n" + "vmla.f32 %q[c1], q6, d1[1]\n" + "vmla.f32 %q[c2], q6, d2[1]\n" + "vmla.f32 %q[c3], q6, d3[1]\n" + "vmla.f32 %q[c4], q6, d4[1]\n" + "vmla.f32 %q[c5], q6, d5[1]\n" + "vmla.f32 %q[c6], q6, d6[1]\n" + "vmla.f32 %q[c7], q6, d7[1]\n" + "vld1.32 {d12, d13}, [%[weight_ptr]]!\n" + + "vmla.f32 %q[c0], q7, d1[0]\n" + "vld1.32 {d0, d1}, [%[src_ptr]]!\n" + "vmla.f32 %q[c1], q7, d2[0]\n" + "vmla.f32 %q[c2], q7, d3[0]\n" + "vld1.32 {d2, d3}, [%[src_ptr]]!\n" + "vmla.f32 %q[c3], q7, d4[0]\n" + "vmla.f32 %q[c4], q7, d5[0]\n" + "vld1.32 {d4, d5}, [%[src_ptr]]!\n" + "vmla.f32 %q[c5], q7, d6[0]\n" + "vmla.f32 %q[c6], q7, d7[0]\n" + "vld1.32 {d6, d7}, [%[src_ptr]]!\n" + "vmla.f32 %q[c7], q7, d8[0]\n" + "vld1.32 {d14, d15}, [%[weight_ptr]]!\n" + //! row 2 + + "vmla.f32 %q[c0], q5, d0[0]\n" + "vld1.32 {d8}, [%[src_ptr]]!\n" + "vmla.f32 %q[c1], q5, d1[0]\n" + "add %[src_ptr], %[src_ptr], %[ld_src_ic_skip_bytes]\n" + "vmla.f32 %q[c2], q5, d2[0]\n" + "vmla.f32 %q[c3], q5, d3[0]\n" + "vmla.f32 %q[c4], q5, d4[0]\n" + "vmla.f32 %q[c5], q5, d5[0]\n" + "vmla.f32 %q[c6], q5, d6[0]\n" + "vmla.f32 %q[c7], q5, d7[0]\n" + + "vmla.f32 %q[c0], q6, d0[1]\n" + "vmla.f32 %q[c1], q6, d1[1]\n" + "vmla.f32 %q[c2], q6, d2[1]\n" + "vmla.f32 %q[c3], q6, d3[1]\n" + "vmla.f32 %q[c4], q6, d4[1]\n" + "vmla.f32 %q[c5], q6, d5[1]\n" + "vmla.f32 %q[c6], q6, d6[1]\n" + "vmla.f32 %q[c7], q6, d7[1]\n" + + "vmla.f32 %q[c0], q7, d1[0]\n" + "vmla.f32 %q[c1], q7, d2[0]\n" + "vmla.f32 %q[c2], q7, d3[0]\n" + "vmla.f32 %q[c3], q7, d4[0]\n" + "vmla.f32 %q[c4], q7, d5[0]\n" + "vmla.f32 %q[c5], q7, d6[0]\n" + "vmla.f32 %q[c6], q7, d7[0]\n" + "vmla.f32 %q[c7], q7, d8[0]\n" + + "6:\n" + : [c0] "+w"(c[0][0]), [c1] "+w"(c[0][1]), + [c2] "+w"(c[0][2]), [c3] "+w"(c[0][3]), + [c4] "+w"(c[0][4]), [c5] "+w"(c[0][5]), + [c6] "+w"(c[0][6]), [c7] "+w"(c[0][7]), + [src_ptr] "+r"(src_ptr), [weight_ptr] "+r"(weight_ptr) + + : [ld_src_ic_skip_bytes] "r"(ld_src_ic_skip_bytes), + [iw_skip_bytes] "r"(iw_skip_bytes) + : "d0", "d1", "d2", "d3", "d4", "d5", "d6", "d7", "d8", + "d9", "d10", "d11", "d12", "d13", "d14", "d15", "r1", + "r2", "cc", "memory"); + } + store_ocx_ow8_remain_static(c, op, dst_ptr, + ld_dst_oc); + } +}; + +#endif template struct KerNeonXXs2NchwNchw44FP32 -void fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44( - const float32_t* src, const float32_t* filter, const float32_t* bias, - float32_t*, float32_t* dst, const int oc, const int ic, const int ih, - const int iw, const int oh, const int oh_block, const int ow, - const Op& op, const int, const int) { - constexpr int fh = filter_size; - constexpr int fw = filter_size; - constexpr int ic_step = 1; - constexpr int big_oc_step = 8; - constexpr int oc_step = 4; - constexpr int ih_step = 1; - constexpr int oh_step = 1; - constexpr int ow_step = 8; - constexpr int stride_h = stride; - constexpr int stride_w = stride; - constexpr int pack_iw_len = 1; - - const int img_stride = oh * ow; - const int ow_end = ow / ow_step * ow_step; - const int ow_remain = ow - ow_end; - const int oc_end = oc / big_oc_step * big_oc_step; - const int oc_remain = oc - oc_end; - const int ld_dst_oc = oc_step * img_stride; - - using remain_fun = std::function; - remain_fun kern_big_oc_remain = nullptr; - remain_fun kern_small_oc_remain = nullptr; - - switch (ow_remain) { +struct ConvDirectFp32NchwNchw44 { + static MEGDNN_ALWAYS_INLINE void impl( + const float32_t* src, const float32_t* filter, + const float32_t* bias, float32_t*, float32_t* dst, const int oc, + const int ic, const int ih, const int iw, const int oh, + const int oh_block, const int ow, const Op& op) { + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 1; +#if MEGDNN_ARMV7 + constexpr int big_oc_step = 4; +#else + constexpr int big_oc_step = 8; +#endif + constexpr int oc_step = 4; + constexpr int ih_step = 1; + constexpr int oh_step = 1; + constexpr int ow_step = 8; + constexpr int stride_h = stride; + constexpr int stride_w = stride; + constexpr int pack_iw_len = 1; + + const int img_stride = oh * ow; + const int ow_end = ow / ow_step * ow_step; + const int ow_remain = ow - ow_end; + const int oc_end = oc / big_oc_step * big_oc_step; + const int oc_remain = oc - oc_end; + const int ld_dst_oc = oc_step * img_stride; + + using remain_fun = std::function; + remain_fun kern_big_oc_remain = nullptr; + remain_fun kern_small_oc_remain = nullptr; + + switch (ow_remain) { #define cb(step) \ case step: \ kern_big_oc_remain = \ @@ -356,69 +711,209 @@ void fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44( oc_step, stride, ow_step>::impl; \ break; - UNROLL_CALL_RAW(8, cb); - default: - megdnn_assert(0, "no remain %d for kern", ow_remain); - } - for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { - const int weight_offset = oc_idx * ic * fh * fw; - for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { - for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const int src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const int dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonXXs2NchwNchw44FP32::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, - ih, iw, ld_dst_oc, op); + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %d for kern", ow_remain); + } +#undef cb + for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const int weight_offset = oc_idx * ic * fh * fw; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = (oh_idx * stride_h * iw + + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44FP32< + bias_mode, Op, ow_step, filter_size, big_oc_step, + stride, ow_step>::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, ih, iw, + ld_dst_oc, op); + } + if (ow_remain > 0) { + const int src_offset = (oh_idx * stride_h * iw + + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain(src + src_offset, filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } } - if (ow_remain > 0) { - const int src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const int dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_big_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, iw, - ld_dst_oc, op); + } + if (oc_remain > 0) { + int oc_idx = oc_end; + const int weight_offset = oc_idx * ic * fh * fw; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = (oh_idx * stride_h * iw + + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44FP32< + bias_mode, Op, ow_step, filter_size, oc_step, + stride, ow_step>::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, ih, iw, + ld_dst_oc, op); + } + if (ow_remain > 0) { + const int src_offset = (oh_idx * stride_h * iw + + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_end) * oc_step; + kern_small_oc_remain(src + src_offset, + filter + weight_offset, bias + oc_idx, + dst + dst_offset, ic, ih, iw, + ld_dst_oc, op); + } } } } - if (oc_remain > 0) { - int oc_idx = oc_end; - const int weight_offset = oc_idx * ic * fh * fw; - for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { - for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { - const int src_offset = - (oh_idx * stride_h * iw + ow_idx * stride_w * ih_step) * - ic_step * pack_iw_len; - const int dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_idx) * oc_step; - KerNeonXXs2NchwNchw44FP32::impl(src + src_offset, - filter + weight_offset, - bias + oc_idx, - dst + dst_offset, ic, - ih, iw, ld_dst_oc, op); +}; + +#if MEGDNN_ARMV7 +template +struct ConvDirectFp32NchwNchw44 { + static MEGDNN_ALWAYS_INLINE void impl( + const float32_t* src, const float32_t* filter, + const float32_t* bias, float32_t*, float32_t* dst, const int oc, + const int ic, const int ih, const int iw, const int oh, + const int oh_block, const int ow, const Op& op) { + constexpr int filter_size = 3; + constexpr int stride = 2; + constexpr int fh = filter_size; + constexpr int fw = filter_size; + constexpr int ic_step = 1; + constexpr int oc_step = 4; + constexpr int big_oc_step = oc_step; + constexpr int ih_step = 1; + constexpr int oh_step = 1; + constexpr int ow_step = 8; + constexpr int stride_h = stride; + constexpr int stride_w = stride; + constexpr int pack_iw_len = 1; + + const int img_stride = oh * ow; + const int ow_end = ow / ow_step * ow_step; + const int ow_remain = ow - ow_end; + const int oc_end = oc / big_oc_step * big_oc_step; + const int ld_dst_oc = oc_step * img_stride; + + using remain_fun = std::function; + remain_fun kern_big_oc_remain = nullptr; + + switch (ow_remain) { +#define cb(step) \ + case step: \ + kern_big_oc_remain = \ + KerNeonXXs2NchwNchw44FP32::impl; \ + break; + + UNROLL_CALL_RAW(8, cb); + default: + megdnn_assert(0, "no remain %d for kern", ow_remain); + } +#undef cb +#if MGB_ENABLE_CPUINFO + auto arch_tag = + cpuinfo_get_current_core()->uarch == cpuinfo_uarch_cortex_a7 + ? CpuTag::A7_TAG + : CpuTag::DEFAULT_CPU_TAG; +#else + auto arch_tag = CpuTag::A7_TAG; +#endif + if (arch_tag == CpuTag::A7_TAG) { + for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const int weight_offset = oc_idx * ic * fh * fw; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = (oh_idx * stride_h * iw + + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44FP32< + bias_mode, Op, ow_step, filter_size, + big_oc_step, stride, ow_step, + CpuTag::A7_TAG>::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, + dst + dst_offset, ic, ih, + iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const int src_offset = (oh_idx * stride_h * iw + + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain(src + src_offset, + filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, + ih, iw, ld_dst_oc, op); + } + } } - if (ow_remain > 0) { - const int src_offset = - (oh_idx * stride_h * iw + ow_end * stride_w * ih_step) * - ic_step * pack_iw_len; - const int dst_offset = - oc_idx * img_stride + (oh_idx * ow + ow_end) * oc_step; - kern_small_oc_remain(src + src_offset, filter + weight_offset, - bias + oc_idx, dst + dst_offset, ic, ih, - iw, ld_dst_oc, op); + } else { + for (int oc_idx = 0; oc_idx < oc_end; oc_idx += big_oc_step) { + const int weight_offset = oc_idx * ic * fh * fw; + for (int oh_idx = 0; oh_idx < oh_block; oh_idx += oh_step) { + for (int ow_idx = 0; ow_idx < ow_end; ow_idx += ow_step) { + const int src_offset = (oh_idx * stride_h * iw + + ow_idx * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_idx) * oc_step; + KerNeonXXs2NchwNchw44FP32< + bias_mode, Op, ow_step, filter_size, + big_oc_step, stride, + ow_step>::impl(src + src_offset, + filter + weight_offset, + bias + oc_idx, dst + dst_offset, + ic, ih, iw, ld_dst_oc, op); + } + if (ow_remain > 0) { + const int src_offset = (oh_idx * stride_h * iw + + ow_end * stride_w * ih_step) * + ic_step * pack_iw_len; + const int dst_offset = oc_idx * img_stride + + (oh_idx * ow + ow_end) * oc_step; + kern_big_oc_remain(src + src_offset, + filter + weight_offset, + bias + oc_idx, dst + dst_offset, ic, + ih, iw, ld_dst_oc, op); + } + } } } } +}; + +#endif + +} // namespace + +template +void fp32_direct_nchw_nchw44::conv_direct_fp32_nchw_nchw44( + const float32_t* src, const float32_t* filter, const float32_t* bias, + float32_t*, float32_t* dst, const int oc, const int ic, const int ih, + const int iw, const int oh, const int oh_block, const int ow, + const Op& op, const int, const int) { + ConvDirectFp32NchwNchw44::impl( + src, filter, bias, nullptr, dst, oc, ic, ih, iw, oh, oh_block, ow, + op); } #define INSTANTIATION(stride, filter_size, bias_mode, Op) \ diff --git a/dnn/src/arm_common/handle.h b/dnn/src/arm_common/handle.h index ffc63712..eaa46f8d 100644 --- a/dnn/src/arm_common/handle.h +++ b/dnn/src/arm_common/handle.h @@ -10,6 +10,9 @@ */ #pragma once #include "src/fallback/handle.h" +#if MGB_ENABLE_CPUINFO +#include "cpuinfo.h" +#endif namespace megdnn { namespace arm_common { @@ -20,6 +23,9 @@ class HandleImpl: public fallback::HandleImpl { HandleType type = HandleType::ARM_COMMON): fallback::HandleImpl::HandleImpl(computing_handle, type) { + #if MGB_ENABLE_CPUINFO + cpuinfo_initialize(); + #endif } template diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index 2a784762..c299ae07 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -243,6 +243,7 @@ static void benchmark_convbias(Handle* handle, std::string int_name, if (is_fp32) { run(1, 1, 4, 112, 112, 2, 2, true); + run(1, 3, 24, 224, 224, 3, 2, true); run(1, 3, 32, 224, 224, 3, 2, true); run(1, 3, 64, 224, 224, 7, 2, true);