From f2b42bf09e2fdf43b824c10752a5127eef463991 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Sun, 7 Feb 2021 10:02:26 +0800 Subject: [PATCH] chore(dotprod): add arm dotprod attribute for easy use GitOrigin-RevId: 78c3e72218b8db009542b00e3688315d058d37fa --- dnn/src/aarch64/conv_bias/int8/algos.cpp | 42 +++++++++-- dnn/src/aarch64/conv_bias/int8/strategy.cpp | 17 ++--- dnn/src/aarch64/conv_bias/int8/strategy.h | 5 +- dnn/src/aarch64/conv_bias/quint8/algos.cpp | 42 ++++++++++- dnn/src/aarch64/conv_bias/quint8/strategy.cpp | 88 ++++++++++++---------- dnn/src/aarch64/conv_bias/quint8/strategy.h | 46 +++++++---- dnn/src/aarch64/matrix_mul/algos.cpp | 37 +++++---- dnn/src/aarch64/matrix_mul/algos.h | 12 +-- dnn/src/aarch64/matrix_mul/fp32/strategy.cpp | 3 - dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h | 2 - dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h | 2 - .../aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h | 3 - dnn/src/aarch64/matrix_mul/int8/strategy.cpp | 3 - dnn/src/aarch64/matrix_mul/int8/strategy.h | 2 - .../aarch64/matrix_mul/int8_dot/kernel_8x12x4.h | 12 +-- .../matrix_mul/int8_dot/kernel_mk4_8x12x4.h | 9 ++- dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp | 2 +- dnn/src/aarch64/matrix_mul/int8_dot/strategy.h | 2 +- dnn/src/aarch64/matrix_mul/opr_impl.cpp | 20 ++--- dnn/src/aarch64/matrix_mul/opr_impl.h | 10 +-- dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h | 2 - dnn/src/aarch64/matrix_mul/quint8/strategy.cpp | 2 - dnn/src/aarch64/matrix_mul/quint8/strategy.h | 2 - dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp | 9 +-- dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h | 5 +- .../aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h | 8 +- dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp | 10 +-- dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h | 4 +- .../f32_direct_nchw_nchw44_kern_common.h | 3 - dnn/src/arm_common/conv_bias/int8/algos.cpp | 8 +- dnn/src/arm_common/conv_bias/int8/algos.h | 2 +- .../arm_common/conv_bias/int8/direct_dotprod.cpp | 11 ++- dnn/src/arm_common/conv_bias/int8/direct_dotprod.h | 2 +- .../conv_bias/int8/direct_dotprod_nchw44.cpp | 3 +- .../conv_bias/int8/direct_dotprod_nchw44.h | 5 +- .../conv_bias/int8/direct_dotprod_nchw44_algo.cpp | 6 +- .../int8/direct_kernels/dot_direct_nchw44_common.h | 6 +- .../int8/direct_kernels/dot_direct_nchw44_s1.cpp | 6 +- .../int8/direct_kernels/dot_direct_nchw44_s2.cpp | 6 +- .../direct_kernels/dot_direct_nchw_nchw44_s1.cpp | 11 ++- .../direct_kernels/dot_direct_nchw_nchw44_s2.cpp | 11 ++- .../conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp | 5 +- .../conv_bias/int8/dot_direct_nchw_nchw44_kern.h | 2 +- .../arm_common/conv_bias/int8/stride1_dotprod.cpp | 2 +- .../arm_common/conv_bias/int8/stride1_dotprod.h | 2 +- .../arm_common/conv_bias/int8/stride2_dotprod.cpp | 2 +- .../arm_common/conv_bias/int8/stride2_dotprod.h | 2 +- dnn/src/arm_common/conv_bias/opr_impl.cpp | 4 +- dnn/src/arm_common/conv_bias/opr_impl.h | 2 +- dnn/src/arm_common/conv_bias/quint8/algos.cpp | 11 ++- dnn/src/arm_common/conv_bias/quint8/algos.h | 2 +- .../arm_common/conv_bias/quint8/direct_dotprod.cpp | 10 ++- .../arm_common/conv_bias/quint8/direct_dotprod.h | 2 +- .../conv_bias/quint8/stride1_dotprod.cpp | 2 +- .../arm_common/conv_bias/quint8/stride1_dotprod.h | 2 +- .../conv_bias/quint8/stride2_dotprod.cpp | 2 +- .../arm_common/conv_bias/quint8/stride2_dotprod.h | 2 +- dnn/src/arm_common/convolution/int8x8x32/algos.cpp | 10 ++- dnn/src/arm_common/convolution/int8x8x32/algos.h | 2 +- .../int8x8x32/conv_backdata_stride1.cpp | 9 ++- .../convolution/int8x8x32/conv_backdata_stride1.h | 2 +- .../int8x8x32/conv_backdata_stride2.cpp | 8 +- .../convolution/int8x8x32/conv_backdata_stride2.h | 2 +- dnn/src/arm_common/convolution/opr_impl.cpp | 4 +- dnn/src/arm_common/convolution/opr_impl.h | 2 +- dnn/src/arm_common/convolution/quint8/algos.cpp | 10 ++- dnn/src/arm_common/convolution/quint8/algos.h | 2 +- .../convolution/quint8/conv_backdata_stride1.cpp | 9 ++- .../convolution/quint8/conv_backdata_stride1.h | 5 +- .../convolution/quint8/conv_backdata_stride2.cpp | 8 +- .../convolution/quint8/conv_backdata_stride2.h | 5 +- dnn/src/arm_common/matrix_mul/algos.cpp | 8 +- dnn/src/arm_common/matrix_mul/algos.h | 2 +- dnn/src/arm_common/matrix_mul/int8/gemv.cpp | 35 +++++++-- dnn/src/arm_common/matrix_mul/int8/gemv.h | 2 +- dnn/src/arm_common/matrix_mul/opr_impl.cpp | 4 +- dnn/src/arm_common/matrix_mul/opr_impl.h | 2 +- dnn/src/arm_common/neon_struct.h | 5 +- dnn/src/arm_common/simd_macro/marm_neon.h | 27 +++++-- dnn/src/armv7/matrix_mul/algos.cpp | 14 +++- dnn/src/armv7/matrix_mul/algos.h | 2 +- dnn/src/armv7/matrix_mul/asm/common.h | 1 - dnn/src/armv7/matrix_mul/fp32/strategy_4x12.cpp | 1 - dnn/src/armv7/matrix_mul/int8/kernel_6x8x4.h | 4 +- .../armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h | 5 +- dnn/src/armv7/matrix_mul/int8/strategy.cpp | 2 +- dnn/src/armv7/matrix_mul/int8/strategy.h | 2 +- dnn/src/armv7/matrix_mul/opr_impl.cpp | 4 +- dnn/src/armv7/matrix_mul/opr_impl.h | 2 +- dnn/src/armv7/matrix_mul/quint8/kernel_dot_4x8x4.h | 5 +- dnn/src/armv7/matrix_mul/quint8/strategy.cpp | 2 +- dnn/src/armv7/matrix_mul/quint8/strategy.h | 2 +- dnn/src/common/utils.h | 7 ++ .../conv_bias/conv1x1/algos_conv1x1_gemv.cpp | 2 +- dnn/test/aarch64/matrix_mul.cpp | 6 +- dnn/test/arm_common/conv_bias.cpp | 11 ++- dnn/test/arm_common/conv_bias_multi_thread.cpp | 2 +- .../conv_bias_multi_thread_benchmark.cpp | 8 +- .../arm_common/conv_bias_multi_thread_conv1x1.cpp | 20 +++-- .../arm_common/conv_bias_multi_thread_im2col.cpp | 21 +++--- .../conv_bias_multi_thread_weight_preprocess.cpp | 37 +++++---- dnn/test/arm_common/convolution.cpp | 4 +- dnn/test/arm_common/matrix_mul.cpp | 7 +- dnn/test/armv7/matrix_mul.cpp | 6 +- src/megbrain_build_config.h.in | 22 ++++++ 105 files changed, 553 insertions(+), 344 deletions(-) diff --git a/dnn/src/aarch64/conv_bias/int8/algos.cpp b/dnn/src/aarch64/conv_bias/int8/algos.cpp index c7124f94..d6d0d53d 100644 --- a/dnn/src/aarch64/conv_bias/int8/algos.cpp +++ b/dnn/src/aarch64/conv_bias/int8/algos.cpp @@ -67,6 +67,23 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( size_t K = IC * FH * FW; size_t N = OH * OW; +#if MGB_ENABLE_DOT +#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ + _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + part2 = megdnn::matmul::GemmInterleaved< \ + matmul::gemm_##_gemm##_##_bias##_##_nonline>( \ + M, N, K, false, false, strategy) \ + .get_workspace_size(); + + if (cpuinfo_has_arm_neon_dot()) { + DISPATCH_GEMM_BIAS(s8_8x12, 1) + } else { + DISPATCH_GEMM_BIAS(s8_4x4, 0) + } +#else #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ _bias_midout_enum, _nonline, \ _nonline_midout_enum) \ @@ -80,11 +97,7 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle( .get_workspace_size(); \ } \ MIDOUT_END() - -#if !(__ARM_FEATURE_DOTPROD) DISPATCH_GEMM_BIAS(s8_4x4, 0) -#else - DISPATCH_GEMM_BIAS(s8_8x12, 1) #endif #undef DISPATCH_GEMM_STRATEGY } @@ -158,6 +171,23 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, size_t K = IC * FH * FW; size_t N = OH * OW; +#if MGB_ENABLE_DOT +#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ + _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + megdnn::matmul::GemmInterleaved< \ + matmul::gemm_##_gemm##_##_bias##_##_nonline> \ + gemm_interleaved(M, N, K, false, false, strategy); \ + gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); + + if (cpuinfo_has_arm_neon_dot()) { + DISPATCH_GEMM_BIAS(s8_8x12, 1) + } else { + DISPATCH_GEMM_BIAS(s8_4x4, 0) + } +#else #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ _bias_midout_enum, _nonline, \ _nonline_midout_enum) \ @@ -172,11 +202,7 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param, bias); \ } \ MIDOUT_END() - -#if !(__ARM_FEATURE_DOTPROD) DISPATCH_GEMM_BIAS(s8_4x4, 0) -#else - DISPATCH_GEMM_BIAS(s8_8x12, 1) #endif #undef DISPATCH_GEMM_STRATEGY } diff --git a/dnn/src/aarch64/conv_bias/int8/strategy.cpp b/dnn/src/aarch64/conv_bias/int8/strategy.cpp index 03ff9119..6506912f 100644 --- a/dnn/src/aarch64/conv_bias/int8/strategy.cpp +++ b/dnn/src/aarch64/conv_bias/int8/strategy.cpp @@ -26,7 +26,7 @@ namespace impl { template struct KernCaller; -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT template struct KernCaller { static void run(const dt_int8* packA, const dt_int8* packB, size_t M, @@ -118,7 +118,7 @@ struct KernCaller { } }; -#else +#endif template struct KernCaller { @@ -196,10 +196,8 @@ struct KernCaller { } }; -#endif - } // namespace impl -#if !(__ARM_FEATURE_DOTPROD) + MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity) void gemm_s8_4x4_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, @@ -227,7 +225,8 @@ void gemm_s8_4x4_nobias_identity::pack_B(dt_int8* out, const dt_int8* in, size_t gemm_s8_4x4_nobias_identity::get_workspace_size() const { return 4 * 4 * sizeof(dt_int32); } -#else + +#if MGB_ENABLE_DOT MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity) void gemm_s8_8x12_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, @@ -277,11 +276,10 @@ size_t gemm_s8_8x12_nobias_identity::get_workspace_size() const { #define DEFINE_OP(_Op) \ arm_common::_Op op(scale_A* scale_B, scale_C); -#if !(__ARM_FEATURE_DOTPROD) KERN(4, 4, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) KERN(4, 4, nobias, BiasMode::NO_BIAS, relu, ReluOp) KERN(4, 4, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) -#else +#if MGB_ENABLE_DOT KERN(8, 12, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) KERN(8, 12, nobias, BiasMode::NO_BIAS, relu, ReluOp) KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) @@ -291,12 +289,11 @@ KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) #define DEFINE_OP(_Op) \ arm_common::_Op op(scale_A* scale_B, \ scale_A* scale_B, scale_C); -#if !(__ARM_FEATURE_DOTPROD) KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) -#else +#if MGB_ENABLE_DOT KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, diff --git a/dnn/src/aarch64/conv_bias/int8/strategy.h b/dnn/src/aarch64/conv_bias/int8/strategy.h index e419749b..d4d0224d 100644 --- a/dnn/src/aarch64/conv_bias/int8/strategy.h +++ b/dnn/src/aarch64/conv_bias/int8/strategy.h @@ -15,7 +15,6 @@ namespace megdnn { namespace aarch64 { namespace matmul { -#if !(__ARM_FEATURE_DOTPROD) /** * \brief base strategy of gemm. * @@ -39,8 +38,7 @@ MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_relu, MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_hswish, gemm_s8_4x4_nobias_identity); - -#else +#if MGB_ENABLE_DOT MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 8, 12, 4, false, true, gemm_s8_8x12_nobias_identity); @@ -59,7 +57,6 @@ MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_relu, MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_hswish, gemm_s8_8x12_nobias_identity); - #endif } // namespace matmul diff --git a/dnn/src/aarch64/conv_bias/quint8/algos.cpp b/dnn/src/aarch64/conv_bias/quint8/algos.cpp index 5dbaea87..713f516a 100644 --- a/dnn/src/aarch64/conv_bias/quint8/algos.cpp +++ b/dnn/src/aarch64/conv_bias/quint8/algos.cpp @@ -69,6 +69,23 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( size_t K = IC * FH * FW; size_t N = OH * OW; +#if MGB_ENABLE_DOT +#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ + _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + part2 = megdnn::matmul::GemmInterleaved< \ + matmul::gemm_##_gemm##_##_bias##_##_nonline>( \ + M, N, K, false, false, strategy) \ + .get_workspace_size(); + + if (cpuinfo_has_arm_neon_dot()) { + DISPATCH_GEMM_BIAS(u8_8x8_dot, 1); + } else { + DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0); + } +#else #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ _bias_midout_enum, _nonline, \ _nonline_midout_enum) \ @@ -82,8 +99,8 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle( .get_workspace_size(); \ } \ MIDOUT_END() - - DISPATCH_GEMM_BIAS(u8_8x8, 0) + DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) +#endif #undef DISPATCH_GEMM_STRATEGY } return {nullptr, {part0, part1, part2}}; @@ -157,6 +174,23 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, size_t K = IC * FH * FW; size_t N = OH * OW; +#if MGB_ENABLE_DOT +#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ + _bias_midout_enum, _nonline, \ + _nonline_midout_enum) \ + matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \ + M, N, K, param.filter_type, param.src_type, param.dst_type); \ + megdnn::matmul::GemmInterleaved< \ + matmul::gemm_##_gemm##_##_bias##_##_nonline> \ + gemm_interleaved(M, N, K, false, false, strategy); \ + gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias); + + if (cpuinfo_has_arm_neon_dot()) { + DISPATCH_GEMM_BIAS(u8_8x8_dot, 1) + } else { + DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) + } +#else #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ _bias_midout_enum, _nonline, \ _nonline_midout_enum) \ @@ -172,7 +206,9 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param, } \ MIDOUT_END() - DISPATCH_GEMM_BIAS(u8_8x8, 0) + DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0) + +#endif #undef DISPATCH_GEMM_STRATEGY } } diff --git a/dnn/src/aarch64/conv_bias/quint8/strategy.cpp b/dnn/src/aarch64/conv_bias/quint8/strategy.cpp index b1fa5f49..dda6badf 100644 --- a/dnn/src/aarch64/conv_bias/quint8/strategy.cpp +++ b/dnn/src/aarch64/conv_bias/quint8/strategy.cpp @@ -23,12 +23,12 @@ using namespace aarch64; using namespace aarch64::matmul; namespace impl { -template +template struct KernCaller; -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT template -struct KernCaller { +struct KernCaller { static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K, dt_uint8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, @@ -120,10 +120,10 @@ struct KernCaller { } }; -#else +#endif template -struct KernCaller { +struct KernCaller { static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, size_t K, dt_uint8* C, size_t LDC, bool is_first_k, Op op, const dt_int32* bias, @@ -215,13 +215,11 @@ struct KernCaller { } }; -#endif - } // namespace impl -#if __ARM_FEATURE_DOTPROD -MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nobias_identity) +#if MGB_ENABLE_DOT +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_dot_nobias_identity) -void gemm_u8_8x8_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr, +void gemm_u8_8x8_dot_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr, int ldin, int y0, int ymax, int k0, int kmax, bool transpose) const { if (transpose) { @@ -233,7 +231,7 @@ void gemm_u8_8x8_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr, } } -void gemm_u8_8x8_nobias_identity::pack_B(uint8_t* out, const uint8_t* in, +void gemm_u8_8x8_dot_nobias_identity::pack_B(uint8_t* out, const uint8_t* in, int ldin, int x0, int xmax, int k0, int kmax, bool transpose) const { if (transpose) { @@ -245,10 +243,13 @@ void gemm_u8_8x8_nobias_identity::pack_B(uint8_t* out, const uint8_t* in, } } -#else +size_t gemm_u8_8x8_dot_nobias_identity::get_workspace_size() const { + return 8 * 8 * sizeof(dt_int32); +} -MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nobias_identity) -void gemm_u8_8x8_nobias_identity::pack_A(dt_uint8* outptr, +#endif +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nodot_nobias_identity) +void gemm_u8_8x8_nodot_nobias_identity::pack_A(dt_uint8* outptr, const dt_uint8* inptr, int ldin, int y0, int ymax, int k0, int kmax, bool transpose) const { @@ -262,7 +263,7 @@ void gemm_u8_8x8_nobias_identity::pack_A(dt_uint8* outptr, } } -void gemm_u8_8x8_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in, +void gemm_u8_8x8_nodot_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in, int ldin, int x0, int xmax, int k0, int kmax, bool transpose) const { uint8_t zB = B_dtype.param().zero_point; @@ -275,43 +276,52 @@ void gemm_u8_8x8_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in, } } -#endif -size_t gemm_u8_8x8_nobias_identity::get_workspace_size() const { +size_t gemm_u8_8x8_nodot_nobias_identity::get_workspace_size() const { return 8 * 8 * sizeof(dt_int32); } -#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \ - void gemm_u8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \ - const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, \ - size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \ - const dt_int32* bias, dt_int32* workspace) const { \ - float scale_A = A_dtype.param().scale; \ - uint8_t zp_A = A_dtype.param().zero_point; \ - float scale_B = B_dtype.param().scale; \ - uint8_t zp_B = B_dtype.param().zero_point; \ - float scale_C = C_dtype.param().scale; \ - uint8_t zp_C = C_dtype.param().zero_point; \ - DEFINE_OP(_OP); \ - impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \ - packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ - workspace, zp_A, zp_B); \ +#define KERN(_block_m, _block_n, _dot, _suffix, _bias, _BIAS, _nonline, \ + _OP) \ + void gemm_u8_##_block_m##x##_block_n##_suffix##_##_bias##_##_nonline:: \ + kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, \ + size_t N, size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \ + const dt_int32* bias, dt_int32* workspace) const { \ + float scale_A = A_dtype.param().scale; \ + uint8_t zp_A = A_dtype.param().zero_point; \ + float scale_B = B_dtype.param().scale; \ + uint8_t zp_B = B_dtype.param().zero_point; \ + float scale_C = C_dtype.param().scale; \ + uint8_t zp_C = C_dtype.param().zero_point; \ + DEFINE_OP(_OP); \ + impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n, _dot>::run( \ + packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \ + workspace, zp_A, zp_B); \ } #define DEFINE_OP(_Op) \ arm_common::_Op op(scale_A* scale_B, scale_C, zp_C); -KERN(8, 8, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) -KERN(8, 8, nobias, BiasMode::NO_BIAS, relu, ReluOp) -KERN(8, 8, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) +#if MGB_ENABLE_DOT +KERN(8, 8, true, _dot, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) +KERN(8, 8, true, _dot, nobias, BiasMode::NO_BIAS, relu, ReluOp) +KERN(8, 8, true, _dot, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) +#endif +KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) +KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, relu, ReluOp) +KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) #undef DEFINE_OP #define DEFINE_OP(_Op) \ arm_common::_Op op(scale_A* scale_B, \ scale_A* scale_B, scale_C, zp_C); -KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) -KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) -KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, - FuseAddHSwishOp) +#if MGB_ENABLE_DOT +KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) +KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) +KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) +#endif +KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) +KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) +KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp) #undef DEFINE_OP #undef KERN diff --git a/dnn/src/aarch64/conv_bias/quint8/strategy.h b/dnn/src/aarch64/conv_bias/quint8/strategy.h index fe109f1b..4562c84b 100644 --- a/dnn/src/aarch64/conv_bias/quint8/strategy.h +++ b/dnn/src/aarch64/conv_bias/quint8/strategy.h @@ -15,30 +15,46 @@ namespace megdnn { namespace aarch64 { namespace matmul { -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 4, false, true, - gemm_u8_8x8_nobias_identity); -#else + gemm_u8_8x8_dot_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_relu, + gemm_u8_8x8_dot_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_hswish, + gemm_u8_8x8_dot_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_identity, + gemm_u8_8x8_dot_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_relu, + gemm_u8_8x8_dot_nobias_identity); + +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_hswish, + gemm_u8_8x8_dot_nobias_identity); + + +#endif MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 8, false, true, - gemm_u8_8x8_nobias_identity); -#endif + gemm_u8_8x8_nodot_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nobias_relu, - gemm_u8_8x8_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_relu, + gemm_u8_8x8_nodot_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nobias_hswish, - gemm_u8_8x8_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_hswish, + gemm_u8_8x8_nodot_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_identity, - gemm_u8_8x8_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_identity, + gemm_u8_8x8_nodot_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_relu, - gemm_u8_8x8_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_relu, + gemm_u8_8x8_nodot_nobias_identity); -MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_hswish, - gemm_u8_8x8_nobias_identity); +MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_hswish, + gemm_u8_8x8_nodot_nobias_identity); } // namespace matmul diff --git a/dnn/src/aarch64/matrix_mul/algos.cpp b/dnn/src/aarch64/matrix_mul/algos.cpp index dd712195..0b640d32 100644 --- a/dnn/src/aarch64/matrix_mul/algos.cpp +++ b/dnn/src/aarch64/matrix_mul/algos.cpp @@ -24,9 +24,6 @@ #include "src/common/utils.h" #include "src/fallback/matrix_mul/gemm_impl.h" -#if MGB_ENABLE_CPUINFO -#include "cpuinfo.h" -#endif #include "midout.h" MIDOUT_DECL(megdnn_aarch64_matmul_kern) @@ -394,7 +391,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_8x8::get_kern( #endif -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT /* ==================== Int8x8x32 K8x12x4 Dotprod algo ==================== */ namespace { void int8x8x32_k8x12x4_dotprod_kern( @@ -422,6 +419,9 @@ void int8x8x32_k8x12x4_dotprod_kern( bool MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::usable( const KernSizeParam& kern_size_param) const { + if (!cpuinfo_has_arm_neon_dot()){ + return false; + } return can_be_treated_as_int8x8x32(kern_size_param); } @@ -484,6 +484,11 @@ void int8x8x32_mk4_8x12x4_dotprod_kern( bool MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::usable( const KernSizeParam& kern_size_param) const { + + if (!cpuinfo_has_arm_neon_dot()){ + return false; + } + return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && @@ -527,7 +532,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x12x4DotProd, aarch64::matmul::gemm_mk4_s8_8x12, int8_t, int32_t, AlgoDataType::QINT8X8X32, MK4_DOT); -#else +#endif /* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */ namespace { @@ -727,7 +732,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x8x8, aarch64::matmul::gemm_s8_8x8, int8_t, int32_t, AlgoDataType::QINT8X8X32, DEFAULT); -#endif /* ===================== Int8x8x16 K8x8x8 algo ===================== */ namespace { @@ -1151,7 +1155,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_kern( return kern_mk8_8x8; } -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT /* ==================== Quint8 K8x8x4 Dotprod algo ==================== */ namespace { void quint8_k8x8x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { @@ -1166,8 +1170,8 @@ void quint8_k8x8x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { Bptr = kern_param.B(); auto Cptr = kern_param.C(); - aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type); - megdnn::matmul::GemmInterleaved( + aarch64::matmul::gemm_u8_8x8_dot strategy(M, N, K, A_type, B_type, C_type); + megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, kern_param.workspace_ptr); @@ -1178,6 +1182,9 @@ void quint8_k8x8x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { bool MatrixMulImpl::AlgoQuint8K8x8x4DotProd::usable( const KernSizeParam& kern_size_param) const { + if (!cpuinfo_has_arm_neon_dot()){ + return false; + } return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm && kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 && @@ -1195,8 +1202,8 @@ size_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_workspace( auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, C_type = kern_size_param.C_type; - aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type); - return megdnn::matmul::GemmInterleaved( + aarch64::matmul::gemm_u8_8x8_dot strategy(M, N, K, A_type, B_type, C_type); + return megdnn::matmul::GemmInterleaved( M, N, K, trA, trB, strategy) .get_workspace_size(); } @@ -1212,7 +1219,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_kern( MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x4DotProd, megdnn_aarch64_matmul_kern, "AlgoQuint8K8x8x4DotProdImpl"_hash, - aarch64::matmul::gemm_u8_8x8, uint8_t, + aarch64::matmul::gemm_u8_8x8_dot, uint8_t, int32_t, AlgoDataType::QUINT8X8X32, DEFAULT); /* ===================== Quint8 Gemv DotProd algo ===================== */ @@ -1238,6 +1245,9 @@ void quint8_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { bool MatrixMulImpl::AlgoQuint8GemvDotProd::usable( const KernSizeParam& kern_size_param) const { + if (!cpuinfo_has_arm_neon_dot()){ + return false; + } return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm && kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 && @@ -1257,7 +1267,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8GemvDotProd::get_kern( const KernSizeParam&) const { return quint8_gemv_dotprod_kern; } -#else +#endif /* ===================== Quint8 K8x8x8 algo ===================== */ namespace { @@ -1322,7 +1332,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8, aarch64::matmul::gemm_u8_8x8, uint8_t, int32_t, AlgoDataType::QUINT8X8X32, DEFAULT); -#endif /* ===================== Int8x8x16 K8x8x8 algo ===================== */ namespace { diff --git a/dnn/src/aarch64/matrix_mul/algos.h b/dnn/src/aarch64/matrix_mul/algos.h index 7b7b03cb..8d690189 100644 --- a/dnn/src/aarch64/matrix_mul/algos.h +++ b/dnn/src/aarch64/matrix_mul/algos.h @@ -111,7 +111,7 @@ public: #endif -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase { public: AlgoAttribute attribute() const override { @@ -141,7 +141,7 @@ public: MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD) }; -#else +#endif class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { public: @@ -187,7 +187,6 @@ public: MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K8X8X8) }; -#endif class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { public: @@ -313,7 +312,7 @@ public: MEGDNN_DECL_ALGO_TYPE(AARCH64_INT16X16X32_MK8_8X8) }; -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase { public: AlgoAttribute attribute() const override { @@ -328,7 +327,6 @@ public: MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X4_DOTPROD) }; - class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { public: AlgoAttribute attribute() const override { @@ -344,8 +342,7 @@ public: MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT) MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_GEMV_DOTPROD) }; -#else - +#endif class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase { public: AlgoAttribute attribute() const override { @@ -358,7 +355,6 @@ public: MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X8) }; -#endif } // namespace aarch64 } // namespace megdnn diff --git a/dnn/src/aarch64/matrix_mul/fp32/strategy.cpp b/dnn/src/aarch64/matrix_mul/fp32/strategy.cpp index 7f218358..14c1238b 100644 --- a/dnn/src/aarch64/matrix_mul/fp32/strategy.cpp +++ b/dnn/src/aarch64/matrix_mul/fp32/strategy.cpp @@ -20,9 +20,6 @@ #include "src/aarch64/matrix_mul/fp32/strategy.h" #include "src/common/utils.h" -#if MGB_ENABLE_CPUINFO -#include "cpuinfo.h" -#endif using namespace megdnn; using namespace aarch64; diff --git a/dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h b/dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h index 90816b55..de5c0771 100644 --- a/dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h +++ b/dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h @@ -9,7 +9,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if !(__ARM_FEATURE_DOTPROD) #include "src/aarch64/matrix_mul/asm/common.h" #include "src/arm_common/simd_macro/marm_neon.h" @@ -851,6 +850,5 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin, } // namespace matmul_4x4x16 } // namespace aarch64 } // namespace megdnn -#endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h b/dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h index 6dfbbb31..0b0f5575 100644 --- a/dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h +++ b/dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h @@ -9,7 +9,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if !(__ARM_FEATURE_DOTPROD) #include "src/aarch64/matrix_mul/asm/common.h" #include "src/arm_common/simd_macro/marm_neon.h" @@ -1372,4 +1371,3 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr, } // namespace megdnn // vim: syntax=cpp.doxygen -#endif diff --git a/dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h b/dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h index c4dbadf1..d4f91b94 100644 --- a/dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h +++ b/dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h @@ -10,8 +10,6 @@ * implied. */ -#include -#if !(__ARM_FEATURE_DOTPROD) #include "src/aarch64/matrix_mul/asm/common.h" #include "src/arm_common/simd_macro/marm_neon.h" @@ -887,6 +885,5 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin, } // namespace matmul_4x4x16 } // namespace aarch64 } // namespace megdnn -#endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8/strategy.cpp b/dnn/src/aarch64/matrix_mul/int8/strategy.cpp index 41eb3b17..a5a37a49 100644 --- a/dnn/src/aarch64/matrix_mul/int8/strategy.cpp +++ b/dnn/src/aarch64/matrix_mul/int8/strategy.cpp @@ -9,7 +9,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if !(__ARM_FEATURE_DOTPROD) #include "src/aarch64/matrix_mul/int8/strategy.h" #include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/int8/kernel_4x4x16.h" @@ -105,7 +104,6 @@ void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M, packA += K4; } } - ///////////////////////// gemm_mk4_s8_4x4 //////////////////////////////////// MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x4); @@ -258,6 +256,5 @@ void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, packA += K4; } } -#endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8/strategy.h b/dnn/src/aarch64/matrix_mul/int8/strategy.h index b2909737..26b755e3 100644 --- a/dnn/src/aarch64/matrix_mul/int8/strategy.h +++ b/dnn/src/aarch64/matrix_mul/int8/strategy.h @@ -10,7 +10,6 @@ */ #pragma once -#if !(__ARM_FEATURE_DOTPROD) #include "src/fallback/matrix_mul/gemm_common.h" namespace megdnn { @@ -30,5 +29,4 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true, } // namespace aarch64 } // namespace megdnn -#endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h b/dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h index d59e3a60..61649fb2 100644 --- a/dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h +++ b/dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h @@ -9,8 +9,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD - +#if MGB_ENABLE_DOT #include "src/aarch64/matrix_mul/asm/common.h" #include "src/arm_common/simd_macro/marm_neon.h" @@ -50,7 +49,9 @@ namespace matmul_8x12x4 { * same, I test in kirin980 with small and big core, here i just keep both the * implementation. */ + #if 1 +MEGDNN_ATTRIBUTE_TARGET("dotprod") static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, bool is_first_k) { K /= 4; @@ -408,6 +409,7 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, ); } #else +MEGDNN_ATTRIBUTE_TARGET("dotprod") static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, bool is_first_k) { K /= 4; @@ -650,7 +652,7 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, // +-------+-------+ - - - - +--------+--------+--------+ // // Accumulator - +MEGDNN_ATTRIBUTE_TARGET("dotprod") static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, bool is_first_k, int m_remain) { K /= 4; @@ -837,7 +839,7 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, // +-------+-------+ - - - - +---------+ // // Accumulator - +MEGDNN_ATTRIBUTE_TARGET("dotprod") static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, bool is_first_k, int n_remain) { K /= 4; @@ -1038,7 +1040,7 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, // +-------+-------+ - - - - +--------+ // // Accumulator - +MEGDNN_ATTRIBUTE_TARGET("dotprod") static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, bool is_first_k, int m_remain, int n_remain) { diff --git a/dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h b/dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h index 369c5e2e..79898861 100644 --- a/dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h +++ b/dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h @@ -10,8 +10,7 @@ * implied. */ -#if __ARM_FEATURE_DOTPROD - +#if MGB_ENABLE_DOT #include "src/aarch64/matrix_mul/asm/common.h" #include "src/arm_common/simd_macro/marm_neon.h" @@ -40,6 +39,7 @@ namespace matmul_mk4_8x12x4 { // // Accumulator +MEGDNN_ATTRIBUTE_TARGET("dotprod") static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, bool is_first_k) { K /= 4; @@ -60,7 +60,6 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, int32_t* outptr0 = output; int32_t* outptr1; - asm volatile ( // load accumulator C "add %[outptr1], %[outptr0], %x[LDC]\n" @@ -397,6 +396,7 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, // // Accumulator +MEGDNN_ATTRIBUTE_TARGET("dotprod") static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, bool is_first_k) { K /= 4; @@ -543,6 +543,7 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, // +--------+--------+ - - - - +------------+ // Accumulator +MEGDNN_ATTRIBUTE_TARGET("dotprod") static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, bool is_first_k, int n_remain) { K /= 4; @@ -718,6 +719,7 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, // +--------+--------+ - - - - +------------+ // Accumulator +MEGDNN_ATTRIBUTE_TARGET("dotprod") static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, bool is_first_k, int n_remain) { K /= 4; @@ -928,6 +930,5 @@ static void gemm_mk4_s8_8x12_pack_B(dt_int8* out, const dt_int8* in, int ldin, } // namespace matmul_mk4_8x12x4 } // namespace aarch64 } // namespace megdnn - #endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp b/dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp index ce72b693..fadf4215 100644 --- a/dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp +++ b/dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp @@ -10,13 +10,13 @@ */ #include "src/aarch64/matrix_mul/int8_dot/strategy.h" +#if MGB_ENABLE_DOT #include "src/aarch64/matrix_mul/asm/common.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" #include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h" #include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h" -#if __ARM_FEATURE_DOTPROD using namespace megdnn; using namespace aarch64; using namespace aarch64::matmul; diff --git a/dnn/src/aarch64/matrix_mul/int8_dot/strategy.h b/dnn/src/aarch64/matrix_mul/int8_dot/strategy.h index 48382035..f413fed8 100644 --- a/dnn/src/aarch64/matrix_mul/int8_dot/strategy.h +++ b/dnn/src/aarch64/matrix_mul/int8_dot/strategy.h @@ -11,7 +11,7 @@ #pragma once #include "src/fallback/matrix_mul/gemm_common.h" -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT namespace megdnn { namespace aarch64 { namespace matmul { diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.cpp b/dnn/src/aarch64/matrix_mul/opr_impl.cpp index 54ddd8ec..470e5b5d 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.cpp +++ b/dnn/src/aarch64/matrix_mul/opr_impl.cpp @@ -27,14 +27,13 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoF16K8x24x1 f16_k8x24x1; AlgoF16MK8_8x8 f16_mk8_8x8; #endif -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod; AlgoInt8x8x32MK4_8x12x4DotProd int8x8x32_mk4_8x12x4_dotprod; -#else +#endif AlgoInt8x8x32MK4_4x4x16 int8x8x32_mk4_4x4x16; AlgoInt8x8x32K4x4x16 int8x8x32_k4x4x16; AlgoInt8x8x32K8x8x8 int8x8x32_k8x8x8; -#endif AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8; AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; AlgoInt8x8x16MK4_16x12x4 int8x8x16_mk4_16x12x4; @@ -44,12 +43,11 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1; AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8; -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT AlgoQuint8K8x8x4DotProd quint8_k8x8x4_dotprod; AlgoQuint8GemvDotProd quint8_gemv_dotprod; -#else - AlgoQuint8K8x8x8 quint8_k8x8x8; #endif + AlgoQuint8K8x8x8 quint8_k8x8x8; AlgoInt4x4x16K8x8x8 int4x4x16_k8x8x8; SmallVector m_all_algos; @@ -66,14 +64,13 @@ public: m_all_algos.emplace_back(&f16_k8x24x1); m_all_algos.emplace_back(&f16_mk8_8x8); #endif -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT m_all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); m_all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod); -#else +#endif m_all_algos.emplace_back(&int8x8x32_k4x4x16); m_all_algos.emplace_back(&int8x8x32_k8x8x8); m_all_algos.emplace_back(&int8x8x32_mk4_4x4x16); -#endif m_all_algos.emplace_back(&int8x8x16_k4x4x16); m_all_algos.emplace_back(&int8x8x16_k8x8x8); m_all_algos.emplace_back(&int8x8x16_mk4_k8x8x8); @@ -82,12 +79,11 @@ public: m_all_algos.emplace_back(&int16x16x32_k12x8x1); m_all_algos.emplace_back(&int16x16x32_mk8_8x8); -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT m_all_algos.emplace_back(&quint8_gemv_dotprod); m_all_algos.emplace_back(&quint8_k8x8x4_dotprod); -#else - m_all_algos.emplace_back(&quint8_k8x8x8); #endif + m_all_algos.emplace_back(&quint8_k8x8x8); m_all_algos.emplace_back(&int4x4x16_k8x8x8); for (auto&& algo : m_all_algos) { diff --git a/dnn/src/aarch64/matrix_mul/opr_impl.h b/dnn/src/aarch64/matrix_mul/opr_impl.h index b9738a3b..39557bb3 100644 --- a/dnn/src/aarch64/matrix_mul/opr_impl.h +++ b/dnn/src/aarch64/matrix_mul/opr_impl.h @@ -41,16 +41,15 @@ private: class AlgoF16MK8_8x8; // Aarch64 F16 Format MK8 block 16x8 #endif -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel // 8x12x4 DotProduct class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel // 8x12x4 DotProduct -#else +#endif class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 -#endif class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 class AlgoInt8x8x16MK4_16x12x4; // Aarch64 Int8x8x16 Kernel 16x12x16 @@ -59,13 +58,12 @@ private: class AlgoInt16x16x32K12x8x1; // Aarch64 Int16x16x32 Kernel 12x8x1 class AlgoInt16x16x32MK8_8x8; // Aarch64 Int16x16x32 Format MK8 block 8x8 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT class AlgoQuint8K8x8x4DotProd; // Aarch64 Quint8 Kernel // 8x8x4 DotProduct class AlgoQuint8GemvDotProd; // Aarch64 Quint8 Gemv DotProduct -#else - class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 #endif + class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8 class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int8x8x16 Kernel 4x4x16 class AlgoInt4x4x16K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 class AlgoPack; diff --git a/dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h b/dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h index d11df457..1c217ca6 100644 --- a/dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h +++ b/dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h @@ -9,7 +9,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if !(__ARM_FEATURE_DOTPROD) #include "src/aarch64/matrix_mul/asm/common.h" #include "src/arm_common/simd_macro/marm_neon.h" @@ -1395,4 +1394,3 @@ static void gemm_u8_8x8_transpose_pack_B_n(dt_uint8* outptr, } // namespace megdnn // vim: syntax=cpp.doxygen -#endif diff --git a/dnn/src/aarch64/matrix_mul/quint8/strategy.cpp b/dnn/src/aarch64/matrix_mul/quint8/strategy.cpp index e560b773..0e37f986 100644 --- a/dnn/src/aarch64/matrix_mul/quint8/strategy.cpp +++ b/dnn/src/aarch64/matrix_mul/quint8/strategy.cpp @@ -9,7 +9,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if !(__ARM_FEATURE_DOTPROD) #include "src/aarch64/matrix_mul/quint8/strategy.h" #include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h" @@ -108,6 +107,5 @@ void gemm_u8_8x8::kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, packA += K4; } } -#endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/quint8/strategy.h b/dnn/src/aarch64/matrix_mul/quint8/strategy.h index 93214ef4..67d67f78 100644 --- a/dnn/src/aarch64/matrix_mul/quint8/strategy.h +++ b/dnn/src/aarch64/matrix_mul/quint8/strategy.h @@ -10,7 +10,6 @@ */ #pragma once -#if !(__ARM_FEATURE_DOTPROD) #include "src/fallback/matrix_mul/gemm_common.h" namespace megdnn { @@ -23,6 +22,5 @@ MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 8, 8, 8, false, true, } // namespace matmul } // namespace aarch64 } // namespace megdnn -#endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp b/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp index b3523572..207c65eb 100644 --- a/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp +++ b/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp @@ -10,15 +10,13 @@ */ #include "src/aarch64/matrix_mul/quint8_dot/gemv.h" -#include +#if MGB_ENABLE_DOT #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" #include "src/common/unroll_macro.h" -#if __ARM_FEATURE_DOTPROD - namespace { - +MEGDNN_ATTRIBUTE_TARGET("dotprod") void gemv_naive_n(const uint8_t* __restrict A, const uint8_t* __restrict B, int32_t* __restrict C, size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride, @@ -146,7 +144,6 @@ void gemv_naive_n(const uint8_t* __restrict A, const uint8_t* __restrict B, acc[0] + acc[1] + acc[2] + acc[3] + zAB - acc_zA - acc_zB; } } - } // namespace bool megdnn::aarch64::matmul::is_gemv_like_preferred_quint8( @@ -171,7 +168,5 @@ void megdnn::aarch64::matmul::gemv_like_quint8( return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride, zero_point_A, zero_point_B); } - #endif - // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h b/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h index ba633ffa..3c2ef596 100644 --- a/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h +++ b/dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h @@ -10,10 +10,9 @@ */ #pragma once -#include -#include +#include "src/common/utils.h" -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT namespace megdnn { namespace aarch64 { namespace matmul { diff --git a/dnn/src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h b/dnn/src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h index 928b8d1f..6e46a97f 100644 --- a/dnn/src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h +++ b/dnn/src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h @@ -9,8 +9,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD - +#if MGB_ENABLE_DOT #include "src/aarch64/matrix_mul/asm/common.h" #include "src/arm_common/simd_macro/marm_neon.h" @@ -56,7 +55,7 @@ namespace matmul_8x8x4 { // C = sum((A - zA) * (B - zB)) = sum(A * B) - sum(A) * zB - sum(B) * zA + zA * // zB * k // A -> v27, v28 | B -> v29, v30 | zA * zB * k -> v26 - +MEGDNN_ATTRIBUTE_TARGET("dotprod") static void kern_8x8(const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC, bool is_first_k, uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) { @@ -293,6 +292,7 @@ static void kern_8x8(const uint8_t* packA, const uint8_t* packB, int K, // zB * k // A -> v28 | B -> v29, v30 | zA * zB * k -> v26 +MEGDNN_ATTRIBUTE_TARGET("dotprod") static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC, bool is_first_k, int m_remain, uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) { @@ -495,6 +495,7 @@ static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, // zB * k // A -> v27, v28 | B -> v29 | zA * zB * k -> v26 +MEGDNN_ATTRIBUTE_TARGET("dotprod") static void kern_8x4(const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC, bool is_first_k, int n_remain, uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) { @@ -733,6 +734,7 @@ static void kern_8x4(const uint8_t* packA, const uint8_t* packB, int K, // zB * k // A -> v28 | B -> v29 | zA * zB * k -> v26 +MEGDNN_ATTRIBUTE_TARGET("dotprod") static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC, bool is_first_k, int m_remain, int n_remain, uint8_t zero_point_A, uint8_t zero_point_B, diff --git a/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp b/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp index 33e44c32..e8edf8ea 100644 --- a/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp +++ b/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp @@ -16,14 +16,14 @@ #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT using namespace megdnn; using namespace aarch64; using namespace aarch64::matmul; -MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8); +MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_dot); -void gemm_u8_8x8::pack_A(uint8_t* outptr, const uint8_t* inptr, int ldin, +void gemm_u8_8x8_dot::pack_A(uint8_t* outptr, const uint8_t* inptr, int ldin, int y0, int ymax, int k0, int kmax, bool transpose) const { if (transpose) { @@ -35,7 +35,7 @@ void gemm_u8_8x8::pack_A(uint8_t* outptr, const uint8_t* inptr, int ldin, } } -void gemm_u8_8x8::pack_B(uint8_t* out, const uint8_t* in, int ldin, int x0, +void gemm_u8_8x8_dot::pack_B(uint8_t* out, const uint8_t* in, int ldin, int x0, int xmax, int k0, int kmax, bool transpose) const { if (transpose) { matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0, @@ -46,7 +46,7 @@ void gemm_u8_8x8::pack_B(uint8_t* out, const uint8_t* in, int ldin, int x0, } } -void gemm_u8_8x8::kern(const uint8_t* packA, const uint8_t* packB, size_t M, +void gemm_u8_8x8_dot::kern(const uint8_t* packA, const uint8_t* packB, size_t M, size_t N, size_t K, dt_int32* C, size_t LDC, bool is_first_k, const dt_int32*, dt_int32*) const { megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && diff --git a/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h b/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h index 5d32405e..1ed9c474 100644 --- a/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h +++ b/dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h @@ -11,13 +11,13 @@ #pragma once #include "src/fallback/matrix_mul/gemm_common.h" -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT namespace megdnn { namespace aarch64 { namespace matmul { MEGDNN_REG_GEMM_STRATEGY(uint8_t, int32_t, int32_t, 8, 8, 4, false, true, - gemm_u8_8x8); + gemm_u8_8x8_dot); } // namespace aarch64 } // namespace matmul diff --git a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h index 9f0c1c66..9184f278 100644 --- a/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h +++ b/dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h @@ -23,9 +23,6 @@ #include "src/armv7/matrix_mul/asm/common.h" #endif -#if MGB_ENABLE_CPUINFO -#include "cpuinfo.h" -#endif using namespace megdnn; using namespace arm_common; diff --git a/dnn/src/arm_common/conv_bias/int8/algos.cpp b/dnn/src/arm_common/conv_bias/int8/algos.cpp index fe1116cd..d4be9237 100644 --- a/dnn/src/arm_common/conv_bias/int8/algos.cpp +++ b/dnn/src/arm_common/conv_bias/int8/algos.cpp @@ -161,10 +161,13 @@ ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns( return {}; } -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT /* ===================== dot stride1 algo ======================== */ bool ConvBiasImpl::AlgoDotS8DirectStride1::usable(const NCBKernSizeParam& param, AlgoSelectionStrategy) const { + if (!cpuinfo_has_arm_neon_dot()) { + return false; + } return direct_dotprod_int8_stride1::can_conv_direct_stride1_int8(param); } @@ -195,6 +198,9 @@ ConvBiasImpl::AlgoDotS8DirectStride1::dispatch_kerns( /* ===================== dot stride2 algo ======================== */ bool ConvBiasImpl::AlgoDotS8DirectStride2::usable(const NCBKernSizeParam& param, AlgoSelectionStrategy) const { + if (!cpuinfo_has_arm_neon_dot()){ + return false; + } return direct_dotprod_int8_stride2::can_conv_direct_stride2_int8(param); } diff --git a/dnn/src/arm_common/conv_bias/int8/algos.h b/dnn/src/arm_common/conv_bias/int8/algos.h index 773f9c7d..3addf2a0 100644 --- a/dnn/src/arm_common/conv_bias/int8/algos.h +++ b/dnn/src/arm_common/conv_bias/int8/algos.h @@ -129,7 +129,7 @@ public: MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHANWISE_STRD2_NCHW44_S8) }; -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT class ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44 final : public AlgoBase { public: diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp b/dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp index 3ecb74fd..d467272f 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/conv_bias/int8/direct_dotprod.h" +#if MGB_ENABLE_DOT #include "src/arm_common/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" @@ -90,6 +90,7 @@ inline int8x16_t vqtbl1q_s8_v7(int8x16_t a, uint8x16_t index) { _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride1_2x2_int8_dot( const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, @@ -325,6 +326,7 @@ void conv_bias::conv_direct_stride1_2x2_int8_dot( } template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride1_3x3_int8_dot( const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, @@ -560,6 +562,7 @@ void conv_bias::conv_direct_stride1_3x3_int8_dot( } template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride2_2x2_int8_dot( const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, @@ -655,6 +658,7 @@ void conv_bias::conv_direct_stride2_2x2_int8_dot( } template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride2_3x3_int8_dot( const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, @@ -810,6 +814,7 @@ void conv_bias::conv_direct_stride2_3x3_int8_dot( _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride2_5x5_int8_dot( const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, @@ -1108,6 +1113,7 @@ void conv_bias::conv_direct_stride2_5x5_int8_dot( } template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride2_7x7_int8_dot( const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, @@ -1470,6 +1476,7 @@ void conv_bias::conv_direct_stride2_7x7_int8_dot( } template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride1_5x5_int8_dot( const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, @@ -1770,6 +1777,7 @@ void conv_bias::conv_direct_stride1_5x5_int8_dot( } template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride1_7x7_int8_dot( const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, @@ -2115,6 +2123,7 @@ void conv_bias::conv_direct_stride1_7x7_int8_dot( #undef ST1_S32X4 #undef ST2_S32X4X2 + #define INSTANTIATION(stride, i, first_ic, last_ic, bias, Op) \ template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_dot< \ first_ic, last_ic, bias, Op>( \ diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod.h b/dnn/src/arm_common/conv_bias/int8/direct_dotprod.h index 60d47b14..a19bae6d 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_dotprod.h +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod.h @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/conv_bias/opr_impl.h" +#if MGB_ENABLE_DOT #include "src/fallback/conv_bias/common.h" namespace megdnn { diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp index 54ffff58..6c20d4a6 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp @@ -10,9 +10,8 @@ * implied. */ -#ifdef __ARM_FEATURE_DOTPROD - #include "src/arm_common/elemwise_helper/kimpl/typecvt.h" +#if MGB_ENABLE_DOT #include "src/common/unroll_macro.h" #include "src/common/utils.h" #include "src/fallback/conv_bias/common.h" diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h index e5ebb849..e1cde5f4 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h @@ -10,11 +10,10 @@ * implied. */ -#if __ARM_FEATURE_DOTPROD - #pragma once #include "src/arm_common/conv_bias/opr_impl.h" +#if MGB_ENABLE_DOT namespace megdnn { namespace arm_common { @@ -78,4 +77,4 @@ void copy_packed_src_int8_nchw44(int8_t* dst, const int dst_step, #endif -// vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp index 609177e7..3909ed9f 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp @@ -10,9 +10,8 @@ * implied. */ -#if __ARM_FEATURE_DOTPROD - #include "src/arm_common/conv_bias/block_helper.h" +#if MGB_ENABLE_DOT #include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" #include "src/arm_common/elemwise_op.h" @@ -159,6 +158,9 @@ static void conv_kern(const WorkspaceBundle& bundle, bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy algo_selection_strategy) const { + if (!cpuinfo_has_arm_neon_dot()){ + return false; + } MEGDNN_MARK_USED_VAR(algo_selection_strategy); auto&& fm = param.filter_meta; auto FH = fm.spatial[0]; diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h index 9c6babbe..3504b4c4 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h @@ -11,9 +11,9 @@ * implied. */ #pragma once -#if __ARM_FEATURE_DOTPROD #include "megdnn/arch.h" #include "src/arm_common/conv_bias/intrinsic_helper.h" +#if MGB_ENABLE_DOT #include "src/arm_common/elemwise_op.h" #include "src/arm_common/intrinsic_helper.h" #include "src/arm_common/neon_struct.h" @@ -208,6 +208,7 @@ MEGDNN_ALWAYS_INLINE void store_ocx_owx_remain_static(int32x4_t res[][8], template struct ShiftCalHelper { + MEGDNN_ATTRIBUTE_TARGET("dotprod") static MEGDNN_ALWAYS_INLINE void impl(T& res, T2& src, T3& weight) { #define cb(step) \ res[res_row][step] = \ @@ -221,6 +222,7 @@ struct ShiftCalHelper { template +MEGDNN_ATTRIBUTE_TARGET("dotprod") MEGDNN_ALWAYS_INLINE void cal_helper(T& res, T2& src, T3& weight) { ShiftCalHelper::impl(res, src, weight); @@ -242,4 +244,4 @@ struct KernNeonSdotNCHW44 { } // namespace arm_common } // namespace megdnn #endif -// vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp index 20e67dcc..1ae736b8 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp @@ -10,8 +10,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h" +#if MGB_ENABLE_DOT namespace megdnn { namespace arm_common { @@ -20,6 +20,7 @@ template struct KernNeonSdotNCHW44 { + MEGDNN_ATTRIBUTE_TARGET("dotprod") static void impl(dst_type* dst, const int dst_step, const int8_t* src, const int ih, const int iw, const int8_t* filter, const int32_t* bias, const int ic, const Op& op) { @@ -109,6 +110,7 @@ struct KernNeonSdotNCHW44 +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, const int8_t* src, const int ih, const int iw, const int8_t* filter, const int32_t* bias, @@ -317,4 +319,4 @@ FOR_FILTER(1) } // namespace arm_common } // namespace megdnn #endif -// vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp index a7683fe9..36b37778 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp @@ -10,9 +10,9 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h" +#if MGB_ENABLE_DOT namespace megdnn { namespace arm_common { namespace direct_dotprod_nchw44 { @@ -20,6 +20,7 @@ template struct KernNeonSdotNCHW44 { + MEGDNN_ATTRIBUTE_TARGET("dotprod") static void impl(dst_type* dst, const int dst_step, const int8_t* src, const int ih, const int iw, const int8_t* filter, const int32_t* bias, const int ic, const Op& op) { @@ -110,6 +111,7 @@ struct KernNeonSdotNCHW44 +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, const int8_t* src, const int ih, const int iw, const int8_t* filter, const int32_t* bias, @@ -319,4 +321,4 @@ FOR_FILTER(2) } // namespace megdnn #endif -// vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp index 57377ae9..170de66b 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp @@ -11,8 +11,8 @@ * implied. */ -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" +#if MGB_ENABLE_DOT namespace megdnn { namespace arm_common { namespace dot_direct_nchw_nchw44 { @@ -20,6 +20,7 @@ namespace dot_direct_nchw_nchw44 { template struct ShiftCalHelper { + MEGDNN_ATTRIBUTE_TARGET("dotprod") static void impl(T& c, T2& src, T3& weight) { #define cb(step) \ c[0][step] = Func::template impl<(src_idx + step) % 4>( \ @@ -35,6 +36,7 @@ struct ShiftCalHelper { template struct ShiftCalHelper { + MEGDNN_ATTRIBUTE_TARGET("dotprod") static void impl(T& c, T2& src, T3& weight) { #define cb(step) \ c[0][step] = Func::template impl<(src_idx + step) % 4>( \ @@ -49,6 +51,7 @@ template struct KerNeonDotXXs2Nchw44Int8 { + MEGDNN_ATTRIBUTE_TARGET("dotprod") static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { @@ -97,6 +100,7 @@ template struct KerNeonDotXXs2Nchw44Int8 { + MEGDNN_ATTRIBUTE_TARGET("dotprod") static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { @@ -151,6 +155,7 @@ template struct KerNeonDotXXs2Nchw44Int8 { + MEGDNN_ATTRIBUTE_TARGET("dotprod") static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { @@ -200,6 +205,7 @@ template struct KerNeonDotXXs2Nchw44Int8 { + MEGDNN_ATTRIBUTE_TARGET("dotprod") static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { @@ -302,6 +308,7 @@ void pack_src_int8_nchw_nchw44_dot<1>(int8_t* sptr_base, } template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, int8_t* dst, const int oc, const int ic, @@ -445,4 +452,4 @@ DISPATCH_CONV_KERN(1); } // namespace arm_common } // namespace megdnn #endif -// vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp index 069f3c43..f5b392c1 100644 --- a/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp +++ b/dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp @@ -10,8 +10,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" +#if MGB_ENABLE_DOT namespace megdnn { namespace arm_common { namespace dot_direct_nchw_nchw44 { @@ -19,6 +19,7 @@ namespace dot_direct_nchw_nchw44 { template struct ShiftCalHelper { + MEGDNN_ATTRIBUTE_TARGET("dotprod") static void impl(T& c, T2& src, T3& weight) { #define cb(step) \ c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ @@ -42,6 +43,7 @@ struct ShiftCalHelper { template struct ShiftCalHelper { + MEGDNN_ATTRIBUTE_TARGET("dotprod") static void impl(T& c, T2& src, T3& weight) { #define cb(step) \ c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ @@ -60,6 +62,7 @@ template struct KerNeonDotXXs2Nchw44Int8 { + MEGDNN_ATTRIBUTE_TARGET("dotprod") static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { @@ -111,6 +114,7 @@ template struct KerNeonDotXXs2Nchw44Int8 { + MEGDNN_ATTRIBUTE_TARGET("dotprod") static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { @@ -169,6 +173,7 @@ template struct KerNeonDotXXs2Nchw44Int8 { + MEGDNN_ATTRIBUTE_TARGET("dotprod") static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { @@ -224,6 +229,7 @@ template struct KerNeonDotXXs2Nchw44Int8 { + MEGDNN_ATTRIBUTE_TARGET("dotprod") static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, int iw, int ld_dst_oc, const Op& op) { @@ -289,6 +295,7 @@ void pack_src_int8_nchw_nchw44_dot<2>( } template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, const int32_t* bias, int32_t* temp, int8_t* dst, const int oc, const int ic, @@ -434,4 +441,4 @@ DISPATCH_CONV_KERN(2); } // namespace megdnn #endif -// vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp index 2c8480ae..7c579e9a 100644 --- a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp +++ b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp @@ -10,8 +10,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express * or implied. */ -#if __ARM_FEATURE_DOTPROD #include "megdnn/oprs.h" +#if MGB_ENABLE_DOT #include "src/arm_common/conv_bias/block_helper.h" #include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" @@ -175,6 +175,9 @@ static void do_conv_kern(const WorkspaceBundle& bundle, bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable( const NCBKernSizeParam& param, AlgoSelectionStrategy) const { + if (!cpuinfo_has_arm_neon_dot()){ + return false; + } return nchw_nchwxx_valid( param.src_type.enumv(), param.filter_type.enumv(), param.dst_type.enumv(), param.filter_meta, param.bias_mode, diff --git a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h index 6aba98d9..bf930f9d 100644 --- a/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h +++ b/dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h @@ -11,9 +11,9 @@ * implied. */ #pragma once -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/conv_bias/intrinsic_helper.h" +#if MGB_ENABLE_DOT #include "src/arm_common/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/unroll_macro.h" diff --git a/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp b/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp index e34e2b09..6ef8eb42 100644 --- a/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp +++ b/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp @@ -8,9 +8,9 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/conv_bias/int8/stride1_dotprod.h" +#if MGB_ENABLE_DOT #include "megdnn/oprs.h" #include "src/arm_common/conv_bias/int8/direct_dotprod.h" #include "src/arm_common/conv_bias/int8/strategy.h" diff --git a/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.h b/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.h index 5ae1be5c..443a2192 100644 --- a/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.h +++ b/dnn/src/arm_common/conv_bias/int8/stride1_dotprod.h @@ -8,10 +8,10 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD #pragma once #include "src/arm_common/conv_bias/opr_impl.h" +#if MGB_ENABLE_DOT namespace megdnn { namespace arm_common { namespace direct_dotprod_int8_stride1 { diff --git a/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp b/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp index 0ca1dad2..ebe541b1 100644 --- a/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp +++ b/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/conv_bias/int8/stride2_dotprod.h" +#if MGB_ENABLE_DOT #include "megdnn/oprs.h" #include "src/arm_common/conv_bias/int8/direct_dotprod.h" #include "src/arm_common/conv_bias/int8/strategy.h" diff --git a/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.h b/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.h index c23cc62b..8cb2cd70 100644 --- a/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.h +++ b/dnn/src/arm_common/conv_bias/int8/stride2_dotprod.h @@ -9,9 +9,9 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD #pragma once #include "src/arm_common/conv_bias/opr_impl.h" +#if MGB_ENABLE_DOT namespace megdnn { namespace arm_common { diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index 0d2f0a55..fc4fe356 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -60,7 +60,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoS8x8x16ChanWiseStride1Stride2NCHW44 s8x8x16_channel_wise_stride1_stride2_nchw44; -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT AlgoDotS8DirectStride1 ds8_direct_stride1; AlgoDotS8DirectStride2 ds8_direct_stride2; AlgoDotU8DirectStride1 du8_direct_stride1; @@ -94,7 +94,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj { public: AlgoPack() { -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT m_direct_algos.emplace_back(&ds8_direct_stride1); m_direct_algos.emplace_back(&ds8_direct_stride2); m_direct_algos.emplace_back(&du8_direct_stride1); diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index 28fb3d76..d145fe57 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -70,7 +70,7 @@ private: class AlgoFP16WinogradF63; class AlgoFP16WinogradF23_8x8; #endif -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT class AlgoDotS8DirectNCHWNCHW44; class AlgoDotS8DirectStride1; class AlgoDotS8DirectStride2; diff --git a/dnn/src/arm_common/conv_bias/quint8/algos.cpp b/dnn/src/arm_common/conv_bias/quint8/algos.cpp index e1c39593..afd41fb0 100644 --- a/dnn/src/arm_common/conv_bias/quint8/algos.cpp +++ b/dnn/src/arm_common/conv_bias/quint8/algos.cpp @@ -11,7 +11,6 @@ */ #include "src/arm_common/conv_bias/quint8/algos.h" -#include "midout.h" #include "src/arm_common/conv_bias/quint8/stride1.h" #include "src/arm_common/conv_bias/quint8/stride1_dotprod.h" #include "src/arm_common/conv_bias/quint8/stride2.h" @@ -19,6 +18,8 @@ #include "src/arm_common/elemwise_op.h" #include "src/fallback/conv_bias/common.h" +#include "midout.h" + MIDOUT_DECL(megdnn_arm_common_conv_bias_quint8) using namespace megdnn; @@ -84,10 +85,13 @@ ConvBiasImpl::AlgoQU8DirectStride2::dispatch_kerns( MIDOUT_END(); return {}; } -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT /* ===================== stride1 algo ===================== */ bool ConvBiasImpl::AlgoDotU8DirectStride1::usable(const NCBKernSizeParam& param, AlgoSelectionStrategy) const { + if (!cpuinfo_has_arm_neon_dot()){ + return false; + } return direct_dotprod_quint8_stride1::can_conv_direct_stride1_quint8(param); } @@ -118,6 +122,9 @@ ConvBiasImpl::AlgoDotU8DirectStride1::dispatch_kerns( /* ===================== stride2 algo ===================== */ bool ConvBiasImpl::AlgoDotU8DirectStride2::usable(const NCBKernSizeParam& param, AlgoSelectionStrategy) const { + if (!cpuinfo_has_arm_neon_dot()){ + return false; + } return direct_dotprod_quint8_stride2::can_conv_direct_stride2_quint8(param); } diff --git a/dnn/src/arm_common/conv_bias/quint8/algos.h b/dnn/src/arm_common/conv_bias/quint8/algos.h index 1924034d..ae188c0f 100644 --- a/dnn/src/arm_common/conv_bias/quint8/algos.h +++ b/dnn/src/arm_common/conv_bias/quint8/algos.h @@ -55,7 +55,7 @@ public: } MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_QU8) }; -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { public: diff --git a/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp b/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp index 20d3bbf2..92e172ed 100644 --- a/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp +++ b/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp @@ -9,8 +9,8 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" +#if MGB_ENABLE_DOT #include "src/arm_common/elemwise_op.h" #include "src/arm_common/simd_macro/marm_neon.h" #include "src/common/utils.h" @@ -120,6 +120,7 @@ inline int8x16_t vqtbl1q_s8_v7(int8x16_t a, uint8x16_t index){ template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride1_2x2_quint8_dot( const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, @@ -452,6 +453,7 @@ void conv_bias::conv_direct_stride1_2x2_quint8_dot( template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride1_3x3_quint8_dot( const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, @@ -691,6 +693,7 @@ void conv_bias::conv_direct_stride1_3x3_quint8_dot( template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride2_2x2_quint8_dot( const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, @@ -801,6 +804,7 @@ void conv_bias::conv_direct_stride2_2x2_quint8_dot( template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride2_3x3_quint8_dot( const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, @@ -1135,6 +1139,7 @@ void conv_bias::conv_direct_stride2_3x3_quint8_dot( template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride1_5x5_quint8_dot( const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, @@ -1443,6 +1448,7 @@ void conv_bias::conv_direct_stride1_5x5_quint8_dot( template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride1_7x7_quint8_dot( const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, @@ -1785,6 +1791,7 @@ void conv_bias::conv_direct_stride1_7x7_quint8_dot( template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride2_5x5_quint8_dot( const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, @@ -2090,6 +2097,7 @@ void conv_bias::conv_direct_stride2_5x5_quint8_dot( template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void conv_bias::conv_direct_stride2_7x7_quint8_dot( const uint8_t* src, const uint8_t* filter, const int32_t* bias, int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, diff --git a/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.h b/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.h index caadc5e6..6e0f2e82 100644 --- a/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.h +++ b/dnn/src/arm_common/conv_bias/quint8/direct_dotprod.h @@ -8,9 +8,9 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/conv_bias/opr_impl.h" +#if MGB_ENABLE_DOT #include "src/fallback/conv_bias/common.h" namespace megdnn { diff --git a/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp b/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp index 5a005a4c..bf503a86 100644 --- a/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp +++ b/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/conv_bias/quint8/stride1_dotprod.h" +#if MGB_ENABLE_DOT #include "megdnn/oprs.h" #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" #include "src/arm_common/elemwise_op.h" diff --git a/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.h b/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.h index dc702f3a..85c14fc2 100644 --- a/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.h +++ b/dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.h @@ -8,10 +8,10 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD #pragma once #include "src/arm_common/conv_bias/opr_impl.h" +#if MGB_ENABLE_DOT namespace megdnn { namespace arm_common { diff --git a/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp b/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp index b06ba166..818dfd97 100644 --- a/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp +++ b/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp @@ -8,8 +8,8 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/conv_bias/quint8/stride2_dotprod.h" +#if MGB_ENABLE_DOT #include "megdnn/oprs.h" #include "src/arm_common/conv_bias/quint8/direct_dotprod.h" #include "src/arm_common/elemwise_op.h" diff --git a/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.h b/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.h index 5191cf9a..7cd9e64a 100644 --- a/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.h +++ b/dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.h @@ -8,10 +8,10 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD #pragma once #include "src/arm_common/conv_bias/opr_impl.h" +#if MGB_ENABLE_DOT namespace megdnn { namespace arm_common { diff --git a/dnn/src/arm_common/convolution/int8x8x32/algos.cpp b/dnn/src/arm_common/convolution/int8x8x32/algos.cpp index 12cb2979..ea9a2a76 100644 --- a/dnn/src/arm_common/convolution/int8x8x32/algos.cpp +++ b/dnn/src/arm_common/convolution/int8x8x32/algos.cpp @@ -13,21 +13,24 @@ #include "src/arm_common/convolution/int8x8x32/algos.h" #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h" #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h" +#include "src/common/opr_delegate.h" #include "midout.h" -#include "src/common/opr_delegate.h" MIDOUT_DECL(megdnn_arm_conv_int8832_kimpl) using namespace megdnn; using namespace arm_common; -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT /* ===================== ConvolutionBackwardData ===================== */ /* ===================== direct stride 1 algo ===================== */ bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::usable( fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + if (!cpuinfo_has_arm_neon_dot()){ + return false; + } return deconv::can_stride1_int8x8x32_dot(param); } @@ -57,6 +60,9 @@ ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern( bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::usable( fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + if (!cpuinfo_has_arm_neon_dot()){ + return false; + } return deconv::can_stride2_int8x8x32_dot(param); } diff --git a/dnn/src/arm_common/convolution/int8x8x32/algos.h b/dnn/src/arm_common/convolution/int8x8x32/algos.h index 21f18d63..7a71a9de 100644 --- a/dnn/src/arm_common/convolution/int8x8x32/algos.h +++ b/dnn/src/arm_common/convolution/int8x8x32/algos.h @@ -17,7 +17,7 @@ namespace megdnn { namespace arm_common { -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT /* ===================== ConvolutionBackwardData ===================== */ class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1 final diff --git a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp index ccb0ae20..c03576be 100644 --- a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp +++ b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp @@ -9,11 +9,9 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h" +#if MGB_ENABLE_DOT #include "src/common/utils.h" - -#include #include "src/arm_common/simd_macro/marm_neon.h" using namespace megdnn; @@ -94,6 +92,7 @@ inline int8x16_t vqtbl1q_s8_common(int8x16_t a, uint8x16_t index) { _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k1_idx, _elem); \ _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); +MEGDNN_ATTRIBUTE_TARGET("dotprod") void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { MEGDNN_MARK_USED_VAR(IH); @@ -328,6 +327,7 @@ void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, } } +MEGDNN_ATTRIBUTE_TARGET("dotprod") void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { MEGDNN_MARK_USED_VAR(IH); @@ -530,6 +530,7 @@ void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst, _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k01_idx, _elem); \ _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); +MEGDNN_ATTRIBUTE_TARGET("dotprod") void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { MEGDNN_MARK_USED_VAR(IH); @@ -777,6 +778,7 @@ void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst, } } +MEGDNN_ATTRIBUTE_TARGET("dotprod") void deconv_direct_7x7(const int8_t* src, const int8_t* filter, int32_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { MEGDNN_MARK_USED_VAR(IH); @@ -1070,6 +1072,7 @@ void deconv_direct_7x7(const int8_t* src, const int8_t* filter, int32_t* dst, } // anonymous namespace + size_t deconv::get_workspace_in_bytes_stride1_int8x8x32_dot( const NCBKernSizeParam& param) { return get_bundle(param).total_size_in_bytes(); diff --git a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h index 4b97420e..6c661639 100644 --- a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h +++ b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h @@ -10,8 +10,8 @@ */ #pragma once -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/convolution/opr_impl.h" +#if MGB_ENABLE_DOT #include #include diff --git a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp index e98b7f44..8fbb0b85 100644 --- a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp +++ b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp @@ -9,11 +9,9 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h" +#if MGB_ENABLE_DOT #include "src/common/utils.h" - -#include #include "src/arm_common/simd_macro/marm_neon.h" using namespace megdnn; @@ -83,6 +81,7 @@ inline int8x16_t vqtbl1q_s8_common(int8x16_t a, uint8x16_t index) { _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { MEGDNN_MARK_USED_VAR(IH); @@ -334,6 +333,7 @@ void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, } template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { MEGDNN_MARK_USED_VAR(IH); @@ -558,6 +558,7 @@ void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst, _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { MEGDNN_MARK_USED_VAR(IH); @@ -835,6 +836,7 @@ void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst, } template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void deconv_direct_7x7(const int8_t* src, const int8_t* filter, int32_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { MEGDNN_MARK_USED_VAR(IH); diff --git a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h index 24d91411..b04b2ed2 100644 --- a/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h +++ b/dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h @@ -10,8 +10,8 @@ */ #pragma once -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/convolution/opr_impl.h" +#if MGB_ENABLE_DOT #include #include diff --git a/dnn/src/arm_common/convolution/opr_impl.cpp b/dnn/src/arm_common/convolution/opr_impl.cpp index 9c5a9f95..0131037b 100644 --- a/dnn/src/arm_common/convolution/opr_impl.cpp +++ b/dnn/src/arm_common/convolution/opr_impl.cpp @@ -24,7 +24,7 @@ using namespace arm_common; /* ===================== ConvolutionBackwardData ===================== */ class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT AlgoSdot8DirectStride1 i8x8x32_direct_stride1_sdot; AlgoSdot8DirectStride2 i8x8x32_direct_stride2_sdot; AlgoUdot8DirectStride1 quint8_direct_stride1_udot; @@ -37,7 +37,7 @@ class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { public: AlgoPack() { -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT m_all_algos.emplace_back(&i8x8x32_direct_stride1_sdot); m_all_algos.emplace_back(&i8x8x32_direct_stride2_sdot); m_all_algos.emplace_back(&quint8_direct_stride1_udot); diff --git a/dnn/src/arm_common/convolution/opr_impl.h b/dnn/src/arm_common/convolution/opr_impl.h index 9d088fda..d0e66122 100644 --- a/dnn/src/arm_common/convolution/opr_impl.h +++ b/dnn/src/arm_common/convolution/opr_impl.h @@ -56,7 +56,7 @@ public: MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl); private: -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT class AlgoSdot8DirectStride1; class AlgoSdot8DirectStride2; class AlgoUdot8DirectStride1; diff --git a/dnn/src/arm_common/convolution/quint8/algos.cpp b/dnn/src/arm_common/convolution/quint8/algos.cpp index c6ececa2..01a191be 100644 --- a/dnn/src/arm_common/convolution/quint8/algos.cpp +++ b/dnn/src/arm_common/convolution/quint8/algos.cpp @@ -14,6 +14,7 @@ #include "src/arm_common/convolution/quint8/conv_backdata_stride1.h" #include "src/arm_common/convolution/quint8/conv_backdata_stride2.h" #include "src/common/opr_delegate.h" + #include "midout.h" MIDOUT_DECL(megdnn_arm_conv_quint8_kimpl) @@ -21,7 +22,7 @@ MIDOUT_DECL(megdnn_arm_conv_quint8_kimpl) using namespace megdnn; using namespace arm_common; -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT /* ===================== ConvolutionBackwardData ===================== */ @@ -29,6 +30,10 @@ using namespace arm_common; bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::usable( fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + + if (!cpuinfo_has_arm_neon_dot()){ + return false; + } return deconv::can_stride1_quint8_dot(param); } @@ -58,6 +63,9 @@ ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern( bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::usable( fallback::ConvolutionBackwardDataImpl*, const NCBKernSizeParam& param) const { + if (!cpuinfo_has_arm_neon_dot()){ + return false; + } return deconv::can_stride2_quint8_dot(param); } diff --git a/dnn/src/arm_common/convolution/quint8/algos.h b/dnn/src/arm_common/convolution/quint8/algos.h index e48d15b9..44b7d6c4 100644 --- a/dnn/src/arm_common/convolution/quint8/algos.h +++ b/dnn/src/arm_common/convolution/quint8/algos.h @@ -17,7 +17,7 @@ namespace megdnn { namespace arm_common { -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT /* ===================== ConvolutionBackwardData ===================== */ class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1 final : public AlgoBase { diff --git a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp index ab5ccbb7..b936bd74 100644 --- a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp +++ b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp @@ -9,11 +9,9 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/convolution/quint8/conv_backdata_stride1.h" +#if MGB_ENABLE_DOT #include "src/common/utils.h" - -#include #include "src/arm_common/simd_macro/marm_neon.h" using namespace megdnn; @@ -109,6 +107,7 @@ inline uint8x16_t vqtbl1q_u8_common(uint8x16_t a, uint8x16_t index) { _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, uint8_t src_zp, uint8_t filter_zp, @@ -385,6 +384,7 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, } template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, uint8_t src_zp, uint8_t filter_zp, @@ -636,6 +636,7 @@ void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, uint8_t src_zp, uint8_t filter_zp, @@ -907,6 +908,7 @@ void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, } template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, uint8_t src_zp, uint8_t filter_zp, @@ -1220,6 +1222,7 @@ void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst, } // anonymous namespace + size_t deconv::get_workspace_in_bytes_stride1_quint8_dot( const NCBKernSizeParam& param) { return get_bundle(param).total_size_in_bytes(); diff --git a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.h b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.h index 6e7c9ab7..ae642501 100644 --- a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.h +++ b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.h @@ -10,11 +10,8 @@ */ #pragma once -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/convolution/opr_impl.h" - -#include -#include +#if MGB_ENABLE_DOT namespace megdnn { namespace arm_common { diff --git a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp index d0f032ad..77da3c20 100644 --- a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp +++ b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp @@ -9,11 +9,9 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/convolution/quint8/conv_backdata_stride2.h" +#if MGB_ENABLE_DOT #include "src/common/utils.h" - -#include #include "src/arm_common/simd_macro/marm_neon.h" using namespace megdnn; @@ -110,6 +108,7 @@ inline uint8x16_t vqtbx1q_u8_common(uint8x16_t a, uint8x16_t t, _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, uint8_t src_zp, uint8_t filter_zp, @@ -402,6 +401,7 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, } template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, uint8_t src_zp, uint8_t filter_zp, @@ -673,6 +673,7 @@ void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, uint8_t src_zp, uint8_t filter_zp, @@ -972,6 +973,7 @@ void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, } template +MEGDNN_ATTRIBUTE_TARGET("dotprod") void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, uint8_t src_zp, uint8_t filter_zp, diff --git a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.h b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.h index 12068b1e..3822c14d 100644 --- a/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.h +++ b/dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.h @@ -10,11 +10,8 @@ */ #pragma once -#if __ARM_FEATURE_DOTPROD #include "src/arm_common/convolution/opr_impl.h" - -#include -#include +#if MGB_ENABLE_DOT namespace megdnn { namespace arm_common { diff --git a/dnn/src/arm_common/matrix_mul/algos.cpp b/dnn/src/arm_common/matrix_mul/algos.cpp index ef2e66ae..e2a2a055 100644 --- a/dnn/src/arm_common/matrix_mul/algos.cpp +++ b/dnn/src/arm_common/matrix_mul/algos.cpp @@ -14,8 +14,10 @@ #include "src/arm_common/matrix_mul/fp16/hgemv.h" #include "src/arm_common/matrix_mul/fp32/exec_sgemv.h" #include "src/arm_common/matrix_mul/int8/gemv.h" + #include "midout.h" + MIDOUT_DECL(megdnn_arm_hgemv) MIDOUT_DECL(megdnn_arm_exec_int8816) MIDOUT_DECL(megdnn_arm_exec_int8832) @@ -158,7 +160,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4::get_kern( return int8x8x32_gemv_mk4_kern; } -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT /* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */ namespace { void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { @@ -176,6 +178,10 @@ void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { bool MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::usable( const KernSizeParam& kern_size_param) const { + + if (!cpuinfo_has_arm_neon_dot()){ + return false; + } auto M = kern_size_param.M; auto N = kern_size_param.N; auto K = kern_size_param.K; diff --git a/dnn/src/arm_common/matrix_mul/algos.h b/dnn/src/arm_common/matrix_mul/algos.h index 89705b70..fb55f5fe 100644 --- a/dnn/src/arm_common/matrix_mul/algos.h +++ b/dnn/src/arm_common/matrix_mul/algos.h @@ -63,7 +63,7 @@ public: MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4) }; -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { public: AlgoAttribute attribute() const override { diff --git a/dnn/src/arm_common/matrix_mul/int8/gemv.cpp b/dnn/src/arm_common/matrix_mul/int8/gemv.cpp index 361f2ba2..d2b81d42 100644 --- a/dnn/src/arm_common/matrix_mul/int8/gemv.cpp +++ b/dnn/src/arm_common/matrix_mul/int8/gemv.cpp @@ -9,7 +9,6 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include #include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/matrix_mul/int8/gemv.h" #include "src/common/utils.h" @@ -21,7 +20,6 @@ MIDOUT_DECL(megdnn_arm_common_int8_gemv) using namespace megdnn; using namespace arm_common; -#if !__ARM_FEATURE_DOTPROD namespace { @@ -170,12 +168,11 @@ void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B, } } // namespace -#endif -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT namespace { - -void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, +MEGDNN_ATTRIBUTE_TARGET("dotprod") +void gemv_naive_n_dot(const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { megdnn_assert(N == 1 && Bstride == 1); @@ -244,7 +241,8 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B, } } -void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B, +MEGDNN_ATTRIBUTE_TARGET("dotprod") +void gemv_naive_n_mk4_dotprod(const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride) { constexpr size_t PACK_SIZE = 4; @@ -323,6 +321,7 @@ void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B, } } +MEGDNN_ATTRIBUTE_TARGET("dotprod") void gemv_naive_n_mk4_dot(const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, size_t M, size_t N, size_t K, size_t Astride, @@ -403,7 +402,16 @@ void arm_common::gemv_like(const int8_t* __restrict A, megdnn_assert(N == 1); MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, midout_iv("INT8_gemv_like"_hash)) { +#if MGB_ENABLE_DOT + if (cpuinfo_has_arm_neon_dot()) { + return gemv_naive_n_dot(A, B, C, M, N, K, Astride, Bstride, + Cstride); + } else { + return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); + } +#else return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); +#endif } MIDOUT_END(); } @@ -416,12 +424,22 @@ void arm_common::gemv_like_mk4(const int8_t* __restrict A, megdnn_assert(N == 1); MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, midout_iv("INT8_gemv_like_mk4"_hash)) { +#if MGB_ENABLE_DOT + if (cpuinfo_has_arm_neon_dot()) { + return gemv_naive_n_mk4_dotprod(A, B, C, M, N, K, Astride, Bstride, + Cstride); + } else { + return gemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, + Cstride); + } +#else return gemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride); +#endif } MIDOUT_END(); } -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT void arm_common::gemv_like_mk4_dot(const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, size_t M, size_t N, @@ -437,4 +455,5 @@ void arm_common::gemv_like_mk4_dot(const int8_t* __restrict A, } #endif + // vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/matrix_mul/int8/gemv.h b/dnn/src/arm_common/matrix_mul/int8/gemv.h index 54c63d0f..13ff27b7 100644 --- a/dnn/src/arm_common/matrix_mul/int8/gemv.h +++ b/dnn/src/arm_common/matrix_mul/int8/gemv.h @@ -28,7 +28,7 @@ void gemv_like_mk4(const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT void gemv_like_mk4_dot(const int8_t* __restrict A, const int8_t* __restrict B, int32_t* __restrict C, size_t M, size_t N, size_t K, size_t Astride, size_t Bstride, size_t Cstride); diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.cpp b/dnn/src/arm_common/matrix_mul/opr_impl.cpp index 771a11c0..9d09eed4 100644 --- a/dnn/src/arm_common/matrix_mul/opr_impl.cpp +++ b/dnn/src/arm_common/matrix_mul/opr_impl.cpp @@ -22,7 +22,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { #endif AlgoInt8x8x32Gemv int8x8x32_gemv; AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4; -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot; #endif AlgoGevm gevm; @@ -37,7 +37,7 @@ public: #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC m_all_algos.emplace_back(&f16gemv); #endif -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT m_all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); #endif m_all_algos.emplace_back(&int8x8x32_gemv); diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.h b/dnn/src/arm_common/matrix_mul/opr_impl.h index 00deb601..cbdb120a 100644 --- a/dnn/src/arm_common/matrix_mul/opr_impl.h +++ b/dnn/src/arm_common/matrix_mul/opr_impl.h @@ -42,7 +42,7 @@ protected: #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC class AlgoF16Gemv; #endif -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT class AlgoInt8x8x32GemvMK4Dot;// Arm_common Int8x8x32 Gemv NCHW44_DOT #endif class AlgoInt8x8x16; // Arm_common Int 8x8x16 diff --git a/dnn/src/arm_common/neon_struct.h b/dnn/src/arm_common/neon_struct.h index a70f4542..43a4eace 100644 --- a/dnn/src/arm_common/neon_struct.h +++ b/dnn/src/arm_common/neon_struct.h @@ -69,9 +69,10 @@ struct Vfmaq_laneq_f32 { return vfmaq_laneq_f32(a, b, v, lane); } }; -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT struct Vdotq_laneq_s32 { template + MEGDNN_ATTRIBUTE_TARGET("dotprod") static __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { return vdotq_laneq_s32(a, b, v, lane); } @@ -82,4 +83,4 @@ struct Vdotq_laneq_s32 { } // namespace megdnn #undef __ai -// vim: syntax=cpp.doxygen \ No newline at end of file +// vim: syntax=cpp.doxygen diff --git a/dnn/src/arm_common/simd_macro/marm_neon.h b/dnn/src/arm_common/simd_macro/marm_neon.h index 8439c5fe..dcd8f9aa 100644 --- a/dnn/src/arm_common/simd_macro/marm_neon.h +++ b/dnn/src/arm_common/simd_macro/marm_neon.h @@ -10,7 +10,12 @@ * implied. */ #pragma once - +#if MGB_ENABLE_DOT +#if defined(__ARM_FEATURE_DOTPROD) +#undef __ARM_FEATURE_DOTPROD +#endif +#define __ARM_FEATURE_DOTPROD 1 +#endif #include #include "megdnn/arch.h" #include "src/common/unroll_macro.h" @@ -249,13 +254,14 @@ __ai float16x8_t vdupq_n_f16(__fp16 a) { #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -#if __ARM_FEATURE_DOTPROD - +#if MGB_ENABLE_DOT +MEGDNN_ATTRIBUTE_TARGET("dotprod") __ai int32x4_t vdotq2_s32(int8x16_t a, int8x16_t b) { int32x4_t c = vdupq_n_s32(0); return vdotq_s32(c, a, b); } +MEGDNN_ATTRIBUTE_TARGET("dotprod") __ai uint32x4_t vdotq2_u32(uint8x16_t a, uint8x16_t b) { uint32x4_t c = vdupq_n_u32(0); return vdotq_u32(c, a, b); @@ -275,11 +281,13 @@ __ai uint32x4_t vdotq2_u32(uint8x16_t a, uint8x16_t b) { c; \ }) +MEGDNN_ATTRIBUTE_TARGET("dotprod") __ai int32x2_t vdot2_s32(int8x8_t a, int8x8_t b) { int32x2_t c = vdup_n_s32(0); return vdot_s32(c, a, b); } +MEGDNN_ATTRIBUTE_TARGET("dotprod") __ai uint32x2_t vdot2_u8(uint8x8_t a, uint8x8_t b) { uint32x2_t c = vdup_n_u32(0); return vdot_u32(c, a, b); @@ -298,8 +306,7 @@ __ai uint32x2_t vdot2_u8(uint8x8_t a, uint8x8_t b) { c = vdot_lane_u32(c, a, b, lane); \ c; \ }) - -#endif // __ARM_FEATURE_DOTPROD +#endif // MGB_ENABLE_DOT #if __GNUC__ < 8 #undef vld1q_f32_x2 @@ -575,7 +582,7 @@ struct Vfmsq_laneq_f32_armv7<3> { #define vfmsq_laneq_f32(a, b, v, lane) \ Vfmsq_laneq_f32_armv7::impl(a, b, v) -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT namespace { template struct Vdotq_laneq_s32_armv7 { @@ -583,24 +590,28 @@ struct Vdotq_laneq_s32_armv7 { }; template <> struct Vdotq_laneq_s32_armv7<0> { + MEGDNN_ATTRIBUTE_TARGET("dotprod") __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { return vdotq_lane_s32(a, b, vget_low_s32(v), 0); } }; template <> struct Vdotq_laneq_s32_armv7<1> { + MEGDNN_ATTRIBUTE_TARGET("dotprod") __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { return vdotq_lane_s32(a, b, vget_low_s32(v), 1); } }; template <> struct Vdotq_laneq_s32_armv7<2> { + MEGDNN_ATTRIBUTE_TARGET("dotprod") __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { return vdotq_lane_s32(a, b, vget_high_s32(v), 0); } }; template <> struct Vdotq_laneq_s32_armv7<3> { + MEGDNN_ATTRIBUTE_TARGET("dotprod") __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { return vdotq_lane_s32(a, b, vget_high_f32(v), 1); } @@ -765,7 +776,9 @@ __ai float32x4_t Vfmsq_f32(float32x4_t& a, float32x4_t& b, float32x4_t& v) { :); return a; } - +#if MGB_ENABLE_DOT +#undef __ARM_FEATURE_DOTPROD +#endif #undef __ai #pragma GCC diagnostic pop diff --git a/dnn/src/armv7/matrix_mul/algos.cpp b/dnn/src/armv7/matrix_mul/algos.cpp index f0834a88..8d4f6397 100644 --- a/dnn/src/armv7/matrix_mul/algos.cpp +++ b/dnn/src/armv7/matrix_mul/algos.cpp @@ -19,6 +19,9 @@ #include "src/armv7/matrix_mul/quint8/strategy.h" #include "src/common/utils.h" #include "src/fallback/matrix_mul/gemm_impl.h" +#if MGB_ENABLE_CPUINFO +#include "cpuinfo.h" +#endif #include "midout.h" @@ -744,7 +747,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x4x1, armv7::matmul::gemm_s16x16x32_12x4, int16_t, int32_t, AlgoDataType::INT16X16X32, DEFAULT); -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT /* ===================== Int8 K6x8x4 algo ===================== */ namespace { void int8_k6x8x4_kern(const MatrixMulImpl::KernParam& kern_param) { @@ -769,6 +772,9 @@ void int8_k6x8x4_kern(const MatrixMulImpl::KernParam& kern_param) { bool MatrixMulImpl::AlgoInt8x8x32K6x8x4::usable( const KernSizeParam& kern_size_param) const { + if (!cpuinfo_has_arm_neon_dot()){ + return false; + } return can_be_treated_as_int8x8x32(kern_size_param); } @@ -827,6 +833,9 @@ void quint8_dot_k4x8x4_kern(const MatrixMulImpl::KernParam& kern_param) { bool MatrixMulImpl::AlgoQuint8DotK4x8x4::usable( const KernSizeParam& kern_size_param) const { + if (!cpuinfo_has_arm_neon_dot()){ + return false; + } return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm && kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 && @@ -891,6 +900,9 @@ void int8_mk4_8x4x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { bool MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::usable( const KernSizeParam& kern_size_param) const { + if (!cpuinfo_has_arm_neon_dot()){ + return false; + } return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && (kern_size_param.A_type.enumv() == DTypeEnum::Int8 || kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && diff --git a/dnn/src/armv7/matrix_mul/algos.h b/dnn/src/armv7/matrix_mul/algos.h index 7c4ca669..dcd55463 100644 --- a/dnn/src/armv7/matrix_mul/algos.h +++ b/dnn/src/armv7/matrix_mul/algos.h @@ -86,7 +86,7 @@ public: MEGDNN_DECL_ALGO_TYPE(ARMV7_F16_MK8_4X8) }; #endif -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT class MatrixMulImpl::AlgoInt8x8x32K6x8x4 final : public AlgoBase { public: AlgoAttribute attribute() const override { diff --git a/dnn/src/armv7/matrix_mul/asm/common.h b/dnn/src/armv7/matrix_mul/asm/common.h index 2dd75734..1b665a56 100644 --- a/dnn/src/armv7/matrix_mul/asm/common.h +++ b/dnn/src/armv7/matrix_mul/asm/common.h @@ -10,7 +10,6 @@ * implied. */ #pragma once -#include #include #include #include diff --git a/dnn/src/armv7/matrix_mul/fp32/strategy_4x12.cpp b/dnn/src/armv7/matrix_mul/fp32/strategy_4x12.cpp index 502a78e6..c8c00643 100644 --- a/dnn/src/armv7/matrix_mul/fp32/strategy_4x12.cpp +++ b/dnn/src/armv7/matrix_mul/fp32/strategy_4x12.cpp @@ -10,7 +10,6 @@ * implied. */ -#include "src/arm_common/simd_macro/marm_neon.h" #include "src/armv7/matrix_mul/asm/common.h" #include "src/armv7/matrix_mul/fp32/strategy.h" #include "src/common/utils.h" diff --git a/dnn/src/armv7/matrix_mul/int8/kernel_6x8x4.h b/dnn/src/armv7/matrix_mul/int8/kernel_6x8x4.h index ac0c7efb..0b7de2e7 100644 --- a/dnn/src/armv7/matrix_mul/int8/kernel_6x8x4.h +++ b/dnn/src/armv7/matrix_mul/int8/kernel_6x8x4.h @@ -9,7 +9,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT #include "src/arm_common/simd_macro/marm_neon.h" #include "src/armv7/matrix_mul/asm/common.h" @@ -43,6 +43,7 @@ namespace matmul_dot_6x8x4 { // // Accumulator +MEGDNN_ATTRIBUTE_TARGET("dotprod") static void kern_6x8(const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, bool is_first_k, size_t m_remain = 6) { @@ -274,6 +275,7 @@ static void kern_6x8(const int8_t* packA, const int8_t* packB, int K, // // Accumulator +MEGDNN_ATTRIBUTE_TARGET("dotprod") static void kern_6x4(const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, bool is_first_k, size_t n_remain = 8, size_t m_remain = 6) { diff --git a/dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h b/dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h index 84626797..d28ec559 100644 --- a/dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h +++ b/dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h @@ -10,7 +10,7 @@ * implied. */ -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT #include "src/arm_common/simd_macro/marm_neon.h" #include "src/armv7/matrix_mul/asm/common.h" @@ -42,7 +42,7 @@ namespace matmul_mk4_dot_8x4x4 { // |q14[0-4]| // +--------+ // Accumulator - +MEGDNN_ATTRIBUTE_TARGET("dotprod") static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, bool is_first_k, int n_remain) { K /= 4; @@ -211,6 +211,7 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, // +--------+ // Accumulator +MEGDNN_ATTRIBUTE_TARGET("dotprod") static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, int32_t* output, int LDC, bool is_first_k, int n_remain) { K /= 4; diff --git a/dnn/src/armv7/matrix_mul/int8/strategy.cpp b/dnn/src/armv7/matrix_mul/int8/strategy.cpp index 57d05559..69be5248 100644 --- a/dnn/src/armv7/matrix_mul/int8/strategy.cpp +++ b/dnn/src/armv7/matrix_mul/int8/strategy.cpp @@ -175,7 +175,7 @@ void gemm_s8_4x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M, } } -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT // ===========================gemm_s8_6x8====================================== MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_dots8_6x8); void gemm_dots8_6x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, diff --git a/dnn/src/armv7/matrix_mul/int8/strategy.h b/dnn/src/armv7/matrix_mul/int8/strategy.h index 2f7302b2..66ddcdd3 100644 --- a/dnn/src/armv7/matrix_mul/int8/strategy.h +++ b/dnn/src/armv7/matrix_mul/int8/strategy.h @@ -23,7 +23,7 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 8, 8, false, true, MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 2, 16, false, false, gemm_mk4_s8_4x2); -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 6, 8, 4, false, false, gemm_dots8_6x8); diff --git a/dnn/src/armv7/matrix_mul/opr_impl.cpp b/dnn/src/armv7/matrix_mul/opr_impl.cpp index c0cc21e9..a24ea6d0 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.cpp +++ b/dnn/src/armv7/matrix_mul/opr_impl.cpp @@ -27,7 +27,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoF16K4x16x1 f16_k4x16x1; AlgoF16MK8_4x8 f16_mk8_4x8; #endif -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT AlgoInt8x8x32K6x8x4 int8_k6x8x4; AlgoQuint8DotK4x8x4 quint8_k4x8x4; AlgoInt8x8x32MK4_8x4x4DotProd int8x8x32_mk4_8x4x4_dotprod; @@ -57,7 +57,7 @@ public: m_all_algos.emplace_back(&f16_k4x16x1); m_all_algos.emplace_back(&f16_mk8_4x8); #endif -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT m_all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod); m_all_algos.emplace_back(&int8_k6x8x4); m_all_algos.emplace_back(&quint8_k4x8x4); diff --git a/dnn/src/armv7/matrix_mul/opr_impl.h b/dnn/src/armv7/matrix_mul/opr_impl.h index 085bc8c2..39b60346 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.h +++ b/dnn/src/armv7/matrix_mul/opr_impl.h @@ -49,7 +49,7 @@ private: class AlgoF16K4x16x1; // Armv7 F16 Kernel 4x16x1 class AlgoF16MK8_4x8; // Armv7 F16 MK8 Format block 4x8 #endif -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4 class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4 class AlgoInt8x8x32MK4_8x4x4DotProd; // Armv7 nchw44 Int8x8x32 Kernel 8x4x4 diff --git a/dnn/src/armv7/matrix_mul/quint8/kernel_dot_4x8x4.h b/dnn/src/armv7/matrix_mul/quint8/kernel_dot_4x8x4.h index c99eb877..35439e14 100644 --- a/dnn/src/armv7/matrix_mul/quint8/kernel_dot_4x8x4.h +++ b/dnn/src/armv7/matrix_mul/quint8/kernel_dot_4x8x4.h @@ -9,7 +9,7 @@ * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT #include "src/arm_common/simd_macro/marm_neon.h" #include "src/armv7/matrix_mul/asm/common.h" @@ -41,7 +41,7 @@ namespace matmul_dot_4x8x4 { // +-------+-------+ - - - - +--------+--------+--------+ // // Accumulator - +MEGDNN_ATTRIBUTE_TARGET("dotprod") static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC, bool is_first_k, uint8_t zA, uint8_t zB, uint32_t zAB, size_t m_remain = 4) { @@ -257,6 +257,7 @@ static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, // +-------+-------+ - - - - +--------+--------+--------+ // // Accumulator +MEGDNN_ATTRIBUTE_TARGET("dotprod") static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, int32_t* output, int LDC, bool is_first_k, uint8_t zA, uint8_t zB, uint32_t zAB, size_t m_remain = 4, diff --git a/dnn/src/armv7/matrix_mul/quint8/strategy.cpp b/dnn/src/armv7/matrix_mul/quint8/strategy.cpp index 87a93890..d32f186f 100644 --- a/dnn/src/armv7/matrix_mul/quint8/strategy.cpp +++ b/dnn/src/armv7/matrix_mul/quint8/strategy.cpp @@ -88,7 +88,7 @@ void gemm_u8_4x8::kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, } } -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT // ===========================gemm_dot_quint8_4x8====================================== MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_dot_quint8_4x8); void gemm_dot_quint8_4x8::pack_A(dt_uint8* out, const dt_uint8* in, int ldin, diff --git a/dnn/src/armv7/matrix_mul/quint8/strategy.h b/dnn/src/armv7/matrix_mul/quint8/strategy.h index 170a1ff0..b8833698 100644 --- a/dnn/src/armv7/matrix_mul/quint8/strategy.h +++ b/dnn/src/armv7/matrix_mul/quint8/strategy.h @@ -17,7 +17,7 @@ namespace matmul { MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 4, 8, 8, false, true, gemm_u8_4x8); -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 4, 8, 4, false, false, gemm_dot_quint8_4x8); #endif diff --git a/dnn/src/common/utils.h b/dnn/src/common/utils.h index 2937534c..3af0ccdf 100644 --- a/dnn/src/common/utils.h +++ b/dnn/src/common/utils.h @@ -60,6 +60,13 @@ #include #endif + +#if MEGDNN_AARCH64 || MEGDNN_ARMV7 +#if MGB_ENABLE_CPUINFO +#include "cpuinfo.h" +#endif +#endif + #if __cplusplus >= 201703L || __clang_major__ >= 4 #define MEGDNN_FALLTHRU [[fallthrough]]; #elif __GNUC__ >= 7 diff --git a/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp b/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp index a3592472..36976d76 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp +++ b/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp @@ -148,7 +148,7 @@ struct GemvLike { } }; -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT template struct GemvLike { inline static void do_gemv(const stype* A, const stype* B, btype* C, diff --git a/dnn/test/aarch64/matrix_mul.cpp b/dnn/test/aarch64/matrix_mul.cpp index 0bef6a0a..139ceb1e 100644 --- a/dnn/test/aarch64/matrix_mul.cpp +++ b/dnn/test/aarch64/matrix_mul.cpp @@ -87,7 +87,7 @@ TEST_F(AARCH64, MATRIX_MUL_F16_MK8) { } #endif -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT TEST_F(AARCH64, MATRIX_MUL_INT8X8X32_K8X12X4_DOTPROD) { matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, handle(), "AARCH64_INT8X8X32_K8X12X4_DOTPROD"); @@ -690,7 +690,7 @@ TEST_F(AARCH64, BENCHMARK_GEMV) { run(M, K, N); } -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT TEST_F(AARCH64, BENCHMARK_TRANSPOSED_MATRIX_MUL_INT_8X8X32) { constexpr size_t RUNS = 50; param::MatrixMul param; @@ -803,7 +803,7 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT8X8X32_MK4_8X12X4) { std::cout << std::endl; } } -#endif // __ARM_FEATURE_DOTPROD +#endif // MGB_ENABLE_DOT #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_F16_MK8) { diff --git a/dnn/test/arm_common/conv_bias.cpp b/dnn/test/arm_common/conv_bias.cpp index 08d8add6..3cf3414c 100644 --- a/dnn/test/arm_common/conv_bias.cpp +++ b/dnn/test/arm_common/conv_bias.cpp @@ -166,7 +166,7 @@ static void benchmark_convbias(Handle* handle, std::string int_name, .set_display(false); } auto nchw44_algo_regx = ".*(DIRECT|NCHW_NCHW44).*"; -#if __ARM_FEATURE_DOTPROD +#if MGB_ENBALE_DOT if (!is_fp32) { nchw44_algo_regx = ".*DOT.*"; } @@ -1852,7 +1852,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) { #endif -#if __ARM_FEATURE_DOTPROD +#if MGB_ENBALE_DOT #if MEGDNN_WITH_BENCHMARK TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { // have to remove preferred restrict in usable func before run the benchmark @@ -2440,7 +2440,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDSYM) { dtype::QuantizedS8 stype(2.5f); dtype::QuantizedS32 dtype(6.25f); #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENBALE_DOT benchmark_conv1x1("AARCH64_INT8X8X32_K8X12X4_DOTPROD", handle(), stype, dtype, dtype, dtype); #else @@ -2460,7 +2460,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDASYM) { dtype::QuantizedS32 dtype(1.2 * 1.2); #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENBALE_DOT benchmark_conv1x1("AARCH64_QUINT8_K8X8X4_DOTPROD", handle(), stype, dtype, dtype, dtype); #else @@ -2565,7 +2565,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_GEMV_FP32) { } } -#ifndef __ARM_FEATURE_DOTPROD +//! enable none dot algo now TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_1X1_S1_NCHW_VS_NCHW44_INT8x8x32) { std::vector conv_bias_1x1_args_nchw44 = get_conv_bias_1x1_benchmark_args(4); @@ -2634,7 +2634,6 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_1X1_S1_NCHW_VS_NCHW44_INT8x8x32) { computations / conv1x1_nchw44, conv1x1_nchw / conv1x1_nchw44); } } -#endif TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_WINOGRAD_VS_IM2COL_INT8) { auto&& args = get_winograd_benchmark_args(3, 8); diff --git a/dnn/test/arm_common/conv_bias_multi_thread.cpp b/dnn/test/arm_common/conv_bias_multi_thread.cpp index fe7deb97..e76ef63c 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread.cpp @@ -500,7 +500,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) { } /****************************dot qint8 direct*************************/ -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { auto args = get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, BR_AND_NO_BIASMODE, 2, false, true); diff --git a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp index 157dda11..3d4d4ff6 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp @@ -655,7 +655,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44) { bench_case(1, 512, 256, 28, 28, 3, 4, 1, 2); } -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT) { constexpr size_t RUNS = 40; std::vector data_type = { @@ -892,7 +892,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, data_type); } -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_INT8_INT8_STRIDE1_WITHDOTPROD) { constexpr size_t RUNS = 50; @@ -1157,7 +1157,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, {1, {4}}, data_type); } -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_QUINT8_QUINT8_QUINT8_STRIDE1_WITHDOTPROD) { constexpr size_t RUNS = 50; @@ -1977,7 +1977,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, dtype::QuantizedS32 btype(0.04f); dtype::Quantized8Asymm dtype(1.4f, 110); #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT conv1x1_multithread_benchmark("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:8", stype, ftype, btype, dtype); #else diff --git a/dnn/test/arm_common/conv_bias_multi_thread_conv1x1.cpp b/dnn/test/arm_common/conv_bias_multi_thread_conv1x1.cpp index 74e6c1dc..e8a37690 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread_conv1x1.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread_conv1x1.cpp @@ -20,7 +20,7 @@ using namespace megdnn; using namespace test; using namespace conv_bias; -#ifdef __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT) { UniformIntRNG rng{-50, 50}; @@ -138,7 +138,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM) { dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ dtype::QuantizedS8(60.25f), name); #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:24"); #else cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24"); @@ -174,7 +174,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) { name); float epsilon = 0.001; #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:48"); #else cb("CONV1x1:AARCH64_QUINT8_K8X8X8:24"); @@ -210,13 +210,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32) { dtype::QuantizedS32(1.2 * 1.3), {}, name); #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:24"); #else cb("CONV1x1:AARCH64_QUINT8_K8X8X8:48"); #endif #elif MEGDNN_ARMV7 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("CONV1x1:AARCH32_QUINT8_K4X8X4:48"); #endif cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24"); @@ -287,14 +287,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) { #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:48"); #else cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24"); cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:24"); #endif #elif MEGDNN_ARMV7 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("CONV1x1:AARCH32_INT8_K6X8X4:48"); #endif cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:24"); @@ -312,8 +312,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) { } checker_conv_bias_mul_int8x8x32(gemv_args, handle(), "CONV1x1_GEMV"); } - -#ifndef __ARM_FEATURE_DOTPROD +//! enable none dot algo now TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) { using namespace conv_bias; std::vector args = @@ -345,7 +344,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) { #endif #undef cb } -#endif TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44) { using namespace conv_bias; @@ -364,7 +362,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44) { "CONV1x1_GEMV"); } -#ifdef __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44_DOT) { using namespace conv_bias; std::vector args = get_nchw44_conv_bias_args( diff --git a/dnn/test/arm_common/conv_bias_multi_thread_im2col.cpp b/dnn/test/arm_common/conv_bias_multi_thread_im2col.cpp index debfb1f7..9cbe7879 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread_im2col.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread_im2col.cpp @@ -135,7 +135,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) { float epsilon = 0.001; #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD"); #else cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8"); @@ -148,7 +148,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) { #undef cb } -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT) { UniformIntRNG rng{-50, 50}; @@ -173,6 +173,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT) { #if MEGDNN_AARCH64 cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); #elif MEGDNN_ARMV7 + epsilon = 1; cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96"); #endif #undef cb @@ -194,6 +195,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, #if MEGDNN_AARCH64 cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); #elif MEGDNN_ARMV7 + epsilon = 1; cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96"); #endif #undef cb @@ -273,7 +275,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) { dtype::Quantized8Asymm(50.3f, (uint8_t)120), name); float epsilon = 0.001; #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD"); #else cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8"); @@ -305,13 +307,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) { dtype::QuantizedS32(1.2 * 1.3), {}, name); #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD"); #else cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8"); #endif #elif MEGDNN_ARMV7 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("IM2COLMATMUL:AARCH32_QUINT8_K4X8X4"); #endif cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8"); @@ -392,7 +394,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP16) { #endif #if MEGDNN_AARCH64 || MEGDNN_ARMV7 -#if !__ARM_FEATURE_DOTPROD +//! enable none dot algo now TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { using namespace conv_bias; std::vector args = get_nchw44_conv_bias_args( @@ -483,10 +485,9 @@ TEST_F(ARM_COMMON_MULTI_THREADS, #endif #endif -#endif #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44DOT_FUSE) { UniformIntRNG rng{-50, 50}; @@ -516,14 +517,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) { #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD"); #else cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8"); cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16"); #endif #elif MEGDNN_ARMV7 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("IM2COLMATMUL:AARCH32_INT8_K6X8X4"); #endif cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8"); diff --git a/dnn/test/arm_common/conv_bias_multi_thread_weight_preprocess.cpp b/dnn/test/arm_common/conv_bias_multi_thread_weight_preprocess.cpp index 13f23fac..eb44453e 100644 --- a/dnn/test/arm_common/conv_bias_multi_thread_weight_preprocess.cpp +++ b/dnn/test/arm_common/conv_bias_multi_thread_weight_preprocess.cpp @@ -480,7 +480,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_PREPROCESS) float epsilon = 0.001; #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD"); #else cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8"); @@ -494,7 +494,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_PREPROCESS) } -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT_PREPROCESS) { @@ -520,6 +520,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, #if MEGDNN_AARCH64 cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); #elif MEGDNN_ARMV7 + epsilon = 1; cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96"); #endif #undef cb @@ -604,7 +605,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, dtype::Quantized8Asymm(50.3f, (uint8_t)120), name); float epsilon = 0.001; #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD"); #else cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8"); @@ -639,13 +640,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, dtype::QuantizedS32(1.2 * 1.3), {}, name); #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD"); #else cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8"); #endif #elif MEGDNN_ARMV7 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("IM2COLMATMUL:AARCH32_QUINT8_K4X8X4"); #endif cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8"); @@ -732,9 +733,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP16_FILTERPREPROCESS) { #endif #if MEGDNN_AARCH64 || MEGDNN_ARMV7 -#if !__ARM_FEATURE_DOTPROD - - +//! enable none dot algo now TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2_PREPROCESS) { using namespace conv_bias; @@ -830,10 +829,9 @@ TEST_F(ARM_COMMON_MULTI_THREADS, #endif #endif -#endif #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44DOT_FUSE_PREPROCESS) { @@ -867,14 +865,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, #define cb(name) checker_conv_bias_int8x8x32_preprocess(args, handle(), name); #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD"); #else cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8"); cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16"); #endif #elif MEGDNN_ARMV7 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("IM2COLMATMUL:AARCH32_INT8_K6X8X4"); #endif cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8"); @@ -987,7 +985,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM_PREPROCESS) { dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ dtype::QuantizedS8(60.25f), name); #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:24"); #else cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24"); @@ -1014,7 +1012,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM_PREPROCESS) { name); float epsilon = 0.001; #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:48"); #else cb("CONV1x1:AARCH64_QUINT8_K8X8X8:24"); @@ -1040,13 +1038,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32_PREPROCESS) { dtype::QuantizedS32(1.2 * 1.3), {}, name); #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:24"); #else cb("CONV1x1:AARCH64_QUINT8_K8X8X8:48"); #endif #elif MEGDNN_ARMV7 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("CONV1x1:AARCH32_QUINT8_K4X8X4:48"); #endif cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24"); @@ -1083,14 +1081,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_PREPROCESS) { #define cb(name) checker_conv_bias_int8x8x32_preprocess(args, handle(), name); #if MEGDNN_AARCH64 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:48"); #else cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24"); cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:24"); #endif #elif MEGDNN_ARMV7 -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT cb("CONV1x1:AARCH32_INT8_K6X8X4:48"); #endif cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:24"); @@ -1102,7 +1100,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_PREPROCESS) { #undef cb } -#ifndef __ARM_FEATURE_DOTPROD +//! enable none dot algo now TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4_PREPROCESS) { using namespace conv_bias; std::vector args = get_nchw44_conv_bias_args( @@ -1135,6 +1133,5 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4_PREPROCESS) { #undef cb } -#endif // vim: syntax=cpp.doxygen diff --git a/dnn/test/arm_common/convolution.cpp b/dnn/test/arm_common/convolution.cpp index d0872291..0357b822 100644 --- a/dnn/test/arm_common/convolution.cpp +++ b/dnn/test/arm_common/convolution.cpp @@ -20,7 +20,7 @@ using namespace test; using Param = param::Convolution; -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT TEST_F(ARM_COMMON, CONVOLUTION_BACKWARD_DATA_INT8_INT8_INT32) { Checker checker(handle()); using Param = ConvolutionBackwardData::Param; @@ -144,7 +144,7 @@ TEST_F(ARM_COMMON, CONVOLUTION_BACKWARD_DATA_QUINT8) { #endif #if MEGDNN_WITH_BENCHMARK -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT TEST_F(ARM_COMMON, BENCHMARK_CONVOLUTION_STRIDE1_I8x8x32_WITHDOTPROD) { using namespace convolution; using Param = param::Convolution; diff --git a/dnn/test/arm_common/matrix_mul.cpp b/dnn/test/arm_common/matrix_mul.cpp index 0f15e8f7..1208c154 100644 --- a/dnn/test/arm_common/matrix_mul.cpp +++ b/dnn/test/arm_common/matrix_mul.cpp @@ -16,6 +16,10 @@ #include "test/common/matrix_mul.h" #include "test/common/rng.h" +#if MGB_ENABLE_CPUINFO +#include "cpuinfo.h" +#endif + using namespace megdnn; using namespace test; @@ -196,8 +200,9 @@ TEST_F(ARM_COMMON, QINT8x8x32_GEMV_MK4) { run(M, K, 1); } -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT TEST_F(ARM_COMMON, QINT8x8x32_GEMV_MK4_DOT) { + Checker checker(handle()); using Param = MatrixMul::Param; diff --git a/dnn/test/armv7/matrix_mul.cpp b/dnn/test/armv7/matrix_mul.cpp index 7d014ee9..53943b7d 100644 --- a/dnn/test/armv7/matrix_mul.cpp +++ b/dnn/test/armv7/matrix_mul.cpp @@ -88,7 +88,7 @@ TEST_F(ARMV7, MATRIX_MUL_F16_MK8) { } #endif -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT TEST_F(ARMV7, MATRIX_MUL_SDOT) { matrix_mul::check_matrix_mul(dtype::Int8(), dtype::Int8(), dtype::Int32(), handle(), "AARCH32_INT8_K6X8X4"); @@ -298,7 +298,7 @@ void run_16x16x32_benchmark(const char* algo, Handle* handle) { } } -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT void run_8x8x32_benchmark(const char* algo, Handle* handle) { constexpr size_t RUNS = 50; param::MatrixMul param; @@ -387,7 +387,7 @@ void run_8x8x32_quint_benchmark(Handle* handle) { #endif } // namespace -#if __ARM_FEATURE_DOTPROD +#if MGB_ENABLE_DOT TEST_F(ARMV7, BENCHMARK_MATRIX_MUL_INT8x8x32_K6x8x4) { run_8x8x32_benchmark("AARCH32_INT8_K6X8X4", handle()); } diff --git a/src/megbrain_build_config.h.in b/src/megbrain_build_config.h.in index c5060346..08c25954 100644 --- a/src/megbrain_build_config.h.in +++ b/src/megbrain_build_config.h.in @@ -21,6 +21,7 @@ #cmakedefine01 MGB_ENABLE_LOGGING #cmakedefine01 MGB_ENABLE_GRAD #cmakedefine01 MGB_ENABLE_CPUINFO +#cmakedefine01 MGB_ENABLE_DOT #cmakedefine01 MGB_VERBOSE_TYPEINFO_NAME #cmakedefine01 MGB_BUILD_SLIM_SERVING #cmakedefine01 MGB_ENABLE_EXCEPTION @@ -97,6 +98,27 @@ #define MGB_ENABLE_CPUINFO 0 #endif +//! use one MACRO indicate enable_arm_dotprod +#if __ARM_FEATURE_DOTPROD +#ifdef MGB_ENABLE_DOT +#undef MGB_ENABLE_DOT +#endif +#define MGB_ENABLE_DOT 1 +#endif + + +//! ENABLE MGB DOT should enable CPUINFO +#if MGB_ENABLE_DOT +#if !defined(MGB_ENABLE_CPUINFO) || !MGB_ENABLE_CPUINFO +#ifdef MGB_ENABLE_CPUINFO +#undef MGB_ENABLE_CPUINFO +#endif +#define MGB_ENABLE_CPUINFO 1 +#endif +#endif + + + // whether to include actual class name in mgb::Typeinfo object; if this is // disabled, mgb::serialization::OprRegistry::find_opr_by_name would not work. #ifndef MGB_VERBOSE_TYPEINFO_NAME