Browse Source

chore(dotprod): add arm dotprod attribute for easy use

GitOrigin-RevId: 78c3e72218
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
f2b42bf09e
100 changed files with 503 additions and 318 deletions
  1. +34
    -8
      dnn/src/aarch64/conv_bias/int8/algos.cpp
  2. +7
    -10
      dnn/src/aarch64/conv_bias/int8/strategy.cpp
  3. +1
    -4
      dnn/src/aarch64/conv_bias/int8/strategy.h
  4. +39
    -3
      dnn/src/aarch64/conv_bias/quint8/algos.cpp
  5. +49
    -39
      dnn/src/aarch64/conv_bias/quint8/strategy.cpp
  6. +31
    -15
      dnn/src/aarch64/conv_bias/quint8/strategy.h
  7. +23
    -14
      dnn/src/aarch64/matrix_mul/algos.cpp
  8. +4
    -8
      dnn/src/aarch64/matrix_mul/algos.h
  9. +0
    -3
      dnn/src/aarch64/matrix_mul/fp32/strategy.cpp
  10. +0
    -2
      dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h
  11. +0
    -2
      dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h
  12. +0
    -3
      dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h
  13. +0
    -3
      dnn/src/aarch64/matrix_mul/int8/strategy.cpp
  14. +0
    -2
      dnn/src/aarch64/matrix_mul/int8/strategy.h
  15. +7
    -5
      dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h
  16. +5
    -4
      dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h
  17. +1
    -1
      dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp
  18. +1
    -1
      dnn/src/aarch64/matrix_mul/int8_dot/strategy.h
  19. +8
    -12
      dnn/src/aarch64/matrix_mul/opr_impl.cpp
  20. +4
    -6
      dnn/src/aarch64/matrix_mul/opr_impl.h
  21. +0
    -2
      dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h
  22. +0
    -2
      dnn/src/aarch64/matrix_mul/quint8/strategy.cpp
  23. +0
    -2
      dnn/src/aarch64/matrix_mul/quint8/strategy.h
  24. +2
    -7
      dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp
  25. +2
    -3
      dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h
  26. +5
    -3
      dnn/src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h
  27. +5
    -5
      dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp
  28. +2
    -2
      dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h
  29. +0
    -3
      dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h
  30. +7
    -1
      dnn/src/arm_common/conv_bias/int8/algos.cpp
  31. +1
    -1
      dnn/src/arm_common/conv_bias/int8/algos.h
  32. +10
    -1
      dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp
  33. +1
    -1
      dnn/src/arm_common/conv_bias/int8/direct_dotprod.h
  34. +1
    -2
      dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp
  35. +2
    -3
      dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h
  36. +4
    -2
      dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp
  37. +4
    -2
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h
  38. +4
    -2
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp
  39. +4
    -2
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp
  40. +9
    -2
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp
  41. +9
    -2
      dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp
  42. +4
    -1
      dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp
  43. +1
    -1
      dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h
  44. +1
    -1
      dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp
  45. +1
    -1
      dnn/src/arm_common/conv_bias/int8/stride1_dotprod.h
  46. +1
    -1
      dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp
  47. +1
    -1
      dnn/src/arm_common/conv_bias/int8/stride2_dotprod.h
  48. +2
    -2
      dnn/src/arm_common/conv_bias/opr_impl.cpp
  49. +1
    -1
      dnn/src/arm_common/conv_bias/opr_impl.h
  50. +9
    -2
      dnn/src/arm_common/conv_bias/quint8/algos.cpp
  51. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/algos.h
  52. +9
    -1
      dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp
  53. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/direct_dotprod.h
  54. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp
  55. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.h
  56. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp
  57. +1
    -1
      dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.h
  58. +8
    -2
      dnn/src/arm_common/convolution/int8x8x32/algos.cpp
  59. +1
    -1
      dnn/src/arm_common/convolution/int8x8x32/algos.h
  60. +6
    -3
      dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp
  61. +1
    -1
      dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h
  62. +5
    -3
      dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp
  63. +1
    -1
      dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h
  64. +2
    -2
      dnn/src/arm_common/convolution/opr_impl.cpp
  65. +1
    -1
      dnn/src/arm_common/convolution/opr_impl.h
  66. +9
    -1
      dnn/src/arm_common/convolution/quint8/algos.cpp
  67. +1
    -1
      dnn/src/arm_common/convolution/quint8/algos.h
  68. +6
    -3
      dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp
  69. +1
    -4
      dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.h
  70. +5
    -3
      dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp
  71. +1
    -4
      dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.h
  72. +7
    -1
      dnn/src/arm_common/matrix_mul/algos.cpp
  73. +1
    -1
      dnn/src/arm_common/matrix_mul/algos.h
  74. +27
    -8
      dnn/src/arm_common/matrix_mul/int8/gemv.cpp
  75. +1
    -1
      dnn/src/arm_common/matrix_mul/int8/gemv.h
  76. +2
    -2
      dnn/src/arm_common/matrix_mul/opr_impl.cpp
  77. +1
    -1
      dnn/src/arm_common/matrix_mul/opr_impl.h
  78. +3
    -2
      dnn/src/arm_common/neon_struct.h
  79. +20
    -7
      dnn/src/arm_common/simd_macro/marm_neon.h
  80. +13
    -1
      dnn/src/armv7/matrix_mul/algos.cpp
  81. +1
    -1
      dnn/src/armv7/matrix_mul/algos.h
  82. +0
    -1
      dnn/src/armv7/matrix_mul/asm/common.h
  83. +0
    -1
      dnn/src/armv7/matrix_mul/fp32/strategy_4x12.cpp
  84. +3
    -1
      dnn/src/armv7/matrix_mul/int8/kernel_6x8x4.h
  85. +3
    -2
      dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h
  86. +1
    -1
      dnn/src/armv7/matrix_mul/int8/strategy.cpp
  87. +1
    -1
      dnn/src/armv7/matrix_mul/int8/strategy.h
  88. +2
    -2
      dnn/src/armv7/matrix_mul/opr_impl.cpp
  89. +1
    -1
      dnn/src/armv7/matrix_mul/opr_impl.h
  90. +3
    -2
      dnn/src/armv7/matrix_mul/quint8/kernel_dot_4x8x4.h
  91. +1
    -1
      dnn/src/armv7/matrix_mul/quint8/strategy.cpp
  92. +1
    -1
      dnn/src/armv7/matrix_mul/quint8/strategy.h
  93. +7
    -0
      dnn/src/common/utils.h
  94. +1
    -1
      dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp
  95. +3
    -3
      dnn/test/aarch64/matrix_mul.cpp
  96. +5
    -6
      dnn/test/arm_common/conv_bias.cpp
  97. +1
    -1
      dnn/test/arm_common/conv_bias_multi_thread.cpp
  98. +4
    -4
      dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp
  99. +9
    -11
      dnn/test/arm_common/conv_bias_multi_thread_conv1x1.cpp
  100. +11
    -10
      dnn/test/arm_common/conv_bias_multi_thread_im2col.cpp

+ 34
- 8
dnn/src/aarch64/conv_bias/int8/algos.cpp View File

@@ -67,6 +67,23 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle(
size_t K = IC * FH * FW; size_t K = IC * FH * FW;
size_t N = OH * OW; size_t N = OH * OW;


#if MGB_ENABLE_DOT
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
part2 = megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline>( \
M, N, K, false, false, strategy) \
.get_workspace_size();

if (cpuinfo_has_arm_neon_dot()) {
DISPATCH_GEMM_BIAS(s8_8x12, 1)
} else {
DISPATCH_GEMM_BIAS(s8_4x4, 0)
}
#else
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \ _bias_midout_enum, _nonline, \
_nonline_midout_enum) \ _nonline_midout_enum) \
@@ -80,11 +97,7 @@ WorkspaceBundle ConvBiasImpl::AlgoS8MatrixMul::get_bundle(
.get_workspace_size(); \ .get_workspace_size(); \
} \ } \
MIDOUT_END() MIDOUT_END()

#if !(__ARM_FEATURE_DOTPROD)
DISPATCH_GEMM_BIAS(s8_4x4, 0) DISPATCH_GEMM_BIAS(s8_4x4, 0)
#else
DISPATCH_GEMM_BIAS(s8_8x12, 1)
#endif #endif
#undef DISPATCH_GEMM_STRATEGY #undef DISPATCH_GEMM_STRATEGY
} }
@@ -158,6 +171,23 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param,
size_t K = IC * FH * FW; size_t K = IC * FH * FW;
size_t N = OH * OW; size_t N = OH * OW;


#if MGB_ENABLE_DOT
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias);

if (cpuinfo_has_arm_neon_dot()) {
DISPATCH_GEMM_BIAS(s8_8x12, 1)
} else {
DISPATCH_GEMM_BIAS(s8_4x4, 0)
}
#else
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \ _bias_midout_enum, _nonline, \
_nonline_midout_enum) \ _nonline_midout_enum) \
@@ -172,11 +202,7 @@ void ConvBiasImpl::AlgoS8MatrixMul::kimpl(const NCBKernParam& param,
bias); \ bias); \
} \ } \
MIDOUT_END() MIDOUT_END()

#if !(__ARM_FEATURE_DOTPROD)
DISPATCH_GEMM_BIAS(s8_4x4, 0) DISPATCH_GEMM_BIAS(s8_4x4, 0)
#else
DISPATCH_GEMM_BIAS(s8_8x12, 1)
#endif #endif
#undef DISPATCH_GEMM_STRATEGY #undef DISPATCH_GEMM_STRATEGY
} }


+ 7
- 10
dnn/src/aarch64/conv_bias/int8/strategy.cpp View File

@@ -26,7 +26,7 @@ namespace impl {
template <BiasMode bmode, typename Op, int block_m, int block_n> template <BiasMode bmode, typename Op, int block_m, int block_n>
struct KernCaller; struct KernCaller;


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
template <BiasMode bmode, typename Op> template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 8, 12> { struct KernCaller<bmode, Op, 8, 12> {
static void run(const dt_int8* packA, const dt_int8* packB, size_t M, static void run(const dt_int8* packA, const dt_int8* packB, size_t M,
@@ -118,7 +118,7 @@ struct KernCaller<bmode, Op, 8, 12> {
} }
}; };


#else
#endif


template <BiasMode bmode, typename Op> template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 4, 4> { struct KernCaller<bmode, Op, 4, 4> {
@@ -196,10 +196,8 @@ struct KernCaller<bmode, Op, 4, 4> {
} }
}; };


#endif

} // namespace impl } // namespace impl
#if !(__ARM_FEATURE_DOTPROD)

MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity) MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_4x4_nobias_identity)


void gemm_s8_4x4_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, void gemm_s8_4x4_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr,
@@ -227,7 +225,8 @@ void gemm_s8_4x4_nobias_identity::pack_B(dt_int8* out, const dt_int8* in,
size_t gemm_s8_4x4_nobias_identity::get_workspace_size() const { size_t gemm_s8_4x4_nobias_identity::get_workspace_size() const {
return 4 * 4 * sizeof(dt_int32); return 4 * 4 * sizeof(dt_int32);
} }
#else

#if MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity) MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_s8_8x12_nobias_identity)


void gemm_s8_8x12_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr, void gemm_s8_8x12_nobias_identity::pack_A(dt_int8* outptr, const dt_int8* inptr,
@@ -277,11 +276,10 @@ size_t gemm_s8_8x12_nobias_identity::get_workspace_size() const {
#define DEFINE_OP(_Op) \ #define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, scale_C); arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, scale_C);


#if !(__ARM_FEATURE_DOTPROD)
KERN(4, 4, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) KERN(4, 4, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp)
KERN(4, 4, nobias, BiasMode::NO_BIAS, relu, ReluOp) KERN(4, 4, nobias, BiasMode::NO_BIAS, relu, ReluOp)
KERN(4, 4, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) KERN(4, 4, nobias, BiasMode::NO_BIAS, hswish, HSwishOp)
#else
#if MGB_ENABLE_DOT
KERN(8, 12, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp) KERN(8, 12, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp)
KERN(8, 12, nobias, BiasMode::NO_BIAS, relu, ReluOp) KERN(8, 12, nobias, BiasMode::NO_BIAS, relu, ReluOp)
KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp) KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp)
@@ -291,12 +289,11 @@ KERN(8, 12, nobias, BiasMode::NO_BIAS, hswish, HSwishOp)
#define DEFINE_OP(_Op) \ #define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, \ arm_common::_Op<dt_qint32, dt_qint8> op(scale_A* scale_B, \
scale_A* scale_B, scale_C); scale_A* scale_B, scale_C);
#if !(__ARM_FEATURE_DOTPROD)
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp)
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp)
KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, KERN(4, 4, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish,
FuseAddHSwishOp) FuseAddHSwishOp)
#else
#if MGB_ENABLE_DOT
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp) KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp)
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp) KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp)
KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, KERN(8, 12, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish,


+ 1
- 4
dnn/src/aarch64/conv_bias/int8/strategy.h View File

@@ -15,7 +15,6 @@ namespace megdnn {
namespace aarch64 { namespace aarch64 {
namespace matmul { namespace matmul {


#if !(__ARM_FEATURE_DOTPROD)
/** /**
* \brief base strategy of gemm. * \brief base strategy of gemm.
* *
@@ -39,8 +38,7 @@ MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_relu,


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_hswish, MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_4x4_bias_channel_hswish,
gemm_s8_4x4_nobias_identity); gemm_s8_4x4_nobias_identity);

#else
#if MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 8, 12, 4, MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_int8, dt_int8, dt_int32, 8, 12, 4,
false, true, false, true,
gemm_s8_8x12_nobias_identity); gemm_s8_8x12_nobias_identity);
@@ -59,7 +57,6 @@ MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_relu,


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_hswish, MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_s8_8x12_bias_channel_hswish,
gemm_s8_8x12_nobias_identity); gemm_s8_8x12_nobias_identity);

#endif #endif


} // namespace matmul } // namespace matmul


+ 39
- 3
dnn/src/aarch64/conv_bias/quint8/algos.cpp View File

@@ -69,6 +69,23 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle(
size_t K = IC * FH * FW; size_t K = IC * FH * FW;
size_t N = OH * OW; size_t N = OH * OW;


#if MGB_ENABLE_DOT
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
part2 = megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline>( \
M, N, K, false, false, strategy) \
.get_workspace_size();

if (cpuinfo_has_arm_neon_dot()) {
DISPATCH_GEMM_BIAS(u8_8x8_dot, 1);
} else {
DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0);
}
#else
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \ _bias_midout_enum, _nonline, \
_nonline_midout_enum) \ _nonline_midout_enum) \
@@ -82,8 +99,8 @@ WorkspaceBundle ConvBiasImpl::AlgoQU8MatrixMul::get_bundle(
.get_workspace_size(); \ .get_workspace_size(); \
} \ } \
MIDOUT_END() MIDOUT_END()
DISPATCH_GEMM_BIAS(u8_8x8, 0)
DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0)
#endif
#undef DISPATCH_GEMM_STRATEGY #undef DISPATCH_GEMM_STRATEGY
} }
return {nullptr, {part0, part1, part2}}; return {nullptr, {part0, part1, part2}};
@@ -157,6 +174,23 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param,
size_t K = IC * FH * FW; size_t K = IC * FH * FW;
size_t N = OH * OW; size_t N = OH * OW;


#if MGB_ENABLE_DOT
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \
_nonline_midout_enum) \
matmul::gemm_##_gemm##_##_bias##_##_nonline strategy( \
M, N, K, param.filter_type, param.src_type, param.dst_type); \
megdnn::matmul::GemmInterleaved< \
matmul::gemm_##_gemm##_##_bias##_##_nonline> \
gemm_interleaved(M, N, K, false, false, strategy); \
gemm_interleaved.execute(filter, K, B, N, dst, N, workspace.raw_ptr, bias);

if (cpuinfo_has_arm_neon_dot()) {
DISPATCH_GEMM_BIAS(u8_8x8_dot, 1)
} else {
DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0)
}
#else
#define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \ #define DISPATCH_GEMM_STRATEGY(_gemm, _gemm_midout_enum, _bias, \
_bias_midout_enum, _nonline, \ _bias_midout_enum, _nonline, \
_nonline_midout_enum) \ _nonline_midout_enum) \
@@ -172,7 +206,9 @@ void ConvBiasImpl::AlgoQU8MatrixMul::kimpl(const NCBKernParam& param,
} \ } \
MIDOUT_END() MIDOUT_END()


DISPATCH_GEMM_BIAS(u8_8x8, 0)
DISPATCH_GEMM_BIAS(u8_8x8_nodot, 0)

#endif
#undef DISPATCH_GEMM_STRATEGY #undef DISPATCH_GEMM_STRATEGY
} }
} }


+ 49
- 39
dnn/src/aarch64/conv_bias/quint8/strategy.cpp View File

@@ -23,12 +23,12 @@ using namespace aarch64;
using namespace aarch64::matmul; using namespace aarch64::matmul;


namespace impl { namespace impl {
template <BiasMode bmode, typename Op, int block_m, int block_n>
template <BiasMode bmode, typename Op, int block_m, int block_n, bool dot>
struct KernCaller; struct KernCaller;


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
template <BiasMode bmode, typename Op> template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 8, 8> {
struct KernCaller<bmode, Op, 8, 8, true> {
static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M,
size_t N, size_t K, dt_uint8* C, size_t LDC, size_t N, size_t K, dt_uint8* C, size_t LDC,
bool is_first_k, Op op, const dt_int32* bias, bool is_first_k, Op op, const dt_int32* bias,
@@ -120,10 +120,10 @@ struct KernCaller<bmode, Op, 8, 8> {
} }
}; };


#else
#endif


template <BiasMode bmode, typename Op> template <BiasMode bmode, typename Op>
struct KernCaller<bmode, Op, 8, 8> {
struct KernCaller<bmode, Op, 8, 8, false> {
static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M, static void run(const dt_uint8* packA, const dt_uint8* packB, size_t M,
size_t N, size_t K, dt_uint8* C, size_t LDC, size_t N, size_t K, dt_uint8* C, size_t LDC,
bool is_first_k, Op op, const dt_int32* bias, bool is_first_k, Op op, const dt_int32* bias,
@@ -215,13 +215,11 @@ struct KernCaller<bmode, Op, 8, 8> {
} }
}; };


#endif

} // namespace impl } // namespace impl
#if __ARM_FEATURE_DOTPROD
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nobias_identity)
#if MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_dot_nobias_identity)


void gemm_u8_8x8_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr,
void gemm_u8_8x8_dot_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr,
int ldin, int y0, int ymax, int k0, int ldin, int y0, int ymax, int k0,
int kmax, bool transpose) const { int kmax, bool transpose) const {
if (transpose) { if (transpose) {
@@ -233,7 +231,7 @@ void gemm_u8_8x8_nobias_identity::pack_A(uint8_t* outptr, const uint8_t* inptr,
} }
} }


void gemm_u8_8x8_nobias_identity::pack_B(uint8_t* out, const uint8_t* in,
void gemm_u8_8x8_dot_nobias_identity::pack_B(uint8_t* out, const uint8_t* in,
int ldin, int x0, int xmax, int k0, int ldin, int x0, int xmax, int k0,
int kmax, bool transpose) const { int kmax, bool transpose) const {
if (transpose) { if (transpose) {
@@ -245,10 +243,13 @@ void gemm_u8_8x8_nobias_identity::pack_B(uint8_t* out, const uint8_t* in,
} }
} }


#else
size_t gemm_u8_8x8_dot_nobias_identity::get_workspace_size() const {
return 8 * 8 * sizeof(dt_int32);
}


MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nobias_identity)
void gemm_u8_8x8_nobias_identity::pack_A(dt_uint8* outptr,
#endif
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_nodot_nobias_identity)
void gemm_u8_8x8_nodot_nobias_identity::pack_A(dt_uint8* outptr,
const dt_uint8* inptr, int ldin, const dt_uint8* inptr, int ldin,
int y0, int ymax, int k0, int kmax, int y0, int ymax, int k0, int kmax,
bool transpose) const { bool transpose) const {
@@ -262,7 +263,7 @@ void gemm_u8_8x8_nobias_identity::pack_A(dt_uint8* outptr,
} }
} }


void gemm_u8_8x8_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in,
void gemm_u8_8x8_nodot_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in,
int ldin, int x0, int xmax, int k0, int ldin, int x0, int xmax, int k0,
int kmax, bool transpose) const { int kmax, bool transpose) const {
uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point; uint8_t zB = B_dtype.param<dtype::Quantized8Asymm>().zero_point;
@@ -275,43 +276,52 @@ void gemm_u8_8x8_nobias_identity::pack_B(dt_uint8* out, const dt_uint8* in,
} }
} }


#endif
size_t gemm_u8_8x8_nobias_identity::get_workspace_size() const {
size_t gemm_u8_8x8_nodot_nobias_identity::get_workspace_size() const {
return 8 * 8 * sizeof(dt_int32); return 8 * 8 * sizeof(dt_int32);
} }


#define KERN(_block_m, _block_n, _bias, _BIAS, _nonline, _OP) \
void gemm_u8_##_block_m##x##_block_n##_##_bias##_##_nonline::kern( \
const dt_uint8* packA, const dt_uint8* packB, size_t M, size_t N, \
size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \
const dt_int32* bias, dt_int32* workspace) const { \
float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \
DEFINE_OP(_OP); \
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n>::run( \
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \
workspace, zp_A, zp_B); \
#define KERN(_block_m, _block_n, _dot, _suffix, _bias, _BIAS, _nonline, \
_OP) \
void gemm_u8_##_block_m##x##_block_n##_suffix##_##_bias##_##_nonline:: \
kern(const dt_uint8* packA, const dt_uint8* packB, size_t M, \
size_t N, size_t K, dt_uint8* C, size_t LDC, bool is_first_k, \
const dt_int32* bias, dt_int32* workspace) const { \
float scale_A = A_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_A = A_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_B = B_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_B = B_dtype.param<dtype::Quantized8Asymm>().zero_point; \
float scale_C = C_dtype.param<dtype::Quantized8Asymm>().scale; \
uint8_t zp_C = C_dtype.param<dtype::Quantized8Asymm>().zero_point; \
DEFINE_OP(_OP); \
impl::KernCaller<_BIAS, decltype(op), _block_m, _block_n, _dot>::run( \
packA, packB, M, N, K, C, LDC, is_first_k, op, bias, \
workspace, zp_A, zp_B); \
} }


#define DEFINE_OP(_Op) \ #define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, scale_C, zp_C); arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, scale_C, zp_C);


KERN(8, 8, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp)
KERN(8, 8, nobias, BiasMode::NO_BIAS, relu, ReluOp)
KERN(8, 8, nobias, BiasMode::NO_BIAS, hswish, HSwishOp)
#if MGB_ENABLE_DOT
KERN(8, 8, true, _dot, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp)
KERN(8, 8, true, _dot, nobias, BiasMode::NO_BIAS, relu, ReluOp)
KERN(8, 8, true, _dot, nobias, BiasMode::NO_BIAS, hswish, HSwishOp)
#endif
KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, identity, TypeCvtOp)
KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, relu, ReluOp)
KERN(8, 8, false, _nodot, nobias, BiasMode::NO_BIAS, hswish, HSwishOp)
#undef DEFINE_OP #undef DEFINE_OP


#define DEFINE_OP(_Op) \ #define DEFINE_OP(_Op) \
arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, \ arm_common::_Op<dt_qint32, dt_quint8> op(scale_A* scale_B, \
scale_A* scale_B, scale_C, zp_C); scale_A* scale_B, scale_C, zp_C);
KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp)
KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp)
KERN(8, 8, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish,
FuseAddHSwishOp)
#if MGB_ENABLE_DOT
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp)
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp)
KERN(8, 8, true, _dot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp)
#endif
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, identity, AddOp)
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, relu, FuseAddReluOp)
KERN(8, 8, false, _nodot, bias_channel, BiasMode::BROADCAST_CHANNEL_BIAS, hswish, FuseAddHSwishOp)
#undef DEFINE_OP #undef DEFINE_OP


#undef KERN #undef KERN


+ 31
- 15
dnn/src/aarch64/conv_bias/quint8/strategy.h View File

@@ -15,30 +15,46 @@ namespace megdnn {
namespace aarch64 { namespace aarch64 {
namespace matmul { namespace matmul {


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 4, MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 4,
false, true, false, true,
gemm_u8_8x8_nobias_identity);
#else
gemm_u8_8x8_dot_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_relu,
gemm_u8_8x8_dot_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_nobias_hswish,
gemm_u8_8x8_dot_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_identity,
gemm_u8_8x8_dot_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_relu,
gemm_u8_8x8_dot_nobias_identity);

MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_dot_bias_channel_hswish,
gemm_u8_8x8_dot_nobias_identity);


#endif
MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 8, MEGDNN_REG_GEMM_STRATEGY_WITH_WRITEBACK(dt_uint8, dt_uint8, dt_int32, 8, 8, 8,
false, true, false, true,
gemm_u8_8x8_nobias_identity);
#endif
gemm_u8_8x8_nodot_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nobias_relu,
gemm_u8_8x8_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_relu,
gemm_u8_8x8_nodot_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nobias_hswish,
gemm_u8_8x8_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_nobias_hswish,
gemm_u8_8x8_nodot_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_identity,
gemm_u8_8x8_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_identity,
gemm_u8_8x8_nodot_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_relu,
gemm_u8_8x8_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_relu,
gemm_u8_8x8_nodot_nobias_identity);


MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_bias_channel_hswish,
gemm_u8_8x8_nobias_identity);
MEGDNN_REG_GEMM_STRATEGY_WITH_SUPER(gemm_u8_8x8_nodot_bias_channel_hswish,
gemm_u8_8x8_nodot_nobias_identity);




} // namespace matmul } // namespace matmul


+ 23
- 14
dnn/src/aarch64/matrix_mul/algos.cpp View File

@@ -24,9 +24,6 @@
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_impl.h" #include "src/fallback/matrix_mul/gemm_impl.h"


#if MGB_ENABLE_CPUINFO
#include "cpuinfo.h"
#endif
#include "midout.h" #include "midout.h"


MIDOUT_DECL(megdnn_aarch64_matmul_kern) MIDOUT_DECL(megdnn_aarch64_matmul_kern)
@@ -394,7 +391,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16MK8_8x8::get_kern(


#endif #endif


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
/* ==================== Int8x8x32 K8x12x4 Dotprod algo ==================== */ /* ==================== Int8x8x32 K8x12x4 Dotprod algo ==================== */
namespace { namespace {
void int8x8x32_k8x12x4_dotprod_kern( void int8x8x32_k8x12x4_dotprod_kern(
@@ -422,6 +419,9 @@ void int8x8x32_k8x12x4_dotprod_kern(


bool MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::usable( bool MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd::usable(
const KernSizeParam& kern_size_param) const { const KernSizeParam& kern_size_param) const {
if (!cpuinfo_has_arm_neon_dot()){
return false;
}
return can_be_treated_as_int8x8x32(kern_size_param); return can_be_treated_as_int8x8x32(kern_size_param);
} }


@@ -484,6 +484,11 @@ void int8x8x32_mk4_8x12x4_dotprod_kern(


bool MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::usable( bool MatrixMulImpl::AlgoInt8x8x32MK4_8x12x4DotProd::usable(
const KernSizeParam& kern_size_param) const { const KernSizeParam& kern_size_param) const {

if (!cpuinfo_has_arm_neon_dot()){
return false;
}

return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 || (kern_size_param.A_type.enumv() == DTypeEnum::Int8 ||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) &&
@@ -527,7 +532,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x12x4DotProd,
aarch64::matmul::gemm_mk4_s8_8x12, int8_t, aarch64::matmul::gemm_mk4_s8_8x12, int8_t,
int32_t, AlgoDataType::QINT8X8X32, int32_t, AlgoDataType::QINT8X8X32,
MK4_DOT); MK4_DOT);
#else
#endif


/* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */ /* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */
namespace { namespace {
@@ -727,7 +732,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x8x8,
aarch64::matmul::gemm_s8_8x8, int8_t, aarch64::matmul::gemm_s8_8x8, int8_t,
int32_t, AlgoDataType::QINT8X8X32, int32_t, AlgoDataType::QINT8X8X32,
DEFAULT); DEFAULT);
#endif


/* ===================== Int8x8x16 K8x8x8 algo ===================== */ /* ===================== Int8x8x16 K8x8x8 algo ===================== */
namespace { namespace {
@@ -1151,7 +1155,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt16x16x32MK8_8x8::get_kern(
return kern_mk8_8x8; return kern_mk8_8x8;
} }


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
/* ==================== Quint8 K8x8x4 Dotprod algo ==================== */ /* ==================== Quint8 K8x8x4 Dotprod algo ==================== */
namespace { namespace {
void quint8_k8x8x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { void quint8_k8x8x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
@@ -1166,8 +1170,8 @@ void quint8_k8x8x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
Bptr = kern_param.B<dt_uint8>(); Bptr = kern_param.B<dt_uint8>();
auto Cptr = kern_param.C<dt_int32>(); auto Cptr = kern_param.C<dt_int32>();


aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type);
megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8>(
aarch64::matmul::gemm_u8_8x8_dot strategy(M, N, K, A_type, B_type, C_type);
megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8_dot>(
M, N, K, trA, trB, strategy) M, N, K, trA, trB, strategy)
.execute(Aptr, LDA, Bptr, LDB, Cptr, LDC, .execute(Aptr, LDA, Bptr, LDB, Cptr, LDC,
kern_param.workspace_ptr); kern_param.workspace_ptr);
@@ -1178,6 +1182,9 @@ void quint8_k8x8x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {


bool MatrixMulImpl::AlgoQuint8K8x8x4DotProd::usable( bool MatrixMulImpl::AlgoQuint8K8x8x4DotProd::usable(
const KernSizeParam& kern_size_param) const { const KernSizeParam& kern_size_param) const {
if (!cpuinfo_has_arm_neon_dot()){
return false;
}
return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm &&
kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm && kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 && kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 &&
@@ -1195,8 +1202,8 @@ size_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_workspace(
auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type, auto A_type = kern_size_param.A_type, B_type = kern_size_param.B_type,
C_type = kern_size_param.C_type; C_type = kern_size_param.C_type;


aarch64::matmul::gemm_u8_8x8 strategy(M, N, K, A_type, B_type, C_type);
return megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8>(
aarch64::matmul::gemm_u8_8x8_dot strategy(M, N, K, A_type, B_type, C_type);
return megdnn::matmul::GemmInterleaved<aarch64::matmul::gemm_u8_8x8_dot>(
M, N, K, trA, trB, strategy) M, N, K, trA, trB, strategy)
.get_workspace_size(); .get_workspace_size();
} }
@@ -1212,7 +1219,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K8x8x4DotProd::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x4DotProd, MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x4DotProd,
megdnn_aarch64_matmul_kern, megdnn_aarch64_matmul_kern,
"AlgoQuint8K8x8x4DotProdImpl"_hash, "AlgoQuint8K8x8x4DotProdImpl"_hash,
aarch64::matmul::gemm_u8_8x8, uint8_t,
aarch64::matmul::gemm_u8_8x8_dot, uint8_t,
int32_t, AlgoDataType::QUINT8X8X32, int32_t, AlgoDataType::QUINT8X8X32,
DEFAULT); DEFAULT);
/* ===================== Quint8 Gemv DotProd algo ===================== */ /* ===================== Quint8 Gemv DotProd algo ===================== */
@@ -1238,6 +1245,9 @@ void quint8_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {


bool MatrixMulImpl::AlgoQuint8GemvDotProd::usable( bool MatrixMulImpl::AlgoQuint8GemvDotProd::usable(
const KernSizeParam& kern_size_param) const { const KernSizeParam& kern_size_param) const {
if (!cpuinfo_has_arm_neon_dot()){
return false;
}
return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm &&
kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm && kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 && kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 &&
@@ -1257,7 +1267,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8GemvDotProd::get_kern(
const KernSizeParam&) const { const KernSizeParam&) const {
return quint8_gemv_dotprod_kern; return quint8_gemv_dotprod_kern;
} }
#else
#endif


/* ===================== Quint8 K8x8x8 algo ===================== */ /* ===================== Quint8 K8x8x8 algo ===================== */
namespace { namespace {
@@ -1322,7 +1332,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8,
aarch64::matmul::gemm_u8_8x8, uint8_t, aarch64::matmul::gemm_u8_8x8, uint8_t,
int32_t, AlgoDataType::QUINT8X8X32, int32_t, AlgoDataType::QUINT8X8X32,
DEFAULT); DEFAULT);
#endif


/* ===================== Int8x8x16 K8x8x8 algo ===================== */ /* ===================== Int8x8x16 K8x8x8 algo ===================== */
namespace { namespace {


+ 4
- 8
dnn/src/aarch64/matrix_mul/algos.h View File

@@ -111,7 +111,7 @@ public:


#endif #endif


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32K8x12x4DotProd final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {
@@ -141,7 +141,7 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD) MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD)
}; };
#else
#endif


class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase {
public: public:
@@ -187,7 +187,6 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K8X8X8) MEGDNN_DECL_ALGO_TYPE(AARCH64_INT8X8X32_K8X8X8)
}; };
#endif


class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x16K8x8x8 final : public AlgoBase {
public: public:
@@ -313,7 +312,7 @@ public:
MEGDNN_DECL_ALGO_TYPE(AARCH64_INT16X16X32_MK8_8X8) MEGDNN_DECL_ALGO_TYPE(AARCH64_INT16X16X32_MK8_8X8)
}; };


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase { class MatrixMulImpl::AlgoQuint8K8x8x4DotProd final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {
@@ -328,7 +327,6 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X4_DOTPROD) MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X4_DOTPROD)
}; };

class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {
@@ -344,8 +342,7 @@ public:
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT) MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT)
MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_GEMV_DOTPROD) MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_GEMV_DOTPROD)
}; };
#else

#endif
class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase { class MatrixMulImpl::AlgoQuint8K8x8x8 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {
@@ -358,7 +355,6 @@ public:
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL(); MEGDNN_REG_GEMM_FUNC_FOR_IM2COL();
MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X8) MEGDNN_DECL_ALGO_TYPE(AARCH64_QUINT8_K8X8X8)
}; };
#endif


} // namespace aarch64 } // namespace aarch64
} // namespace megdnn } // namespace megdnn


+ 0
- 3
dnn/src/aarch64/matrix_mul/fp32/strategy.cpp View File

@@ -20,9 +20,6 @@
#include "src/aarch64/matrix_mul/fp32/strategy.h" #include "src/aarch64/matrix_mul/fp32/strategy.h"
#include "src/common/utils.h" #include "src/common/utils.h"


#if MGB_ENABLE_CPUINFO
#include "cpuinfo.h"
#endif


using namespace megdnn; using namespace megdnn;
using namespace aarch64; using namespace aarch64;


+ 0
- 2
dnn/src/aarch64/matrix_mul/int8/kernel_4x4x16.h View File

@@ -9,7 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#if !(__ARM_FEATURE_DOTPROD)
#include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"


@@ -851,6 +850,5 @@ static void gemm_s8_4x4_pack_B_n(dt_int8* out, const dt_int8* in, int ldin,
} // namespace matmul_4x4x16 } // namespace matmul_4x4x16
} // namespace aarch64 } // namespace aarch64
} // namespace megdnn } // namespace megdnn
#endif


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 0
- 2
dnn/src/aarch64/matrix_mul/int8/kernel_8x8x8.h View File

@@ -9,7 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#if !(__ARM_FEATURE_DOTPROD)
#include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"


@@ -1372,4 +1371,3 @@ static void gemm_s8_8x8_transpose_pack_B_n(int8_t* outptr, const int8_t* inptr,
} // namespace megdnn } // namespace megdnn


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
#endif

+ 0
- 3
dnn/src/aarch64/matrix_mul/int8/kernel_mk4_4x4x16.h View File

@@ -10,8 +10,6 @@
* implied. * implied.
*/ */


#include <cstring>
#if !(__ARM_FEATURE_DOTPROD)
#include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"


@@ -887,6 +885,5 @@ static void gemm_mk4_s8_4x4_pack_B(dt_int8* out, const dt_int8* in, int ldin,
} // namespace matmul_4x4x16 } // namespace matmul_4x4x16
} // namespace aarch64 } // namespace aarch64
} // namespace megdnn } // namespace megdnn
#endif


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 0
- 3
dnn/src/aarch64/matrix_mul/int8/strategy.cpp View File

@@ -9,7 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#if !(__ARM_FEATURE_DOTPROD)
#include "src/aarch64/matrix_mul/int8/strategy.h" #include "src/aarch64/matrix_mul/int8/strategy.h"
#include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/int8/kernel_4x4x16.h" #include "src/aarch64/matrix_mul/int8/kernel_4x4x16.h"
@@ -105,7 +104,6 @@ void gemm_s8_4x4::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
packA += K4; packA += K4;
} }
} }

///////////////////////// gemm_mk4_s8_4x4 //////////////////////////////////// ///////////////////////// gemm_mk4_s8_4x4 ////////////////////////////////////
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x4); MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_mk4_s8_4x4);


@@ -258,6 +256,5 @@ void gemm_s8_8x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
packA += K4; packA += K4;
} }
} }
#endif


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 0
- 2
dnn/src/aarch64/matrix_mul/int8/strategy.h View File

@@ -10,7 +10,6 @@
*/ */
#pragma once #pragma once


#if !(__ARM_FEATURE_DOTPROD)
#include "src/fallback/matrix_mul/gemm_common.h" #include "src/fallback/matrix_mul/gemm_common.h"


namespace megdnn { namespace megdnn {
@@ -30,5 +29,4 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 8, 8, 8, false, true,
} // namespace aarch64 } // namespace aarch64
} // namespace megdnn } // namespace megdnn


#endif
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 7
- 5
dnn/src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h View File

@@ -9,8 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#if __ARM_FEATURE_DOTPROD

#if MGB_ENABLE_DOT
#include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"


@@ -50,7 +49,9 @@ namespace matmul_8x12x4 {
* same, I test in kirin980 with small and big core, here i just keep both the * same, I test in kirin980 with small and big core, here i just keep both the
* implementation. * implementation.
*/ */

#if 1 #if 1
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) { int32_t* output, int LDC, bool is_first_k) {
K /= 4; K /= 4;
@@ -408,6 +409,7 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
); );
} }
#else #else
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) { int32_t* output, int LDC, bool is_first_k) {
K /= 4; K /= 4;
@@ -650,7 +652,7 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
// +-------+-------+ - - - - +--------+--------+--------+ // +-------+-------+ - - - - +--------+--------+--------+
// //
// Accumulator // Accumulator
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, static void kern_4x12(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int m_remain) { int32_t* output, int LDC, bool is_first_k, int m_remain) {
K /= 4; K /= 4;
@@ -837,7 +839,7 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K,
// +-------+-------+ - - - - +---------+ // +-------+-------+ - - - - +---------+
// //
// Accumulator // Accumulator
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int n_remain) { int32_t* output, int LDC, bool is_first_k, int n_remain) {
K /= 4; K /= 4;
@@ -1038,7 +1040,7 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
// +-------+-------+ - - - - +--------+ // +-------+-------+ - - - - +--------+
// //
// Accumulator // Accumulator
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int m_remain, int32_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain) { int n_remain) {


+ 5
- 4
dnn/src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h View File

@@ -10,8 +10,7 @@
* implied. * implied.
*/ */


#if __ARM_FEATURE_DOTPROD

#if MGB_ENABLE_DOT
#include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"


@@ -40,6 +39,7 @@ namespace matmul_mk4_8x12x4 {
// //
// Accumulator // Accumulator


MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x12(const int8_t* packA, const int8_t* packB, int K, static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) { int32_t* output, int LDC, bool is_first_k) {
K /= 4; K /= 4;
@@ -60,7 +60,6 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,


int32_t* outptr0 = output; int32_t* outptr0 = output;
int32_t* outptr1; int32_t* outptr1;

asm volatile ( asm volatile (
// load accumulator C // load accumulator C
"add %[outptr1], %[outptr0], %x[LDC]\n" "add %[outptr1], %[outptr0], %x[LDC]\n"
@@ -397,6 +396,7 @@ static void kern_8x12(const int8_t* packA, const int8_t* packB, int K,
// //
// Accumulator // Accumulator


MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_4x12(const int8_t* packA, const int8_t* packB, int K, static void kern_4x12(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k) { int32_t* output, int LDC, bool is_first_k) {
K /= 4; K /= 4;
@@ -543,6 +543,7 @@ static void kern_4x12(const int8_t* packA, const int8_t* packB, int K,
// +--------+--------+ - - - - +------------+ // +--------+--------+ - - - - +------------+
// Accumulator // Accumulator


MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int n_remain) { int32_t* output, int LDC, bool is_first_k, int n_remain) {
K /= 4; K /= 4;
@@ -718,6 +719,7 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
// +--------+--------+ - - - - +------------+ // +--------+--------+ - - - - +------------+
// Accumulator // Accumulator


MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int n_remain) { int32_t* output, int LDC, bool is_first_k, int n_remain) {
K /= 4; K /= 4;
@@ -928,6 +930,5 @@ static void gemm_mk4_s8_8x12_pack_B(dt_int8* out, const dt_int8* in, int ldin,
} // namespace matmul_mk4_8x12x4 } // namespace matmul_mk4_8x12x4
} // namespace aarch64 } // namespace aarch64
} // namespace megdnn } // namespace megdnn

#endif #endif
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 1
- 1
dnn/src/aarch64/matrix_mul/int8_dot/strategy.cpp View File

@@ -10,13 +10,13 @@
*/ */


#include "src/aarch64/matrix_mul/int8_dot/strategy.h" #include "src/aarch64/matrix_mul/int8_dot/strategy.h"
#if MGB_ENABLE_DOT
#include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h" #include "src/aarch64/matrix_mul/int8_dot/kernel_8x12x4.h"
#include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h" #include "src/aarch64/matrix_mul/int8_dot/kernel_mk4_8x12x4.h"


#if __ARM_FEATURE_DOTPROD
using namespace megdnn; using namespace megdnn;
using namespace aarch64; using namespace aarch64;
using namespace aarch64::matmul; using namespace aarch64::matmul;


+ 1
- 1
dnn/src/aarch64/matrix_mul/int8_dot/strategy.h View File

@@ -11,7 +11,7 @@
#pragma once #pragma once
#include "src/fallback/matrix_mul/gemm_common.h" #include "src/fallback/matrix_mul/gemm_common.h"


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
namespace megdnn { namespace megdnn {
namespace aarch64 { namespace aarch64 {
namespace matmul { namespace matmul {


+ 8
- 12
dnn/src/aarch64/matrix_mul/opr_impl.cpp View File

@@ -27,14 +27,13 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoF16K8x24x1 f16_k8x24x1; AlgoF16K8x24x1 f16_k8x24x1;
AlgoF16MK8_8x8 f16_mk8_8x8; AlgoF16MK8_8x8 f16_mk8_8x8;
#endif #endif
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod; AlgoInt8x8x32K8x12x4DotProd int8x8x32_k8x12x4_dotprod;
AlgoInt8x8x32MK4_8x12x4DotProd int8x8x32_mk4_8x12x4_dotprod; AlgoInt8x8x32MK4_8x12x4DotProd int8x8x32_mk4_8x12x4_dotprod;
#else
#endif
AlgoInt8x8x32MK4_4x4x16 int8x8x32_mk4_4x4x16; AlgoInt8x8x32MK4_4x4x16 int8x8x32_mk4_4x4x16;
AlgoInt8x8x32K4x4x16 int8x8x32_k4x4x16; AlgoInt8x8x32K4x4x16 int8x8x32_k4x4x16;
AlgoInt8x8x32K8x8x8 int8x8x32_k8x8x8; AlgoInt8x8x32K8x8x8 int8x8x32_k8x8x8;
#endif
AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8; AlgoInt8x8x16K8x8x8 int8x8x16_k8x8x8;
AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16; AlgoInt8x8x16K4x4x16 int8x8x16_k4x4x16;
AlgoInt8x8x16MK4_16x12x4 int8x8x16_mk4_16x12x4; AlgoInt8x8x16MK4_16x12x4 int8x8x16_mk4_16x12x4;
@@ -44,12 +43,11 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1; AlgoInt16x16x32K12x8x1 int16x16x32_k12x8x1;
AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8; AlgoInt16x16x32MK8_8x8 int16x16x32_mk8_8x8;


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
AlgoQuint8K8x8x4DotProd quint8_k8x8x4_dotprod; AlgoQuint8K8x8x4DotProd quint8_k8x8x4_dotprod;
AlgoQuint8GemvDotProd quint8_gemv_dotprod; AlgoQuint8GemvDotProd quint8_gemv_dotprod;
#else
AlgoQuint8K8x8x8 quint8_k8x8x8;
#endif #endif
AlgoQuint8K8x8x8 quint8_k8x8x8;
AlgoInt4x4x16K8x8x8 int4x4x16_k8x8x8; AlgoInt4x4x16K8x8x8 int4x4x16_k8x8x8;


SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos; SmallVector<fallback::MatrixMulImpl::AlgoBase*> m_all_algos;
@@ -66,14 +64,13 @@ public:
m_all_algos.emplace_back(&f16_k8x24x1); m_all_algos.emplace_back(&f16_k8x24x1);
m_all_algos.emplace_back(&f16_mk8_8x8); m_all_algos.emplace_back(&f16_mk8_8x8);
#endif #endif
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
m_all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod); m_all_algos.emplace_back(&int8x8x32_k8x12x4_dotprod);
m_all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod); m_all_algos.emplace_back(&int8x8x32_mk4_8x12x4_dotprod);
#else
#endif
m_all_algos.emplace_back(&int8x8x32_k4x4x16); m_all_algos.emplace_back(&int8x8x32_k4x4x16);
m_all_algos.emplace_back(&int8x8x32_k8x8x8); m_all_algos.emplace_back(&int8x8x32_k8x8x8);
m_all_algos.emplace_back(&int8x8x32_mk4_4x4x16); m_all_algos.emplace_back(&int8x8x32_mk4_4x4x16);
#endif
m_all_algos.emplace_back(&int8x8x16_k4x4x16); m_all_algos.emplace_back(&int8x8x16_k4x4x16);
m_all_algos.emplace_back(&int8x8x16_k8x8x8); m_all_algos.emplace_back(&int8x8x16_k8x8x8);
m_all_algos.emplace_back(&int8x8x16_mk4_k8x8x8); m_all_algos.emplace_back(&int8x8x16_mk4_k8x8x8);
@@ -82,12 +79,11 @@ public:


m_all_algos.emplace_back(&int16x16x32_k12x8x1); m_all_algos.emplace_back(&int16x16x32_k12x8x1);
m_all_algos.emplace_back(&int16x16x32_mk8_8x8); m_all_algos.emplace_back(&int16x16x32_mk8_8x8);
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
m_all_algos.emplace_back(&quint8_gemv_dotprod); m_all_algos.emplace_back(&quint8_gemv_dotprod);
m_all_algos.emplace_back(&quint8_k8x8x4_dotprod); m_all_algos.emplace_back(&quint8_k8x8x4_dotprod);
#else
m_all_algos.emplace_back(&quint8_k8x8x8);
#endif #endif
m_all_algos.emplace_back(&quint8_k8x8x8);
m_all_algos.emplace_back(&int4x4x16_k8x8x8); m_all_algos.emplace_back(&int4x4x16_k8x8x8);


for (auto&& algo : m_all_algos) { for (auto&& algo : m_all_algos) {


+ 4
- 6
dnn/src/aarch64/matrix_mul/opr_impl.h View File

@@ -41,16 +41,15 @@ private:
class AlgoF16MK8_8x8; // Aarch64 F16 Format MK8 block 16x8 class AlgoF16MK8_8x8; // Aarch64 F16 Format MK8 block 16x8
#endif #endif


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel class AlgoInt8x8x32K8x12x4DotProd; // Aarch64 Int8x8x32 Kernel
// 8x12x4 DotProduct // 8x12x4 DotProduct
class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel class AlgoInt8x8x32MK4_8x12x4DotProd; // Aarch64 nchw44 Int8x8x32 Kernel
// 8x12x4 DotProduct // 8x12x4 DotProduct
#else
#endif
class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16 class AlgoInt8x8x32MK4_4x4x16; // Aarch64 nchw44 Int8x8x32 Kernel 4x4x16
class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16 class AlgoInt8x8x32K4x4x16; // Aarch64 Int8x8x32 Kernel 4x4x16
class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8 class AlgoInt8x8x32K8x8x8; // Aarch64 Int8x8x32 Kernel 8x8x8
#endif
class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8 class AlgoInt8x8x16K8x8x8; // Aarch64 Int8x8x16 Kernel 8x8x8
class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16 class AlgoInt8x8x16K4x4x16; // Aarch64 Int8x8x16 Kernel 4x4x16
class AlgoInt8x8x16MK4_16x12x4; // Aarch64 Int8x8x16 Kernel 16x12x16 class AlgoInt8x8x16MK4_16x12x4; // Aarch64 Int8x8x16 Kernel 16x12x16
@@ -59,13 +58,12 @@ private:
class AlgoInt16x16x32K12x8x1; // Aarch64 Int16x16x32 Kernel 12x8x1 class AlgoInt16x16x32K12x8x1; // Aarch64 Int16x16x32 Kernel 12x8x1
class AlgoInt16x16x32MK8_8x8; // Aarch64 Int16x16x32 Format MK8 block 8x8 class AlgoInt16x16x32MK8_8x8; // Aarch64 Int16x16x32 Format MK8 block 8x8


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
class AlgoQuint8K8x8x4DotProd; // Aarch64 Quint8 Kernel class AlgoQuint8K8x8x4DotProd; // Aarch64 Quint8 Kernel
// 8x8x4 DotProduct // 8x8x4 DotProduct
class AlgoQuint8GemvDotProd; // Aarch64 Quint8 Gemv DotProduct class AlgoQuint8GemvDotProd; // Aarch64 Quint8 Gemv DotProduct
#else
class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8
#endif #endif
class AlgoQuint8K8x8x8; // Aarch64 Quint8 Kernel 8x8x8
class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int8x8x16 Kernel 4x4x16 class AlgoInt8x8x16MK4_K8x8x8; // Aarch64 Int8x8x16 Kernel 4x4x16
class AlgoInt4x4x16K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16 class AlgoInt4x4x16K8x8x8; // Aarch64 Int4x4x16 Kernel 4x4x16
class AlgoPack; class AlgoPack;


+ 0
- 2
dnn/src/aarch64/matrix_mul/quint8/kernel_8x8x8.h View File

@@ -9,7 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#if !(__ARM_FEATURE_DOTPROD)
#include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"


@@ -1395,4 +1394,3 @@ static void gemm_u8_8x8_transpose_pack_B_n(dt_uint8* outptr,
} // namespace megdnn } // namespace megdnn


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
#endif

+ 0
- 2
dnn/src/aarch64/matrix_mul/quint8/strategy.cpp View File

@@ -9,7 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#if !(__ARM_FEATURE_DOTPROD)
#include "src/aarch64/matrix_mul/quint8/strategy.h" #include "src/aarch64/matrix_mul/quint8/strategy.h"
#include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/asm/common.h"
#include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h" #include "src/aarch64/matrix_mul/quint8/kernel_8x8x8.h"
@@ -108,6 +107,5 @@ void gemm_u8_8x8::kern(const dt_uint8* packA, const dt_uint8* packB, size_t M,
packA += K4; packA += K4;
} }
} }
#endif


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 0
- 2
dnn/src/aarch64/matrix_mul/quint8/strategy.h View File

@@ -10,7 +10,6 @@
*/ */
#pragma once #pragma once


#if !(__ARM_FEATURE_DOTPROD)
#include "src/fallback/matrix_mul/gemm_common.h" #include "src/fallback/matrix_mul/gemm_common.h"


namespace megdnn { namespace megdnn {
@@ -23,6 +22,5 @@ MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 8, 8, 8, false, true,
} // namespace matmul } // namespace matmul
} // namespace aarch64 } // namespace aarch64
} // namespace megdnn } // namespace megdnn
#endif


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 2
- 7
dnn/src/aarch64/matrix_mul/quint8_dot/gemv.cpp View File

@@ -10,15 +10,13 @@
*/ */


#include "src/aarch64/matrix_mul/quint8_dot/gemv.h" #include "src/aarch64/matrix_mul/quint8_dot/gemv.h"
#include <cstddef>
#if MGB_ENABLE_DOT
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"


#if __ARM_FEATURE_DOTPROD

namespace { namespace {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void gemv_naive_n(const uint8_t* __restrict A, const uint8_t* __restrict B, void gemv_naive_n(const uint8_t* __restrict A, const uint8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K, int32_t* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride, size_t Astride, size_t Bstride, size_t Cstride,
@@ -146,7 +144,6 @@ void gemv_naive_n(const uint8_t* __restrict A, const uint8_t* __restrict B,
acc[0] + acc[1] + acc[2] + acc[3] + zAB - acc_zA - acc_zB; acc[0] + acc[1] + acc[2] + acc[3] + zAB - acc_zA - acc_zB;
} }
} }

} // namespace } // namespace


bool megdnn::aarch64::matmul::is_gemv_like_preferred_quint8( bool megdnn::aarch64::matmul::is_gemv_like_preferred_quint8(
@@ -171,7 +168,5 @@ void megdnn::aarch64::matmul::gemv_like_quint8(
return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride, return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride,
zero_point_A, zero_point_B); zero_point_A, zero_point_B);
} }

#endif #endif

// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 2
- 3
dnn/src/aarch64/matrix_mul/quint8_dot/gemv.h View File

@@ -10,10 +10,9 @@
*/ */
#pragma once #pragma once


#include <cstddef>
#include <cstdint>
#include "src/common/utils.h"


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
namespace megdnn { namespace megdnn {
namespace aarch64 { namespace aarch64 {
namespace matmul { namespace matmul {


+ 5
- 3
dnn/src/aarch64/matrix_mul/quint8_dot/kernel_8x8x4.h View File

@@ -9,8 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#if __ARM_FEATURE_DOTPROD

#if MGB_ENABLE_DOT
#include "src/aarch64/matrix_mul/asm/common.h" #include "src/aarch64/matrix_mul/asm/common.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"


@@ -56,7 +55,7 @@ namespace matmul_8x8x4 {
// C = sum((A - zA) * (B - zB)) = sum(A * B) - sum(A) * zB - sum(B) * zA + zA * // C = sum((A - zA) * (B - zB)) = sum(A * B) - sum(A) * zB - sum(B) * zA + zA *
// zB * k // zB * k
// A -> v27, v28 | B -> v29, v30 | zA * zB * k -> v26 // A -> v27, v28 | B -> v29, v30 | zA * zB * k -> v26
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x8(const uint8_t* packA, const uint8_t* packB, int K, static void kern_8x8(const uint8_t* packA, const uint8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int32_t* output, int LDC, bool is_first_k,
uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) { uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) {
@@ -293,6 +292,7 @@ static void kern_8x8(const uint8_t* packA, const uint8_t* packB, int K,
// zB * k // zB * k
// A -> v28 | B -> v29, v30 | zA * zB * k -> v26 // A -> v28 | B -> v29, v30 | zA * zB * k -> v26


MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int m_remain, int32_t* output, int LDC, bool is_first_k, int m_remain,
uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) { uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) {
@@ -495,6 +495,7 @@ static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K,
// zB * k // zB * k
// A -> v27, v28 | B -> v29 | zA * zB * k -> v26 // A -> v27, v28 | B -> v29 | zA * zB * k -> v26


MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x4(const uint8_t* packA, const uint8_t* packB, int K, static void kern_8x4(const uint8_t* packA, const uint8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int n_remain, int32_t* output, int LDC, bool is_first_k, int n_remain,
uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) { uint8_t zero_point_A, uint8_t zero_point_B, uint32_t zAB) {
@@ -733,6 +734,7 @@ static void kern_8x4(const uint8_t* packA, const uint8_t* packB, int K,
// zB * k // zB * k
// A -> v28 | B -> v29 | zA * zB * k -> v26 // A -> v28 | B -> v29 | zA * zB * k -> v26


MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int m_remain, int32_t* output, int LDC, bool is_first_k, int m_remain,
int n_remain, uint8_t zero_point_A, uint8_t zero_point_B, int n_remain, uint8_t zero_point_A, uint8_t zero_point_B,


+ 5
- 5
dnn/src/aarch64/matrix_mul/quint8_dot/strategy.cpp View File

@@ -16,14 +16,14 @@
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h" #include "src/common/utils.h"


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
using namespace megdnn; using namespace megdnn;
using namespace aarch64; using namespace aarch64;
using namespace aarch64::matmul; using namespace aarch64::matmul;


MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8);
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_u8_8x8_dot);


void gemm_u8_8x8::pack_A(uint8_t* outptr, const uint8_t* inptr, int ldin,
void gemm_u8_8x8_dot::pack_A(uint8_t* outptr, const uint8_t* inptr, int ldin,
int y0, int ymax, int k0, int kmax, int y0, int ymax, int k0, int kmax,
bool transpose) const { bool transpose) const {
if (transpose) { if (transpose) {
@@ -35,7 +35,7 @@ void gemm_u8_8x8::pack_A(uint8_t* outptr, const uint8_t* inptr, int ldin,
} }
} }


void gemm_u8_8x8::pack_B(uint8_t* out, const uint8_t* in, int ldin, int x0,
void gemm_u8_8x8_dot::pack_B(uint8_t* out, const uint8_t* in, int ldin, int x0,
int xmax, int k0, int kmax, bool transpose) const { int xmax, int k0, int kmax, bool transpose) const {
if (transpose) { if (transpose) {
matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0, matmul_8x8x4::gemm_u8_8x8_interleave_pack_helper(out, in, ldin, x0,
@@ -46,7 +46,7 @@ void gemm_u8_8x8::pack_B(uint8_t* out, const uint8_t* in, int ldin, int x0,
} }
} }


void gemm_u8_8x8::kern(const uint8_t* packA, const uint8_t* packB, size_t M,
void gemm_u8_8x8_dot::kern(const uint8_t* packA, const uint8_t* packB, size_t M,
size_t N, size_t K, dt_int32* C, size_t LDC, size_t N, size_t K, dt_int32* C, size_t LDC,
bool is_first_k, const dt_int32*, dt_int32*) const { bool is_first_k, const dt_int32*, dt_int32*) const {
megdnn_assert(A_dtype.enumv() == B_dtype.enumv() && megdnn_assert(A_dtype.enumv() == B_dtype.enumv() &&


+ 2
- 2
dnn/src/aarch64/matrix_mul/quint8_dot/strategy.h View File

@@ -11,13 +11,13 @@
#pragma once #pragma once
#include "src/fallback/matrix_mul/gemm_common.h" #include "src/fallback/matrix_mul/gemm_common.h"


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
namespace megdnn { namespace megdnn {
namespace aarch64 { namespace aarch64 {
namespace matmul { namespace matmul {


MEGDNN_REG_GEMM_STRATEGY(uint8_t, int32_t, int32_t, 8, 8, 4, false, true, MEGDNN_REG_GEMM_STRATEGY(uint8_t, int32_t, int32_t, 8, 8, 4, false, true,
gemm_u8_8x8);
gemm_u8_8x8_dot);


} // namespace aarch64 } // namespace aarch64
} // namespace matmul } // namespace matmul


+ 0
- 3
dnn/src/arm_common/conv_bias/fp32/direct_kernels/f32_direct_nchw_nchw44_kern_common.h View File

@@ -23,9 +23,6 @@
#include "src/armv7/matrix_mul/asm/common.h" #include "src/armv7/matrix_mul/asm/common.h"
#endif #endif


#if MGB_ENABLE_CPUINFO
#include "cpuinfo.h"
#endif


using namespace megdnn; using namespace megdnn;
using namespace arm_common; using namespace arm_common;


+ 7
- 1
dnn/src/arm_common/conv_bias/int8/algos.cpp View File

@@ -161,10 +161,13 @@ ConvBiasImpl::AlgoS8DirectStride2::dispatch_kerns(
return {}; return {};
} }


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
/* ===================== dot stride1 algo ======================== */ /* ===================== dot stride1 algo ======================== */
bool ConvBiasImpl::AlgoDotS8DirectStride1::usable(const NCBKernSizeParam& param, bool ConvBiasImpl::AlgoDotS8DirectStride1::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const { AlgoSelectionStrategy) const {
if (!cpuinfo_has_arm_neon_dot()) {
return false;
}
return direct_dotprod_int8_stride1::can_conv_direct_stride1_int8(param); return direct_dotprod_int8_stride1::can_conv_direct_stride1_int8(param);
} }


@@ -195,6 +198,9 @@ ConvBiasImpl::AlgoDotS8DirectStride1::dispatch_kerns(
/* ===================== dot stride2 algo ======================== */ /* ===================== dot stride2 algo ======================== */
bool ConvBiasImpl::AlgoDotS8DirectStride2::usable(const NCBKernSizeParam& param, bool ConvBiasImpl::AlgoDotS8DirectStride2::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const { AlgoSelectionStrategy) const {
if (!cpuinfo_has_arm_neon_dot()){
return false;
}
return direct_dotprod_int8_stride2::can_conv_direct_stride2_int8(param); return direct_dotprod_int8_stride2::can_conv_direct_stride2_int8(param);
} }




+ 1
- 1
dnn/src/arm_common/conv_bias/int8/algos.h View File

@@ -129,7 +129,7 @@ public:
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHANWISE_STRD2_NCHW44_S8) MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_CHANWISE_STRD2_NCHW44_S8)
}; };


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT


class ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44 final : public AlgoBase {
public: public:


+ 10
- 1
dnn/src/arm_common/conv_bias/int8/direct_dotprod.cpp View File

@@ -9,8 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/direct_dotprod.h" #include "src/arm_common/conv_bias/int8/direct_dotprod.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h" #include "src/common/utils.h"
@@ -90,6 +90,7 @@ inline int8x16_t vqtbl1q_s8_v7(int8x16_t a, uint8x16_t index) {
_sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem);


template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_bias::conv_direct_stride1_2x2_int8_dot( void conv_bias::conv_direct_stride1_2x2_int8_dot(
const int8_t* src, const int8_t* filter, const int32_t* bias, const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, int32_t* temp, int8_t* dst, const size_t IH, const size_t IW,
@@ -325,6 +326,7 @@ void conv_bias::conv_direct_stride1_2x2_int8_dot(
} }


template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_bias::conv_direct_stride1_3x3_int8_dot( void conv_bias::conv_direct_stride1_3x3_int8_dot(
const int8_t* src, const int8_t* filter, const int32_t* bias, const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, int32_t* temp, int8_t* dst, const size_t IH, const size_t IW,
@@ -560,6 +562,7 @@ void conv_bias::conv_direct_stride1_3x3_int8_dot(
} }


template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_bias::conv_direct_stride2_2x2_int8_dot( void conv_bias::conv_direct_stride2_2x2_int8_dot(
const int8_t* src, const int8_t* filter, const int32_t* bias, const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, int32_t* temp, int8_t* dst, const size_t IH, const size_t IW,
@@ -655,6 +658,7 @@ void conv_bias::conv_direct_stride2_2x2_int8_dot(
} }


template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_bias::conv_direct_stride2_3x3_int8_dot( void conv_bias::conv_direct_stride2_3x3_int8_dot(
const int8_t* src, const int8_t* filter, const int32_t* bias, const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, int32_t* temp, int8_t* dst, const size_t IH, const size_t IW,
@@ -810,6 +814,7 @@ void conv_bias::conv_direct_stride2_3x3_int8_dot(
_sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem);


template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_bias::conv_direct_stride2_5x5_int8_dot( void conv_bias::conv_direct_stride2_5x5_int8_dot(
const int8_t* src, const int8_t* filter, const int32_t* bias, const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, int32_t* temp, int8_t* dst, const size_t IH, const size_t IW,
@@ -1108,6 +1113,7 @@ void conv_bias::conv_direct_stride2_5x5_int8_dot(
} }


template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_bias::conv_direct_stride2_7x7_int8_dot( void conv_bias::conv_direct_stride2_7x7_int8_dot(
const int8_t* src, const int8_t* filter, const int32_t* bias, const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, int32_t* temp, int8_t* dst, const size_t IH, const size_t IW,
@@ -1470,6 +1476,7 @@ void conv_bias::conv_direct_stride2_7x7_int8_dot(
} }


template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_bias::conv_direct_stride1_5x5_int8_dot( void conv_bias::conv_direct_stride1_5x5_int8_dot(
const int8_t* src, const int8_t* filter, const int32_t* bias, const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, int32_t* temp, int8_t* dst, const size_t IH, const size_t IW,
@@ -1770,6 +1777,7 @@ void conv_bias::conv_direct_stride1_5x5_int8_dot(
} }


template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op> template <bool first_ic, bool last_ic, BiasMode bias_mode, typename Op>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_bias::conv_direct_stride1_7x7_int8_dot( void conv_bias::conv_direct_stride1_7x7_int8_dot(
const int8_t* src, const int8_t* filter, const int32_t* bias, const int8_t* src, const int8_t* filter, const int32_t* bias,
int32_t* temp, int8_t* dst, const size_t IH, const size_t IW, int32_t* temp, int8_t* dst, const size_t IH, const size_t IW,
@@ -2115,6 +2123,7 @@ void conv_bias::conv_direct_stride1_7x7_int8_dot(
#undef ST1_S32X4 #undef ST1_S32X4
#undef ST2_S32X4X2 #undef ST2_S32X4X2



#define INSTANTIATION(stride, i, first_ic, last_ic, bias, Op) \ #define INSTANTIATION(stride, i, first_ic, last_ic, bias, Op) \
template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_dot< \ template void conv_bias::conv_direct_##stride##_##i##x##i##_int8_dot< \
first_ic, last_ic, bias, Op>( \ first_ic, last_ic, bias, Op>( \


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/direct_dotprod.h View File

@@ -8,8 +8,8 @@
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h"
#if MGB_ENABLE_DOT
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"


namespace megdnn { namespace megdnn {


+ 1
- 2
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.cpp View File

@@ -10,9 +10,8 @@
* implied. * implied.
*/ */


#ifdef __ARM_FEATURE_DOTPROD

#include "src/arm_common/elemwise_helper/kimpl/typecvt.h" #include "src/arm_common/elemwise_helper/kimpl/typecvt.h"
#if MGB_ENABLE_DOT
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"


+ 2
- 3
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h View File

@@ -10,11 +10,10 @@
* implied. * implied.
*/ */


#if __ARM_FEATURE_DOTPROD

#pragma once #pragma once


#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h"
#if MGB_ENABLE_DOT


namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {
@@ -78,4 +77,4 @@ void copy_packed_src_int8_nchw44(int8_t* dst, const int dst_step,


#endif #endif


// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 4
- 2
dnn/src/arm_common/conv_bias/int8/direct_dotprod_nchw44_algo.cpp View File

@@ -10,9 +10,8 @@
* implied. * implied.
*/ */


#if __ARM_FEATURE_DOTPROD

#include "src/arm_common/conv_bias/block_helper.h" #include "src/arm_common/conv_bias/block_helper.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h" #include "src/arm_common/conv_bias/int8/direct_dotprod_nchw44.h"
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"
@@ -159,6 +158,9 @@ static void conv_kern(const WorkspaceBundle& bundle,
bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::usable( bool ConvBiasImpl::AlgoDotS8Direct_NCHW44::usable(
const NCBKernSizeParam& param, const NCBKernSizeParam& param,
AlgoSelectionStrategy algo_selection_strategy) const { AlgoSelectionStrategy algo_selection_strategy) const {
if (!cpuinfo_has_arm_neon_dot()){
return false;
}
MEGDNN_MARK_USED_VAR(algo_selection_strategy); MEGDNN_MARK_USED_VAR(algo_selection_strategy);
auto&& fm = param.filter_meta; auto&& fm = param.filter_meta;
auto FH = fm.spatial[0]; auto FH = fm.spatial[0];


+ 4
- 2
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h View File

@@ -11,9 +11,9 @@
* implied. * implied.
*/ */
#pragma once #pragma once
#if __ARM_FEATURE_DOTPROD
#include "megdnn/arch.h" #include "megdnn/arch.h"
#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"
#include "src/arm_common/intrinsic_helper.h" #include "src/arm_common/intrinsic_helper.h"
#include "src/arm_common/neon_struct.h" #include "src/arm_common/neon_struct.h"
@@ -208,6 +208,7 @@ MEGDNN_ALWAYS_INLINE void store_ocx_owx_remain_static(int32x4_t res[][8],
template <int res_row, int src_row, int src_start_idx, int weight_idx, template <int res_row, int src_row, int src_start_idx, int weight_idx,
typename T, typename T2, typename T3> typename T, typename T2, typename T3>
struct ShiftCalHelper { struct ShiftCalHelper {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static MEGDNN_ALWAYS_INLINE void impl(T& res, T2& src, T3& weight) { static MEGDNN_ALWAYS_INLINE void impl(T& res, T2& src, T3& weight) {
#define cb(step) \ #define cb(step) \
res[res_row][step] = \ res[res_row][step] = \
@@ -221,6 +222,7 @@ struct ShiftCalHelper {


template <int res_row, int src_row, int src_start_idx, int weight_idx, template <int res_row, int src_row, int src_start_idx, int weight_idx,
typename T, typename T2, typename T3> typename T, typename T2, typename T3>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
MEGDNN_ALWAYS_INLINE void cal_helper(T& res, T2& src, T3& weight) { MEGDNN_ALWAYS_INLINE void cal_helper(T& res, T2& src, T3& weight) {
ShiftCalHelper<res_row, src_row, src_start_idx, weight_idx, T, T2, ShiftCalHelper<res_row, src_row, src_start_idx, weight_idx, T, T2,
T3>::impl(res, src, weight); T3>::impl(res, src, weight);
@@ -242,4 +244,4 @@ struct KernNeonSdotNCHW44 {
} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn
#endif #endif
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 4
- 2
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s1.cpp View File

@@ -10,8 +10,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. * implied.
*/ */
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h" #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h"
#if MGB_ENABLE_DOT


namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {
@@ -20,6 +20,7 @@ template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain,
int filter_size, int oc_interval, int ow_interval> int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44<dst_type, 1, bias_mode, Op, ow_remain, filter_size, struct KernNeonSdotNCHW44<dst_type, 1, bias_mode, Op, ow_remain, filter_size,
oc_interval, ow_interval> { oc_interval, ow_interval> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void impl(dst_type* dst, const int dst_step, const int8_t* src, static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter, const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op) { const int32_t* bias, const int ic, const Op& op) {
@@ -109,6 +110,7 @@ struct KernNeonSdotNCHW44<dst_type, 1, bias_mode, Op, ow_remain, filter_size,


template <typename dst_type, int stride, BiasMode bias_mode, typename Op, template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int filter_size> int filter_size>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow,
const int8_t* src, const int ih, const int iw, const int8_t* src, const int ih, const int iw,
const int8_t* filter, const int32_t* bias, const int8_t* filter, const int32_t* bias,
@@ -317,4 +319,4 @@ FOR_FILTER(1)
} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn
#endif #endif
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 4
- 2
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_s2.cpp View File

@@ -10,9 +10,9 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. * implied.
*/ */
#if __ARM_FEATURE_DOTPROD


#include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h" #include "src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw44_common.h"
#if MGB_ENABLE_DOT
namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {
namespace direct_dotprod_nchw44 { namespace direct_dotprod_nchw44 {
@@ -20,6 +20,7 @@ template <typename dst_type, BiasMode bias_mode, typename Op, int ow_remain,
int filter_size, int oc_interval, int ow_interval> int filter_size, int oc_interval, int ow_interval>
struct KernNeonSdotNCHW44<dst_type, 2, bias_mode, Op, ow_remain, filter_size, struct KernNeonSdotNCHW44<dst_type, 2, bias_mode, Op, ow_remain, filter_size,
oc_interval, ow_interval> { oc_interval, ow_interval> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void impl(dst_type* dst, const int dst_step, const int8_t* src, static void impl(dst_type* dst, const int dst_step, const int8_t* src,
const int ih, const int iw, const int8_t* filter, const int ih, const int iw, const int8_t* filter,
const int32_t* bias, const int ic, const Op& op) { const int32_t* bias, const int ic, const Op& op) {
@@ -110,6 +111,7 @@ struct KernNeonSdotNCHW44<dst_type, 2, bias_mode, Op, ow_remain, filter_size,


template <typename dst_type, int stride, BiasMode bias_mode, typename Op, template <typename dst_type, int stride, BiasMode bias_mode, typename Op,
int filter_size> int filter_size>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow, void conv_direct_sdot_int8_nchw44(dst_type* dst, const int oh, const int ow,
const int8_t* src, const int ih, const int iw, const int8_t* src, const int ih, const int iw,
const int8_t* filter, const int32_t* bias, const int8_t* filter, const int32_t* bias,
@@ -319,4 +321,4 @@ FOR_FILTER(2)
} // namespace megdnn } // namespace megdnn


#endif #endif
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 9
- 2
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s1.cpp View File

@@ -11,8 +11,8 @@
* implied. * implied.
*/ */


#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h"
#if MGB_ENABLE_DOT
namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {
namespace dot_direct_nchw_nchw44 { namespace dot_direct_nchw_nchw44 {
@@ -20,6 +20,7 @@ namespace dot_direct_nchw_nchw44 {
template <int src_idx, int weight_idx, typename Func, typename T, typename T2, template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4> typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 1, T, T2, T3, T4> { struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 1, T, T2, T3, T4> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void impl(T& c, T2& src, T3& weight) { static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \ #define cb(step) \
c[0][step] = Func::template impl<(src_idx + step) % 4>( \ c[0][step] = Func::template impl<(src_idx + step) % 4>( \
@@ -35,6 +36,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 1, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename Func, typename T, typename T2, template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4> typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 1, T, T2, T3, T4> { struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 1, T, T2, T3, T4> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void impl(T& c, T2& src, T3& weight) { static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \ #define cb(step) \
c[0][step] = Func::template impl<(src_idx + step) % 4>( \ c[0][step] = Func::template impl<(src_idx + step) % 4>( \
@@ -49,6 +51,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block> int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block,
1> { 1> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) { int iw, int ld_dst_oc, const Op& op) {
@@ -97,6 +100,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block> int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block, struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block,
1> { 1> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) { int iw, int ld_dst_oc, const Op& op) {
@@ -151,6 +155,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block> int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block, struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block,
1> { 1> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) { int iw, int ld_dst_oc, const Op& op) {
@@ -200,6 +205,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block> int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block, struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block,
1> { 1> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) { int iw, int ld_dst_oc, const Op& op) {
@@ -302,6 +308,7 @@ void pack_src_int8_nchw_nchw44_dot<1>(int8_t* sptr_base,
} }


template <BiasMode bias_mode, typename Op, int filter_size, int stride> template <BiasMode bias_mode, typename Op, int filter_size, int stride>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp, const int32_t* bias, int32_t* temp,
int8_t* dst, const int oc, const int ic, int8_t* dst, const int oc, const int ic,
@@ -445,4 +452,4 @@ DISPATCH_CONV_KERN(1);
} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn
#endif #endif
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 9
- 2
dnn/src/arm_common/conv_bias/int8/direct_kernels/dot_direct_nchw_nchw44_s2.cpp View File

@@ -10,8 +10,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied. * implied.
*/ */
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h"
#if MGB_ENABLE_DOT
namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {
namespace dot_direct_nchw_nchw44 { namespace dot_direct_nchw_nchw44 {
@@ -19,6 +19,7 @@ namespace dot_direct_nchw_nchw44 {
template <int src_idx, int weight_idx, typename Func, typename T, typename T2, template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4> typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 2, T, T2, T3, T4> { struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 2, T, T2, T3, T4> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void impl(T& c, T2& src, T3& weight) { static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \ #define cb(step) \
c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \
@@ -42,6 +43,7 @@ struct ShiftCalHelper<src_idx, weight_idx, 2, Func, 8, 2, T, T2, T3, T4> {
template <int src_idx, int weight_idx, typename Func, typename T, typename T2, template <int src_idx, int weight_idx, typename Func, typename T, typename T2,
typename T3, typename T4> typename T3, typename T4>
struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 2, T, T2, T3, T4> { struct ShiftCalHelper<src_idx, weight_idx, 1, Func, 8, 2, T, T2, T3, T4> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void impl(T& c, T2& src, T3& weight) { static void impl(T& c, T2& src, T3& weight) {
#define cb(step) \ #define cb(step) \
c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \ c[0][step * 2] = Func::template impl<(src_idx + step) % 4>( \
@@ -60,6 +62,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block> int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block, struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 2, oc_block, ow_block,
2> { 2> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) { int iw, int ld_dst_oc, const Op& op) {
@@ -111,6 +114,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block> int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block, struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 3, oc_block, ow_block,
2> { 2> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) { int iw, int ld_dst_oc, const Op& op) {
@@ -169,6 +173,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block> int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block, struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 5, oc_block, ow_block,
2> { 2> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) { int iw, int ld_dst_oc, const Op& op) {
@@ -224,6 +229,7 @@ template <BiasMode bias_mode, typename Op, int remain_w, int oc_block,
int ow_block> int ow_block>
struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block, struct KerNeonDotXXs2Nchw44Int8<bias_mode, Op, remain_w, 7, oc_block, ow_block,
2> { 2> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void impl(const int8_t* src_ptr, const int8_t* weight_ptr, static void impl(const int8_t* src_ptr, const int8_t* weight_ptr,
const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih, const int32_t* bias_ptr, int8_t* dst_ptr, int ic, int ih,
int iw, int ld_dst_oc, const Op& op) { int iw, int ld_dst_oc, const Op& op) {
@@ -289,6 +295,7 @@ void pack_src_int8_nchw_nchw44_dot<2>(
} }


template <BiasMode bias_mode, typename Op, int filter_size, int stride> template <BiasMode bias_mode, typename Op, int filter_size, int stride>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter, void conv_direct_int8_nchw_nchw44_dot(const int8_t* src, const int8_t* filter,
const int32_t* bias, int32_t* temp, const int32_t* bias, int32_t* temp,
int8_t* dst, const int oc, const int ic, int8_t* dst, const int oc, const int ic,
@@ -434,4 +441,4 @@ DISPATCH_CONV_KERN(2);
} // namespace megdnn } // namespace megdnn


#endif #endif
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 4
- 1
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_algo.cpp View File

@@ -10,8 +10,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express
* or implied. * or implied.
*/ */
#if __ARM_FEATURE_DOTPROD
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/conv_bias/block_helper.h" #include "src/arm_common/conv_bias/block_helper.h"
#include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h" #include "src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h"
@@ -175,6 +175,9 @@ static void do_conv_kern(const WorkspaceBundle& bundle,


bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable( bool ConvBiasImpl::AlgoDotS8DirectNCHWNCHW44::usable(
const NCBKernSizeParam& param, AlgoSelectionStrategy) const { const NCBKernSizeParam& param, AlgoSelectionStrategy) const {
if (!cpuinfo_has_arm_neon_dot()){
return false;
}
return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_DOT>( return nchw_nchwxx_valid<NchwNchwxxType::NCHW44_INT8_DOT>(
param.src_type.enumv(), param.filter_type.enumv(), param.src_type.enumv(), param.filter_type.enumv(),
param.dst_type.enumv(), param.filter_meta, param.bias_mode, param.dst_type.enumv(), param.filter_meta, param.bias_mode,


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/dot_direct_nchw_nchw44_kern.h View File

@@ -11,9 +11,9 @@
* implied. * implied.
*/ */
#pragma once #pragma once
#if __ARM_FEATURE_DOTPROD


#include "src/arm_common/conv_bias/intrinsic_helper.h" #include "src/arm_common/conv_bias/intrinsic_helper.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/stride1_dotprod.cpp View File

@@ -8,9 +8,9 @@
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#if __ARM_FEATURE_DOTPROD


#include "src/arm_common/conv_bias/int8/stride1_dotprod.h" #include "src/arm_common/conv_bias/int8/stride1_dotprod.h"
#if MGB_ENABLE_DOT
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod.h" #include "src/arm_common/conv_bias/int8/direct_dotprod.h"
#include "src/arm_common/conv_bias/int8/strategy.h" #include "src/arm_common/conv_bias/int8/strategy.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/stride1_dotprod.h View File

@@ -8,10 +8,10 @@
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#if __ARM_FEATURE_DOTPROD
#pragma once #pragma once


#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h"
#if MGB_ENABLE_DOT
namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {
namespace direct_dotprod_int8_stride1 { namespace direct_dotprod_int8_stride1 {


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/stride2_dotprod.cpp View File

@@ -9,8 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/int8/stride2_dotprod.h" #include "src/arm_common/conv_bias/int8/stride2_dotprod.h"
#if MGB_ENABLE_DOT
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/int8/direct_dotprod.h" #include "src/arm_common/conv_bias/int8/direct_dotprod.h"
#include "src/arm_common/conv_bias/int8/strategy.h" #include "src/arm_common/conv_bias/int8/strategy.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/int8/stride2_dotprod.h View File

@@ -9,9 +9,9 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#if __ARM_FEATURE_DOTPROD
#pragma once #pragma once
#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h"
#if MGB_ENABLE_DOT


namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {


+ 2
- 2
dnn/src/arm_common/conv_bias/opr_impl.cpp View File

@@ -60,7 +60,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoS8x8x16ChanWiseStride1Stride2NCHW44 AlgoS8x8x16ChanWiseStride1Stride2NCHW44
s8x8x16_channel_wise_stride1_stride2_nchw44; s8x8x16_channel_wise_stride1_stride2_nchw44;


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
AlgoDotS8DirectStride1 ds8_direct_stride1; AlgoDotS8DirectStride1 ds8_direct_stride1;
AlgoDotS8DirectStride2 ds8_direct_stride2; AlgoDotS8DirectStride2 ds8_direct_stride2;
AlgoDotU8DirectStride1 du8_direct_stride1; AlgoDotU8DirectStride1 du8_direct_stride1;
@@ -94,7 +94,7 @@ class ConvBiasImpl::AlgoPack : NonCopyableObj {


public: public:
AlgoPack() { AlgoPack() {
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
m_direct_algos.emplace_back(&ds8_direct_stride1); m_direct_algos.emplace_back(&ds8_direct_stride1);
m_direct_algos.emplace_back(&ds8_direct_stride2); m_direct_algos.emplace_back(&ds8_direct_stride2);
m_direct_algos.emplace_back(&du8_direct_stride1); m_direct_algos.emplace_back(&du8_direct_stride1);


+ 1
- 1
dnn/src/arm_common/conv_bias/opr_impl.h View File

@@ -70,7 +70,7 @@ private:
class AlgoFP16WinogradF63; class AlgoFP16WinogradF63;
class AlgoFP16WinogradF23_8x8; class AlgoFP16WinogradF23_8x8;
#endif #endif
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
class AlgoDotS8DirectNCHWNCHW44; class AlgoDotS8DirectNCHWNCHW44;
class AlgoDotS8DirectStride1; class AlgoDotS8DirectStride1;
class AlgoDotS8DirectStride2; class AlgoDotS8DirectStride2;


+ 9
- 2
dnn/src/arm_common/conv_bias/quint8/algos.cpp View File

@@ -11,7 +11,6 @@
*/ */


#include "src/arm_common/conv_bias/quint8/algos.h" #include "src/arm_common/conv_bias/quint8/algos.h"
#include "midout.h"
#include "src/arm_common/conv_bias/quint8/stride1.h" #include "src/arm_common/conv_bias/quint8/stride1.h"
#include "src/arm_common/conv_bias/quint8/stride1_dotprod.h" #include "src/arm_common/conv_bias/quint8/stride1_dotprod.h"
#include "src/arm_common/conv_bias/quint8/stride2.h" #include "src/arm_common/conv_bias/quint8/stride2.h"
@@ -19,6 +18,8 @@
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"


#include "midout.h"

MIDOUT_DECL(megdnn_arm_common_conv_bias_quint8) MIDOUT_DECL(megdnn_arm_common_conv_bias_quint8)


using namespace megdnn; using namespace megdnn;
@@ -84,10 +85,13 @@ ConvBiasImpl::AlgoQU8DirectStride2::dispatch_kerns(
MIDOUT_END(); MIDOUT_END();
return {}; return {};
} }
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
/* ===================== stride1 algo ===================== */ /* ===================== stride1 algo ===================== */
bool ConvBiasImpl::AlgoDotU8DirectStride1::usable(const NCBKernSizeParam& param, bool ConvBiasImpl::AlgoDotU8DirectStride1::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const { AlgoSelectionStrategy) const {
if (!cpuinfo_has_arm_neon_dot()){
return false;
}
return direct_dotprod_quint8_stride1::can_conv_direct_stride1_quint8(param); return direct_dotprod_quint8_stride1::can_conv_direct_stride1_quint8(param);
} }


@@ -118,6 +122,9 @@ ConvBiasImpl::AlgoDotU8DirectStride1::dispatch_kerns(
/* ===================== stride2 algo ===================== */ /* ===================== stride2 algo ===================== */
bool ConvBiasImpl::AlgoDotU8DirectStride2::usable(const NCBKernSizeParam& param, bool ConvBiasImpl::AlgoDotU8DirectStride2::usable(const NCBKernSizeParam& param,
AlgoSelectionStrategy) const { AlgoSelectionStrategy) const {
if (!cpuinfo_has_arm_neon_dot()){
return false;
}
return direct_dotprod_quint8_stride2::can_conv_direct_stride2_quint8(param); return direct_dotprod_quint8_stride2::can_conv_direct_stride2_quint8(param);
} }




+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/algos.h View File

@@ -55,7 +55,7 @@ public:
} }
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_QU8) MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_DIRECT_STRD2_QU8)
}; };
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase {


public: public:


+ 9
- 1
dnn/src/arm_common/conv_bias/quint8/direct_dotprod.cpp View File

@@ -9,8 +9,8 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/quint8/direct_dotprod.h" #include "src/arm_common/conv_bias/quint8/direct_dotprod.h"
#if MGB_ENABLE_DOT
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/common/utils.h" #include "src/common/utils.h"
@@ -120,6 +120,7 @@ inline int8x16_t vqtbl1q_s8_v7(int8x16_t a, uint8x16_t index){


template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode,
typename Op> typename Op>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_bias::conv_direct_stride1_2x2_quint8_dot( void conv_bias::conv_direct_stride1_2x2_quint8_dot(
const uint8_t* src, const uint8_t* filter, const int32_t* bias, const uint8_t* src, const uint8_t* filter, const int32_t* bias,
int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW,
@@ -452,6 +453,7 @@ void conv_bias::conv_direct_stride1_2x2_quint8_dot(


template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode,
typename Op> typename Op>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_bias::conv_direct_stride1_3x3_quint8_dot( void conv_bias::conv_direct_stride1_3x3_quint8_dot(
const uint8_t* src, const uint8_t* filter, const int32_t* bias, const uint8_t* src, const uint8_t* filter, const int32_t* bias,
int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW,
@@ -691,6 +693,7 @@ void conv_bias::conv_direct_stride1_3x3_quint8_dot(


template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode,
typename Op> typename Op>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_bias::conv_direct_stride2_2x2_quint8_dot( void conv_bias::conv_direct_stride2_2x2_quint8_dot(
const uint8_t* src, const uint8_t* filter, const int32_t* bias, const uint8_t* src, const uint8_t* filter, const int32_t* bias,
int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW,
@@ -801,6 +804,7 @@ void conv_bias::conv_direct_stride2_2x2_quint8_dot(


template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode,
typename Op> typename Op>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_bias::conv_direct_stride2_3x3_quint8_dot( void conv_bias::conv_direct_stride2_3x3_quint8_dot(
const uint8_t* src, const uint8_t* filter, const int32_t* bias, const uint8_t* src, const uint8_t* filter, const int32_t* bias,
int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW,
@@ -1135,6 +1139,7 @@ void conv_bias::conv_direct_stride2_3x3_quint8_dot(


template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode,
typename Op> typename Op>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_bias::conv_direct_stride1_5x5_quint8_dot( void conv_bias::conv_direct_stride1_5x5_quint8_dot(
const uint8_t* src, const uint8_t* filter, const int32_t* bias, const uint8_t* src, const uint8_t* filter, const int32_t* bias,
int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW,
@@ -1443,6 +1448,7 @@ void conv_bias::conv_direct_stride1_5x5_quint8_dot(


template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode,
typename Op> typename Op>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_bias::conv_direct_stride1_7x7_quint8_dot( void conv_bias::conv_direct_stride1_7x7_quint8_dot(
const uint8_t* src, const uint8_t* filter, const int32_t* bias, const uint8_t* src, const uint8_t* filter, const int32_t* bias,
int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW,
@@ -1785,6 +1791,7 @@ void conv_bias::conv_direct_stride1_7x7_quint8_dot(


template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode,
typename Op> typename Op>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_bias::conv_direct_stride2_5x5_quint8_dot( void conv_bias::conv_direct_stride2_5x5_quint8_dot(
const uint8_t* src, const uint8_t* filter, const int32_t* bias, const uint8_t* src, const uint8_t* filter, const int32_t* bias,
int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW,
@@ -2090,6 +2097,7 @@ void conv_bias::conv_direct_stride2_5x5_quint8_dot(


template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode, template <bool first_ic, bool last_ic, bool fused_kern, BiasMode bias_mode,
typename Op> typename Op>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void conv_bias::conv_direct_stride2_7x7_quint8_dot( void conv_bias::conv_direct_stride2_7x7_quint8_dot(
const uint8_t* src, const uint8_t* filter, const int32_t* bias, const uint8_t* src, const uint8_t* filter, const int32_t* bias,
int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW, int32_t* temp, uint8_t* dst, const size_t IH, const size_t IW,


+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/direct_dotprod.h View File

@@ -8,9 +8,9 @@
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#if __ARM_FEATURE_DOTPROD


#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h"
#if MGB_ENABLE_DOT
#include "src/fallback/conv_bias/common.h" #include "src/fallback/conv_bias/common.h"


namespace megdnn { namespace megdnn {


+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.cpp View File

@@ -8,8 +8,8 @@
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/quint8/stride1_dotprod.h" #include "src/arm_common/conv_bias/quint8/stride1_dotprod.h"
#if MGB_ENABLE_DOT
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/quint8/direct_dotprod.h" #include "src/arm_common/conv_bias/quint8/direct_dotprod.h"
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/stride1_dotprod.h View File

@@ -8,10 +8,10 @@
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#if __ARM_FEATURE_DOTPROD
#pragma once #pragma once


#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h"
#if MGB_ENABLE_DOT


namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {


+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.cpp View File

@@ -8,8 +8,8 @@
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/conv_bias/quint8/stride2_dotprod.h" #include "src/arm_common/conv_bias/quint8/stride2_dotprod.h"
#if MGB_ENABLE_DOT
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
#include "src/arm_common/conv_bias/quint8/direct_dotprod.h" #include "src/arm_common/conv_bias/quint8/direct_dotprod.h"
#include "src/arm_common/elemwise_op.h" #include "src/arm_common/elemwise_op.h"


+ 1
- 1
dnn/src/arm_common/conv_bias/quint8/stride2_dotprod.h View File

@@ -8,10 +8,10 @@
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#if __ARM_FEATURE_DOTPROD
#pragma once #pragma once


#include "src/arm_common/conv_bias/opr_impl.h" #include "src/arm_common/conv_bias/opr_impl.h"
#if MGB_ENABLE_DOT


namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {


+ 8
- 2
dnn/src/arm_common/convolution/int8x8x32/algos.cpp View File

@@ -13,21 +13,24 @@
#include "src/arm_common/convolution/int8x8x32/algos.h" #include "src/arm_common/convolution/int8x8x32/algos.h"
#include "src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h" #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h"
#include "src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h" #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h"
#include "src/common/opr_delegate.h"


#include "midout.h" #include "midout.h"
#include "src/common/opr_delegate.h"


MIDOUT_DECL(megdnn_arm_conv_int8832_kimpl) MIDOUT_DECL(megdnn_arm_conv_int8832_kimpl)


using namespace megdnn; using namespace megdnn;
using namespace arm_common; using namespace arm_common;


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
/* ===================== ConvolutionBackwardData ===================== */ /* ===================== ConvolutionBackwardData ===================== */
/* ===================== direct stride 1 algo ===================== */ /* ===================== direct stride 1 algo ===================== */
bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::usable( bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::usable(
fallback::ConvolutionBackwardDataImpl*, fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const { const NCBKernSizeParam& param) const {
if (!cpuinfo_has_arm_neon_dot()){
return false;
}
return deconv::can_stride1_int8x8x32_dot(param); return deconv::can_stride1_int8x8x32_dot(param);
} }


@@ -57,6 +60,9 @@ ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1::dispatch_kern(
bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::usable( bool ConvolutionBackwardDataImpl::AlgoSdot8DirectStride2::usable(
fallback::ConvolutionBackwardDataImpl*, fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const { const NCBKernSizeParam& param) const {
if (!cpuinfo_has_arm_neon_dot()){
return false;
}
return deconv::can_stride2_int8x8x32_dot(param); return deconv::can_stride2_int8x8x32_dot(param);
} }




+ 1
- 1
dnn/src/arm_common/convolution/int8x8x32/algos.h View File

@@ -17,7 +17,7 @@
namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
/* ===================== ConvolutionBackwardData ===================== */ /* ===================== ConvolutionBackwardData ===================== */


class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1 final class ConvolutionBackwardDataImpl::AlgoSdot8DirectStride1 final


+ 6
- 3
dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.cpp View File

@@ -9,11 +9,9 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h" #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h"
#if MGB_ENABLE_DOT
#include "src/common/utils.h" #include "src/common/utils.h"

#include <cstring>
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"


using namespace megdnn; using namespace megdnn;
@@ -94,6 +92,7 @@ inline int8x16_t vqtbl1q_s8_common(int8x16_t a, uint8x16_t index) {
_sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k1_idx, _elem); \ _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k1_idx, _elem); \
_sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem);


MEGDNN_ATTRIBUTE_TARGET("dotprod")
void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) {
MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(IH);
@@ -328,6 +327,7 @@ void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst,
} }
} }


MEGDNN_ATTRIBUTE_TARGET("dotprod")
void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst, void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) {
MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(IH);
@@ -530,6 +530,7 @@ void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst,
_sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k01_idx, _elem); \ _sum0##_c_idx = vdotq_s32(_sum0##_c_idx, _k##_k01_idx, _elem); \
_sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem);


MEGDNN_ATTRIBUTE_TARGET("dotprod")
void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst, void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) {
MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(IH);
@@ -777,6 +778,7 @@ void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst,
} }
} }


MEGDNN_ATTRIBUTE_TARGET("dotprod")
void deconv_direct_7x7(const int8_t* src, const int8_t* filter, int32_t* dst, void deconv_direct_7x7(const int8_t* src, const int8_t* filter, int32_t* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) {
MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(IH);
@@ -1070,6 +1072,7 @@ void deconv_direct_7x7(const int8_t* src, const int8_t* filter, int32_t* dst,


} // anonymous namespace } // anonymous namespace



size_t deconv::get_workspace_in_bytes_stride1_int8x8x32_dot( size_t deconv::get_workspace_in_bytes_stride1_int8x8x32_dot(
const NCBKernSizeParam& param) { const NCBKernSizeParam& param) {
return get_bundle(param).total_size_in_bytes(); return get_bundle(param).total_size_in_bytes();


+ 1
- 1
dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride1.h View File

@@ -10,8 +10,8 @@
*/ */
#pragma once #pragma once


#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/convolution/opr_impl.h" #include "src/arm_common/convolution/opr_impl.h"
#if MGB_ENABLE_DOT


#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>


+ 5
- 3
dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.cpp View File

@@ -9,11 +9,9 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h" #include "src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h"
#if MGB_ENABLE_DOT
#include "src/common/utils.h" #include "src/common/utils.h"

#include <cstring>
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"


using namespace megdnn; using namespace megdnn;
@@ -83,6 +81,7 @@ inline int8x16_t vqtbl1q_s8_common(int8x16_t a, uint8x16_t index) {
_sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem); _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k2_idx, _elem);


template <bool even> template <bool even>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst, void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) {
MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(IH);
@@ -334,6 +333,7 @@ void deconv_direct_2x2(const int8_t* src, const int8_t* filter, int32_t* dst,
} }


template <bool even> template <bool even>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst, void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) {
MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(IH);
@@ -558,6 +558,7 @@ void deconv_direct_3x3(const int8_t* src, const int8_t* filter, int32_t* dst,
_sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem); _sum1##_c_idx = vdotq_s32(_sum1##_c_idx, _k##_k11_idx, _elem);


template <bool even> template <bool even>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst, void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) {
MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(IH);
@@ -835,6 +836,7 @@ void deconv_direct_5x5(const int8_t* src, const int8_t* filter, int32_t* dst,
} }


template <bool even> template <bool even>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void deconv_direct_7x7(const int8_t* src, const int8_t* filter, int32_t* dst, void deconv_direct_7x7(const int8_t* src, const int8_t* filter, int32_t* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) { size_t IH, size_t IW, size_t OH, size_t OW, size_t IC) {
MEGDNN_MARK_USED_VAR(IH); MEGDNN_MARK_USED_VAR(IH);


+ 1
- 1
dnn/src/arm_common/convolution/int8x8x32/conv_backdata_stride2.h View File

@@ -10,8 +10,8 @@
*/ */
#pragma once #pragma once


#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/convolution/opr_impl.h" #include "src/arm_common/convolution/opr_impl.h"
#if MGB_ENABLE_DOT


#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>


+ 2
- 2
dnn/src/arm_common/convolution/opr_impl.cpp View File

@@ -24,7 +24,7 @@ using namespace arm_common;


/* ===================== ConvolutionBackwardData ===================== */ /* ===================== ConvolutionBackwardData ===================== */
class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj { class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
AlgoSdot8DirectStride1 i8x8x32_direct_stride1_sdot; AlgoSdot8DirectStride1 i8x8x32_direct_stride1_sdot;
AlgoSdot8DirectStride2 i8x8x32_direct_stride2_sdot; AlgoSdot8DirectStride2 i8x8x32_direct_stride2_sdot;
AlgoUdot8DirectStride1 quint8_direct_stride1_udot; AlgoUdot8DirectStride1 quint8_direct_stride1_udot;
@@ -37,7 +37,7 @@ class ConvolutionBackwardDataImpl::AlgoPack : NonCopyableObj {


public: public:
AlgoPack() { AlgoPack() {
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
m_all_algos.emplace_back(&i8x8x32_direct_stride1_sdot); m_all_algos.emplace_back(&i8x8x32_direct_stride1_sdot);
m_all_algos.emplace_back(&i8x8x32_direct_stride2_sdot); m_all_algos.emplace_back(&i8x8x32_direct_stride2_sdot);
m_all_algos.emplace_back(&quint8_direct_stride1_udot); m_all_algos.emplace_back(&quint8_direct_stride1_udot);


+ 1
- 1
dnn/src/arm_common/convolution/opr_impl.h View File

@@ -56,7 +56,7 @@ public:
MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl); MEGDNN_FB_DECL_GET_ALGO_FROM_DESC(ConvolutionBackwardDataImpl);


private: private:
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
class AlgoSdot8DirectStride1; class AlgoSdot8DirectStride1;
class AlgoSdot8DirectStride2; class AlgoSdot8DirectStride2;
class AlgoUdot8DirectStride1; class AlgoUdot8DirectStride1;


+ 9
- 1
dnn/src/arm_common/convolution/quint8/algos.cpp View File

@@ -14,6 +14,7 @@
#include "src/arm_common/convolution/quint8/conv_backdata_stride1.h" #include "src/arm_common/convolution/quint8/conv_backdata_stride1.h"
#include "src/arm_common/convolution/quint8/conv_backdata_stride2.h" #include "src/arm_common/convolution/quint8/conv_backdata_stride2.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"

#include "midout.h" #include "midout.h"


MIDOUT_DECL(megdnn_arm_conv_quint8_kimpl) MIDOUT_DECL(megdnn_arm_conv_quint8_kimpl)
@@ -21,7 +22,7 @@ MIDOUT_DECL(megdnn_arm_conv_quint8_kimpl)
using namespace megdnn; using namespace megdnn;
using namespace arm_common; using namespace arm_common;


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT


/* ===================== ConvolutionBackwardData ===================== */ /* ===================== ConvolutionBackwardData ===================== */


@@ -29,6 +30,10 @@ using namespace arm_common;
bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::usable( bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::usable(
fallback::ConvolutionBackwardDataImpl*, fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const { const NCBKernSizeParam& param) const {

if (!cpuinfo_has_arm_neon_dot()){
return false;
}
return deconv::can_stride1_quint8_dot(param); return deconv::can_stride1_quint8_dot(param);
} }


@@ -58,6 +63,9 @@ ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1::dispatch_kern(
bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::usable( bool ConvolutionBackwardDataImpl::AlgoUdot8DirectStride2::usable(
fallback::ConvolutionBackwardDataImpl*, fallback::ConvolutionBackwardDataImpl*,
const NCBKernSizeParam& param) const { const NCBKernSizeParam& param) const {
if (!cpuinfo_has_arm_neon_dot()){
return false;
}
return deconv::can_stride2_quint8_dot(param); return deconv::can_stride2_quint8_dot(param);
} }




+ 1
- 1
dnn/src/arm_common/convolution/quint8/algos.h View File

@@ -17,7 +17,7 @@
namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
/* ===================== ConvolutionBackwardData ===================== */ /* ===================== ConvolutionBackwardData ===================== */
class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1 final class ConvolutionBackwardDataImpl::AlgoUdot8DirectStride1 final
: public AlgoBase { : public AlgoBase {


+ 6
- 3
dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.cpp View File

@@ -9,11 +9,9 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/convolution/quint8/conv_backdata_stride1.h" #include "src/arm_common/convolution/quint8/conv_backdata_stride1.h"
#if MGB_ENABLE_DOT
#include "src/common/utils.h" #include "src/common/utils.h"

#include <cstring>
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"


using namespace megdnn; using namespace megdnn;
@@ -109,6 +107,7 @@ inline uint8x16_t vqtbl1q_u8_common(uint8x16_t a, uint8x16_t index) {
_sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem));


template <bool last_oc = false> template <bool last_oc = false>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC,
uint8_t src_zp, uint8_t filter_zp, uint8_t src_zp, uint8_t filter_zp,
@@ -385,6 +384,7 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst,
} }


template <bool last_oc = false> template <bool last_oc = false>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC,
uint8_t src_zp, uint8_t filter_zp, uint8_t src_zp, uint8_t filter_zp,
@@ -636,6 +636,7 @@ void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst,
_sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2);


template <bool last_oc = false> template <bool last_oc = false>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC,
uint8_t src_zp, uint8_t filter_zp, uint8_t src_zp, uint8_t filter_zp,
@@ -907,6 +908,7 @@ void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst,
} }


template <bool last_oc = false> template <bool last_oc = false>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst, void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC,
uint8_t src_zp, uint8_t filter_zp, uint8_t src_zp, uint8_t filter_zp,
@@ -1220,6 +1222,7 @@ void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst,


} // anonymous namespace } // anonymous namespace



size_t deconv::get_workspace_in_bytes_stride1_quint8_dot( size_t deconv::get_workspace_in_bytes_stride1_quint8_dot(
const NCBKernSizeParam& param) { const NCBKernSizeParam& param) {
return get_bundle(param).total_size_in_bytes(); return get_bundle(param).total_size_in_bytes();


+ 1
- 4
dnn/src/arm_common/convolution/quint8/conv_backdata_stride1.h View File

@@ -10,11 +10,8 @@
*/ */
#pragma once #pragma once


#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/convolution/opr_impl.h" #include "src/arm_common/convolution/opr_impl.h"

#include <cstddef>
#include <cstdint>
#if MGB_ENABLE_DOT


namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {


+ 5
- 3
dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.cpp View File

@@ -9,11 +9,9 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/convolution/quint8/conv_backdata_stride2.h" #include "src/arm_common/convolution/quint8/conv_backdata_stride2.h"
#if MGB_ENABLE_DOT
#include "src/common/utils.h" #include "src/common/utils.h"

#include <cstring>
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"


using namespace megdnn; using namespace megdnn;
@@ -110,6 +108,7 @@ inline uint8x16_t vqtbx1q_u8_common(uint8x16_t a, uint8x16_t t,
_sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem)); _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, vdotq2_u32(_filter_zp, _elem));


template <bool even, bool last_oc = false> template <bool even, bool last_oc = false>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst, void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC,
uint8_t src_zp, uint8_t filter_zp, uint8_t src_zp, uint8_t filter_zp,
@@ -402,6 +401,7 @@ void deconv_direct_2x2(const uint8_t* src, const uint8_t* filter, int32_t* dst,
} }


template <bool even, bool last_oc = false> template <bool even, bool last_oc = false>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst, void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC,
uint8_t src_zp, uint8_t filter_zp, uint8_t src_zp, uint8_t filter_zp,
@@ -673,6 +673,7 @@ void deconv_direct_3x3(const uint8_t* src, const uint8_t* filter, int32_t* dst,
_sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2); _sum1##_c_idx = vsubq_u32(_sum1##_c_idx, _elem2);


template <bool even, bool last_oc = false> template <bool even, bool last_oc = false>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst, void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC,
uint8_t src_zp, uint8_t filter_zp, uint8_t src_zp, uint8_t filter_zp,
@@ -972,6 +973,7 @@ void deconv_direct_5x5(const uint8_t* src, const uint8_t* filter, int32_t* dst,
} }


template <bool even, bool last_oc = false> template <bool even, bool last_oc = false>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst, void deconv_direct_7x7(const uint8_t* src, const uint8_t* filter, int32_t* dst,
size_t IH, size_t IW, size_t OH, size_t OW, size_t IC, size_t IH, size_t IW, size_t OH, size_t OW, size_t IC,
uint8_t src_zp, uint8_t filter_zp, uint8_t src_zp, uint8_t filter_zp,


+ 1
- 4
dnn/src/arm_common/convolution/quint8/conv_backdata_stride2.h View File

@@ -10,11 +10,8 @@
*/ */
#pragma once #pragma once


#if __ARM_FEATURE_DOTPROD
#include "src/arm_common/convolution/opr_impl.h" #include "src/arm_common/convolution/opr_impl.h"

#include <cstddef>
#include <cstdint>
#if MGB_ENABLE_DOT


namespace megdnn { namespace megdnn {
namespace arm_common { namespace arm_common {


+ 7
- 1
dnn/src/arm_common/matrix_mul/algos.cpp View File

@@ -14,8 +14,10 @@
#include "src/arm_common/matrix_mul/fp16/hgemv.h" #include "src/arm_common/matrix_mul/fp16/hgemv.h"
#include "src/arm_common/matrix_mul/fp32/exec_sgemv.h" #include "src/arm_common/matrix_mul/fp32/exec_sgemv.h"
#include "src/arm_common/matrix_mul/int8/gemv.h" #include "src/arm_common/matrix_mul/int8/gemv.h"

#include "midout.h" #include "midout.h"



MIDOUT_DECL(megdnn_arm_hgemv) MIDOUT_DECL(megdnn_arm_hgemv)
MIDOUT_DECL(megdnn_arm_exec_int8816) MIDOUT_DECL(megdnn_arm_exec_int8816)
MIDOUT_DECL(megdnn_arm_exec_int8832) MIDOUT_DECL(megdnn_arm_exec_int8832)
@@ -158,7 +160,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32GemvMK4::get_kern(
return int8x8x32_gemv_mk4_kern; return int8x8x32_gemv_mk4_kern;
} }


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
/* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */ /* =================== Int8x8x32 Gemv MK4_DOT algo ==================== */
namespace { namespace {
void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) { void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) {
@@ -176,6 +178,10 @@ void int8x8x32_gemv_mk4_dot_kern(const MatrixMulImpl::KernParam& kern_param) {


bool MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::usable( bool MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot::usable(
const KernSizeParam& kern_size_param) const { const KernSizeParam& kern_size_param) const {

if (!cpuinfo_has_arm_neon_dot()){
return false;
}
auto M = kern_size_param.M; auto M = kern_size_param.M;
auto N = kern_size_param.N; auto N = kern_size_param.N;
auto K = kern_size_param.K; auto K = kern_size_param.K;


+ 1
- 1
dnn/src/arm_common/matrix_mul/algos.h View File

@@ -63,7 +63,7 @@ public:
MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4) MEGDNN_DECL_ALGO_TYPE(ARM_COMMON_INT8X8X32_GEMV_MK4)
}; };


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase {
public: public:
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {


+ 27
- 8
dnn/src/arm_common/matrix_mul/int8/gemv.cpp View File

@@ -9,7 +9,6 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#include <cstddef>
#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/arm_common/matrix_mul/int8/gemv.h" #include "src/arm_common/matrix_mul/int8/gemv.h"
#include "src/common/utils.h" #include "src/common/utils.h"
@@ -21,7 +20,6 @@ MIDOUT_DECL(megdnn_arm_common_int8_gemv)
using namespace megdnn; using namespace megdnn;
using namespace arm_common; using namespace arm_common;


#if !__ARM_FEATURE_DOTPROD


namespace { namespace {


@@ -170,12 +168,11 @@ void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B,
} }


} // namespace } // namespace
#endif


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
namespace { namespace {
void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B,
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void gemv_naive_n_dot(const int8_t* __restrict A, const int8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K, int32_t* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride) { size_t Astride, size_t Bstride, size_t Cstride) {
megdnn_assert(N == 1 && Bstride == 1); megdnn_assert(N == 1 && Bstride == 1);
@@ -244,7 +241,8 @@ void gemv_naive_n(const int8_t* __restrict A, const int8_t* __restrict B,
} }
} }


void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B,
MEGDNN_ATTRIBUTE_TARGET("dotprod")
void gemv_naive_n_mk4_dotprod(const int8_t* __restrict A, const int8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K, int32_t* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride) { size_t Astride, size_t Bstride, size_t Cstride) {
constexpr size_t PACK_SIZE = 4; constexpr size_t PACK_SIZE = 4;
@@ -323,6 +321,7 @@ void gemv_naive_n_mk4(const int8_t* __restrict A, const int8_t* __restrict B,
} }
} }


MEGDNN_ATTRIBUTE_TARGET("dotprod")
void gemv_naive_n_mk4_dot(const int8_t* __restrict A, void gemv_naive_n_mk4_dot(const int8_t* __restrict A,
const int8_t* __restrict B, int32_t* __restrict C, const int8_t* __restrict B, int32_t* __restrict C,
size_t M, size_t N, size_t K, size_t Astride, size_t M, size_t N, size_t K, size_t Astride,
@@ -403,7 +402,16 @@ void arm_common::gemv_like(const int8_t* __restrict A,
megdnn_assert(N == 1); megdnn_assert(N == 1);
MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, MIDOUT_BEGIN(megdnn_arm_common_int8_gemv,
midout_iv("INT8_gemv_like"_hash)) { midout_iv("INT8_gemv_like"_hash)) {
#if MGB_ENABLE_DOT
if (cpuinfo_has_arm_neon_dot()) {
return gemv_naive_n_dot(A, B, C, M, N, K, Astride, Bstride,
Cstride);
} else {
return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride);
}
#else
return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride); return gemv_naive_n(A, B, C, M, N, K, Astride, Bstride, Cstride);
#endif
} }
MIDOUT_END(); MIDOUT_END();
} }
@@ -416,12 +424,22 @@ void arm_common::gemv_like_mk4(const int8_t* __restrict A,
megdnn_assert(N == 1); megdnn_assert(N == 1);
MIDOUT_BEGIN(megdnn_arm_common_int8_gemv, MIDOUT_BEGIN(megdnn_arm_common_int8_gemv,
midout_iv("INT8_gemv_like_mk4"_hash)) { midout_iv("INT8_gemv_like_mk4"_hash)) {
#if MGB_ENABLE_DOT
if (cpuinfo_has_arm_neon_dot()) {
return gemv_naive_n_mk4_dotprod(A, B, C, M, N, K, Astride, Bstride,
Cstride);
} else {
return gemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride,
Cstride);
}
#else
return gemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride); return gemv_naive_n_mk4(A, B, C, M, N, K, Astride, Bstride, Cstride);
#endif
} }
MIDOUT_END(); MIDOUT_END();
} }


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
void arm_common::gemv_like_mk4_dot(const int8_t* __restrict A, void arm_common::gemv_like_mk4_dot(const int8_t* __restrict A,
const int8_t* __restrict B, const int8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, int32_t* __restrict C, size_t M, size_t N,
@@ -437,4 +455,5 @@ void arm_common::gemv_like_mk4_dot(const int8_t* __restrict A,
} }
#endif #endif



// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 1
- 1
dnn/src/arm_common/matrix_mul/int8/gemv.h View File

@@ -28,7 +28,7 @@ void gemv_like_mk4(const int8_t* __restrict A, const int8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K, int32_t* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride); size_t Astride, size_t Bstride, size_t Cstride);


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
void gemv_like_mk4_dot(const int8_t* __restrict A, const int8_t* __restrict B, void gemv_like_mk4_dot(const int8_t* __restrict A, const int8_t* __restrict B,
int32_t* __restrict C, size_t M, size_t N, size_t K, int32_t* __restrict C, size_t M, size_t N, size_t K,
size_t Astride, size_t Bstride, size_t Cstride); size_t Astride, size_t Bstride, size_t Cstride);


+ 2
- 2
dnn/src/arm_common/matrix_mul/opr_impl.cpp View File

@@ -22,7 +22,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#endif #endif
AlgoInt8x8x32Gemv int8x8x32_gemv; AlgoInt8x8x32Gemv int8x8x32_gemv;
AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4; AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4;
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot; AlgoInt8x8x32GemvMK4Dot int8x8x32_gemv_mk4_dot;
#endif #endif
AlgoGevm gevm; AlgoGevm gevm;
@@ -37,7 +37,7 @@ public:
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
m_all_algos.emplace_back(&f16gemv); m_all_algos.emplace_back(&f16gemv);
#endif #endif
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
m_all_algos.emplace_back(&int8x8x32_gemv_mk4_dot); m_all_algos.emplace_back(&int8x8x32_gemv_mk4_dot);
#endif #endif
m_all_algos.emplace_back(&int8x8x32_gemv); m_all_algos.emplace_back(&int8x8x32_gemv);


+ 1
- 1
dnn/src/arm_common/matrix_mul/opr_impl.h View File

@@ -42,7 +42,7 @@ protected:
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
class AlgoF16Gemv; class AlgoF16Gemv;
#endif #endif
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
class AlgoInt8x8x32GemvMK4Dot;// Arm_common Int8x8x32 Gemv NCHW44_DOT class AlgoInt8x8x32GemvMK4Dot;// Arm_common Int8x8x32 Gemv NCHW44_DOT
#endif #endif
class AlgoInt8x8x16; // Arm_common Int 8x8x16 class AlgoInt8x8x16; // Arm_common Int 8x8x16


+ 3
- 2
dnn/src/arm_common/neon_struct.h View File

@@ -69,9 +69,10 @@ struct Vfmaq_laneq_f32 {
return vfmaq_laneq_f32(a, b, v, lane); return vfmaq_laneq_f32(a, b, v, lane);
} }
}; };
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
struct Vdotq_laneq_s32 { struct Vdotq_laneq_s32 {
template <const int lane> template <const int lane>
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { static __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_laneq_s32(a, b, v, lane); return vdotq_laneq_s32(a, b, v, lane);
} }
@@ -82,4 +83,4 @@ struct Vdotq_laneq_s32 {
} // namespace megdnn } // namespace megdnn


#undef __ai #undef __ai
// vim: syntax=cpp.doxygen
// vim: syntax=cpp.doxygen

+ 20
- 7
dnn/src/arm_common/simd_macro/marm_neon.h View File

@@ -10,7 +10,12 @@
* implied. * implied.
*/ */
#pragma once #pragma once

#if MGB_ENABLE_DOT
#if defined(__ARM_FEATURE_DOTPROD)
#undef __ARM_FEATURE_DOTPROD
#endif
#define __ARM_FEATURE_DOTPROD 1
#endif
#include <arm_neon.h> #include <arm_neon.h>
#include "megdnn/arch.h" #include "megdnn/arch.h"
#include "src/common/unroll_macro.h" #include "src/common/unroll_macro.h"
@@ -249,13 +254,14 @@ __ai float16x8_t vdupq_n_f16(__fp16 a) {


#endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #endif // __ARM_FEATURE_FP16_VECTOR_ARITHMETIC


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
MEGDNN_ATTRIBUTE_TARGET("dotprod")
__ai int32x4_t vdotq2_s32(int8x16_t a, int8x16_t b) { __ai int32x4_t vdotq2_s32(int8x16_t a, int8x16_t b) {
int32x4_t c = vdupq_n_s32(0); int32x4_t c = vdupq_n_s32(0);
return vdotq_s32(c, a, b); return vdotq_s32(c, a, b);
} }


MEGDNN_ATTRIBUTE_TARGET("dotprod")
__ai uint32x4_t vdotq2_u32(uint8x16_t a, uint8x16_t b) { __ai uint32x4_t vdotq2_u32(uint8x16_t a, uint8x16_t b) {
uint32x4_t c = vdupq_n_u32(0); uint32x4_t c = vdupq_n_u32(0);
return vdotq_u32(c, a, b); return vdotq_u32(c, a, b);
@@ -275,11 +281,13 @@ __ai uint32x4_t vdotq2_u32(uint8x16_t a, uint8x16_t b) {
c; \ c; \
}) })


MEGDNN_ATTRIBUTE_TARGET("dotprod")
__ai int32x2_t vdot2_s32(int8x8_t a, int8x8_t b) { __ai int32x2_t vdot2_s32(int8x8_t a, int8x8_t b) {
int32x2_t c = vdup_n_s32(0); int32x2_t c = vdup_n_s32(0);
return vdot_s32(c, a, b); return vdot_s32(c, a, b);
} }


MEGDNN_ATTRIBUTE_TARGET("dotprod")
__ai uint32x2_t vdot2_u8(uint8x8_t a, uint8x8_t b) { __ai uint32x2_t vdot2_u8(uint8x8_t a, uint8x8_t b) {
uint32x2_t c = vdup_n_u32(0); uint32x2_t c = vdup_n_u32(0);
return vdot_u32(c, a, b); return vdot_u32(c, a, b);
@@ -298,8 +306,7 @@ __ai uint32x2_t vdot2_u8(uint8x8_t a, uint8x8_t b) {
c = vdot_lane_u32(c, a, b, lane); \ c = vdot_lane_u32(c, a, b, lane); \
c; \ c; \
}) })

#endif // __ARM_FEATURE_DOTPROD
#endif // MGB_ENABLE_DOT


#if __GNUC__ < 8 #if __GNUC__ < 8
#undef vld1q_f32_x2 #undef vld1q_f32_x2
@@ -575,7 +582,7 @@ struct Vfmsq_laneq_f32_armv7<3> {
#define vfmsq_laneq_f32(a, b, v, lane) \ #define vfmsq_laneq_f32(a, b, v, lane) \
Vfmsq_laneq_f32_armv7<lane>::impl(a, b, v) Vfmsq_laneq_f32_armv7<lane>::impl(a, b, v)


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
namespace { namespace {
template <int lane> template <int lane>
struct Vdotq_laneq_s32_armv7 { struct Vdotq_laneq_s32_armv7 {
@@ -583,24 +590,28 @@ struct Vdotq_laneq_s32_armv7 {
}; };
template <> template <>
struct Vdotq_laneq_s32_armv7<0> { struct Vdotq_laneq_s32_armv7<0> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
__ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_lane_s32(a, b, vget_low_s32(v), 0); return vdotq_lane_s32(a, b, vget_low_s32(v), 0);
} }
}; };
template <> template <>
struct Vdotq_laneq_s32_armv7<1> { struct Vdotq_laneq_s32_armv7<1> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
__ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_lane_s32(a, b, vget_low_s32(v), 1); return vdotq_lane_s32(a, b, vget_low_s32(v), 1);
} }
}; };
template <> template <>
struct Vdotq_laneq_s32_armv7<2> { struct Vdotq_laneq_s32_armv7<2> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
__ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_lane_s32(a, b, vget_high_s32(v), 0); return vdotq_lane_s32(a, b, vget_high_s32(v), 0);
} }
}; };
template <> template <>
struct Vdotq_laneq_s32_armv7<3> { struct Vdotq_laneq_s32_armv7<3> {
MEGDNN_ATTRIBUTE_TARGET("dotprod")
__ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) { __ai int32x4_t impl(int32x4_t a, int8x16_t b, int8x16_t v) {
return vdotq_lane_s32(a, b, vget_high_f32(v), 1); return vdotq_lane_s32(a, b, vget_high_f32(v), 1);
} }
@@ -765,7 +776,9 @@ __ai float32x4_t Vfmsq_f32(float32x4_t& a, float32x4_t& b, float32x4_t& v) {
:); :);
return a; return a;
} }

#if MGB_ENABLE_DOT
#undef __ARM_FEATURE_DOTPROD
#endif
#undef __ai #undef __ai
#pragma GCC diagnostic pop #pragma GCC diagnostic pop




+ 13
- 1
dnn/src/armv7/matrix_mul/algos.cpp View File

@@ -19,6 +19,9 @@
#include "src/armv7/matrix_mul/quint8/strategy.h" #include "src/armv7/matrix_mul/quint8/strategy.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/matrix_mul/gemm_impl.h" #include "src/fallback/matrix_mul/gemm_impl.h"
#if MGB_ENABLE_CPUINFO
#include "cpuinfo.h"
#endif


#include "midout.h" #include "midout.h"


@@ -744,7 +747,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x4x1,
armv7::matmul::gemm_s16x16x32_12x4, armv7::matmul::gemm_s16x16x32_12x4,
int16_t, int32_t, int16_t, int32_t,
AlgoDataType::INT16X16X32, DEFAULT); AlgoDataType::INT16X16X32, DEFAULT);
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
/* ===================== Int8 K6x8x4 algo ===================== */ /* ===================== Int8 K6x8x4 algo ===================== */
namespace { namespace {
void int8_k6x8x4_kern(const MatrixMulImpl::KernParam& kern_param) { void int8_k6x8x4_kern(const MatrixMulImpl::KernParam& kern_param) {
@@ -769,6 +772,9 @@ void int8_k6x8x4_kern(const MatrixMulImpl::KernParam& kern_param) {


bool MatrixMulImpl::AlgoInt8x8x32K6x8x4::usable( bool MatrixMulImpl::AlgoInt8x8x32K6x8x4::usable(
const KernSizeParam& kern_size_param) const { const KernSizeParam& kern_size_param) const {
if (!cpuinfo_has_arm_neon_dot()){
return false;
}
return can_be_treated_as_int8x8x32(kern_size_param); return can_be_treated_as_int8x8x32(kern_size_param);
} }


@@ -827,6 +833,9 @@ void quint8_dot_k4x8x4_kern(const MatrixMulImpl::KernParam& kern_param) {


bool MatrixMulImpl::AlgoQuint8DotK4x8x4::usable( bool MatrixMulImpl::AlgoQuint8DotK4x8x4::usable(
const KernSizeParam& kern_size_param) const { const KernSizeParam& kern_size_param) const {
if (!cpuinfo_has_arm_neon_dot()){
return false;
}
return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm && return kern_size_param.A_type.enumv() == DTypeEnum::Quantized8Asymm &&
kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm && kern_size_param.B_type.enumv() == DTypeEnum::Quantized8Asymm &&
kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 && kern_size_param.C_type.enumv() == DTypeEnum::QuantizedS32 &&
@@ -891,6 +900,9 @@ void int8_mk4_8x4x4_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {


bool MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::usable( bool MatrixMulImpl::AlgoInt8x8x32MK4_8x4x4DotProd::usable(
const KernSizeParam& kern_size_param) const { const KernSizeParam& kern_size_param) const {
if (!cpuinfo_has_arm_neon_dot()){
return false;
}
return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() && return kern_size_param.A_type.enumv() == kern_size_param.B_type.enumv() &&
(kern_size_param.A_type.enumv() == DTypeEnum::Int8 || (kern_size_param.A_type.enumv() == DTypeEnum::Int8 ||
kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) && kern_size_param.A_type.enumv() == DTypeEnum::QuantizedS8) &&


+ 1
- 1
dnn/src/armv7/matrix_mul/algos.h View File

@@ -86,7 +86,7 @@ public:
MEGDNN_DECL_ALGO_TYPE(ARMV7_F16_MK8_4X8) MEGDNN_DECL_ALGO_TYPE(ARMV7_F16_MK8_4X8)
}; };
#endif #endif
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
class MatrixMulImpl::AlgoInt8x8x32K6x8x4 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32K6x8x4 final : public AlgoBase {
public: public:
AlgoAttribute attribute() const override { AlgoAttribute attribute() const override {


+ 0
- 1
dnn/src/armv7/matrix_mul/asm/common.h View File

@@ -10,7 +10,6 @@
* implied. * implied.
*/ */
#pragma once #pragma once
#include <arm_neon.h>
#include <cmath> #include <cmath>
#include <cstdint> #include <cstdint>
#include <type_traits> #include <type_traits>


+ 0
- 1
dnn/src/armv7/matrix_mul/fp32/strategy_4x12.cpp View File

@@ -10,7 +10,6 @@
* implied. * implied.
*/ */


#include "src/arm_common/simd_macro/marm_neon.h"
#include "src/armv7/matrix_mul/asm/common.h" #include "src/armv7/matrix_mul/asm/common.h"
#include "src/armv7/matrix_mul/fp32/strategy.h" #include "src/armv7/matrix_mul/fp32/strategy.h"
#include "src/common/utils.h" #include "src/common/utils.h"


+ 3
- 1
dnn/src/armv7/matrix_mul/int8/kernel_6x8x4.h View File

@@ -9,7 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT


#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/armv7/matrix_mul/asm/common.h" #include "src/armv7/matrix_mul/asm/common.h"
@@ -43,6 +43,7 @@ namespace matmul_dot_6x8x4 {
// //
// Accumulator // Accumulator


MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_6x8(const int8_t* packA, const int8_t* packB, int K, static void kern_6x8(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int32_t* output, int LDC, bool is_first_k,
size_t m_remain = 6) { size_t m_remain = 6) {
@@ -274,6 +275,7 @@ static void kern_6x8(const int8_t* packA, const int8_t* packB, int K,
// //
// Accumulator // Accumulator


MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_6x4(const int8_t* packA, const int8_t* packB, int K, static void kern_6x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int32_t* output, int LDC, bool is_first_k,
size_t n_remain = 8, size_t m_remain = 6) { size_t n_remain = 8, size_t m_remain = 6) {


+ 3
- 2
dnn/src/armv7/matrix_mul/int8/kernel_mk4_dot_8x4x4.h View File

@@ -10,7 +10,7 @@
* implied. * implied.
*/ */


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT


#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/armv7/matrix_mul/asm/common.h" #include "src/armv7/matrix_mul/asm/common.h"
@@ -42,7 +42,7 @@ namespace matmul_mk4_dot_8x4x4 {
// |q14[0-4]| // |q14[0-4]|
// +--------+ // +--------+
// Accumulator // Accumulator
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_8x4(const int8_t* packA, const int8_t* packB, int K, static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int n_remain) { int32_t* output, int LDC, bool is_first_k, int n_remain) {
K /= 4; K /= 4;
@@ -211,6 +211,7 @@ static void kern_8x4(const int8_t* packA, const int8_t* packB, int K,
// +--------+ // +--------+
// Accumulator // Accumulator


MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_4x4(const int8_t* packA, const int8_t* packB, int K, static void kern_4x4(const int8_t* packA, const int8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, int n_remain) { int32_t* output, int LDC, bool is_first_k, int n_remain) {
K /= 4; K /= 4;


+ 1
- 1
dnn/src/armv7/matrix_mul/int8/strategy.cpp View File

@@ -175,7 +175,7 @@ void gemm_s8_4x8::kern(const dt_int8* packA, const dt_int8* packB, size_t M,
} }
} }


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
// ===========================gemm_s8_6x8====================================== // ===========================gemm_s8_6x8======================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_dots8_6x8); MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_dots8_6x8);
void gemm_dots8_6x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0, void gemm_dots8_6x8::pack_A(dt_int8* out, const dt_int8* in, int ldin, int y0,


+ 1
- 1
dnn/src/armv7/matrix_mul/int8/strategy.h View File

@@ -23,7 +23,7 @@ MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 8, 8, false, true,


MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 2, 16, false, false, MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 4, 2, 16, false, false,
gemm_mk4_s8_4x2); gemm_mk4_s8_4x2);
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 6, 8, 4, false, false, MEGDNN_REG_GEMM_STRATEGY(dt_int8, dt_int32, dt_int32, 6, 8, 4, false, false,
gemm_dots8_6x8); gemm_dots8_6x8);




+ 2
- 2
dnn/src/armv7/matrix_mul/opr_impl.cpp View File

@@ -27,7 +27,7 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoF16K4x16x1 f16_k4x16x1; AlgoF16K4x16x1 f16_k4x16x1;
AlgoF16MK8_4x8 f16_mk8_4x8; AlgoF16MK8_4x8 f16_mk8_4x8;
#endif #endif
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
AlgoInt8x8x32K6x8x4 int8_k6x8x4; AlgoInt8x8x32K6x8x4 int8_k6x8x4;
AlgoQuint8DotK4x8x4 quint8_k4x8x4; AlgoQuint8DotK4x8x4 quint8_k4x8x4;
AlgoInt8x8x32MK4_8x4x4DotProd int8x8x32_mk4_8x4x4_dotprod; AlgoInt8x8x32MK4_8x4x4DotProd int8x8x32_mk4_8x4x4_dotprod;
@@ -57,7 +57,7 @@ public:
m_all_algos.emplace_back(&f16_k4x16x1); m_all_algos.emplace_back(&f16_k4x16x1);
m_all_algos.emplace_back(&f16_mk8_4x8); m_all_algos.emplace_back(&f16_mk8_4x8);
#endif #endif
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
m_all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod); m_all_algos.emplace_back(&int8x8x32_mk4_8x4x4_dotprod);
m_all_algos.emplace_back(&int8_k6x8x4); m_all_algos.emplace_back(&int8_k6x8x4);
m_all_algos.emplace_back(&quint8_k4x8x4); m_all_algos.emplace_back(&quint8_k4x8x4);


+ 1
- 1
dnn/src/armv7/matrix_mul/opr_impl.h View File

@@ -49,7 +49,7 @@ private:
class AlgoF16K4x16x1; // Armv7 F16 Kernel 4x16x1 class AlgoF16K4x16x1; // Armv7 F16 Kernel 4x16x1
class AlgoF16MK8_4x8; // Armv7 F16 MK8 Format block 4x8 class AlgoF16MK8_4x8; // Armv7 F16 MK8 Format block 4x8
#endif #endif
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4 class AlgoInt8x8x32K6x8x4; // Armv7 Int8 Kernel 6x8x4
class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4 class AlgoQuint8DotK4x8x4; // Armv7 Quint8 Kernel 6x8x4
class AlgoInt8x8x32MK4_8x4x4DotProd; // Armv7 nchw44 Int8x8x32 Kernel 8x4x4 class AlgoInt8x8x32MK4_8x4x4DotProd; // Armv7 nchw44 Int8x8x32 Kernel 8x4x4


+ 3
- 2
dnn/src/armv7/matrix_mul/quint8/kernel_dot_4x8x4.h View File

@@ -9,7 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT


#include "src/arm_common/simd_macro/marm_neon.h" #include "src/arm_common/simd_macro/marm_neon.h"
#include "src/armv7/matrix_mul/asm/common.h" #include "src/armv7/matrix_mul/asm/common.h"
@@ -41,7 +41,7 @@ namespace matmul_dot_4x8x4 {
// +-------+-------+ - - - - +--------+--------+--------+ // +-------+-------+ - - - - +--------+--------+--------+
// //
// Accumulator // Accumulator
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K, static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, uint8_t zA, int32_t* output, int LDC, bool is_first_k, uint8_t zA,
uint8_t zB, uint32_t zAB, size_t m_remain = 4) { uint8_t zB, uint32_t zAB, size_t m_remain = 4) {
@@ -257,6 +257,7 @@ static void kern_4x8(const uint8_t* packA, const uint8_t* packB, int K,
// +-------+-------+ - - - - +--------+--------+--------+ // +-------+-------+ - - - - +--------+--------+--------+
// //
// Accumulator // Accumulator
MEGDNN_ATTRIBUTE_TARGET("dotprod")
static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K, static void kern_4x4(const uint8_t* packA, const uint8_t* packB, int K,
int32_t* output, int LDC, bool is_first_k, uint8_t zA, int32_t* output, int LDC, bool is_first_k, uint8_t zA,
uint8_t zB, uint32_t zAB, size_t m_remain = 4, uint8_t zB, uint32_t zAB, size_t m_remain = 4,


+ 1
- 1
dnn/src/armv7/matrix_mul/quint8/strategy.cpp View File

@@ -88,7 +88,7 @@ void gemm_u8_4x8::kern(const dt_uint8* packA, const dt_uint8* packB, size_t M,
} }
} }


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
// ===========================gemm_dot_quint8_4x8====================================== // ===========================gemm_dot_quint8_4x8======================================
MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_dot_quint8_4x8); MEGDNN_REG_GEMM_STRATEGY_IMPL(gemm_dot_quint8_4x8);
void gemm_dot_quint8_4x8::pack_A(dt_uint8* out, const dt_uint8* in, int ldin, void gemm_dot_quint8_4x8::pack_A(dt_uint8* out, const dt_uint8* in, int ldin,


+ 1
- 1
dnn/src/armv7/matrix_mul/quint8/strategy.h View File

@@ -17,7 +17,7 @@ namespace matmul {


MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 4, 8, 8, false, true, MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 4, 8, 8, false, true,
gemm_u8_4x8); gemm_u8_4x8);
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 4, 8, 4, false, false, MEGDNN_REG_GEMM_STRATEGY(dt_uint8, dt_int32, dt_int32, 4, 8, 4, false, false,
gemm_dot_quint8_4x8); gemm_dot_quint8_4x8);
#endif #endif


+ 7
- 0
dnn/src/common/utils.h View File

@@ -60,6 +60,13 @@
#include <windows.h> #include <windows.h>
#endif #endif



#if MEGDNN_AARCH64 || MEGDNN_ARMV7
#if MGB_ENABLE_CPUINFO
#include "cpuinfo.h"
#endif
#endif

#if __cplusplus >= 201703L || __clang_major__ >= 4 #if __cplusplus >= 201703L || __clang_major__ >= 4
#define MEGDNN_FALLTHRU [[fallthrough]]; #define MEGDNN_FALLTHRU [[fallthrough]];
#elif __GNUC__ >= 7 #elif __GNUC__ >= 7


+ 1
- 1
dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.cpp View File

@@ -148,7 +148,7 @@ struct GemvLike<stype, btype, param::ConvBias::Format::NCHW44> {
} }
}; };


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
template <typename stype, typename btype> template <typename stype, typename btype>
struct GemvLike<stype, btype, param::ConvBias::Format::NCHW44_DOT> { struct GemvLike<stype, btype, param::ConvBias::Format::NCHW44_DOT> {
inline static void do_gemv(const stype* A, const stype* B, btype* C, inline static void do_gemv(const stype* A, const stype* B, btype* C,


+ 3
- 3
dnn/test/aarch64/matrix_mul.cpp View File

@@ -87,7 +87,7 @@ TEST_F(AARCH64, MATRIX_MUL_F16_MK8) {
} }
#endif #endif


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
TEST_F(AARCH64, MATRIX_MUL_INT8X8X32_K8X12X4_DOTPROD) { TEST_F(AARCH64, MATRIX_MUL_INT8X8X32_K8X12X4_DOTPROD) {
matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{}, matrix_mul::check_matrix_mul(dtype::Int8{}, dtype::Int8{}, dtype::Int32{},
handle(), "AARCH64_INT8X8X32_K8X12X4_DOTPROD"); handle(), "AARCH64_INT8X8X32_K8X12X4_DOTPROD");
@@ -690,7 +690,7 @@ TEST_F(AARCH64, BENCHMARK_GEMV) {
run(M, K, N); run(M, K, N);
} }


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
TEST_F(AARCH64, BENCHMARK_TRANSPOSED_MATRIX_MUL_INT_8X8X32) { TEST_F(AARCH64, BENCHMARK_TRANSPOSED_MATRIX_MUL_INT_8X8X32) {
constexpr size_t RUNS = 50; constexpr size_t RUNS = 50;
param::MatrixMul param; param::MatrixMul param;
@@ -803,7 +803,7 @@ TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_INT8X8X32_MK4_8X12X4) {
std::cout << std::endl; std::cout << std::endl;
} }
} }
#endif // __ARM_FEATURE_DOTPROD
#endif // MGB_ENABLE_DOT


#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_F16_MK8) { TEST_F(AARCH64, BENCHMARK_MATRIX_MUL_F16_MK8) {


+ 5
- 6
dnn/test/arm_common/conv_bias.cpp View File

@@ -166,7 +166,7 @@ static void benchmark_convbias(Handle* handle, std::string int_name,
.set_display(false); .set_display(false);
} }
auto nchw44_algo_regx = ".*(DIRECT|NCHW_NCHW44).*"; auto nchw44_algo_regx = ".*(DIRECT|NCHW_NCHW44).*";
#if __ARM_FEATURE_DOTPROD
#if MGB_ENBALE_DOT
if (!is_fp32) { if (!is_fp32) {
nchw44_algo_regx = ".*DOT.*"; nchw44_algo_regx = ".*DOT.*";
} }
@@ -1852,7 +1852,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_QINT8_STRIDE1_NCHW44) {


#endif #endif


#if __ARM_FEATURE_DOTPROD
#if MGB_ENBALE_DOT
#if MEGDNN_WITH_BENCHMARK #if MEGDNN_WITH_BENCHMARK
TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) { TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_INT8_STRIDE1_WITHDOTPROD) {
// have to remove preferred restrict in usable func before run the benchmark // have to remove preferred restrict in usable func before run the benchmark
@@ -2440,7 +2440,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDSYM) {
dtype::QuantizedS8 stype(2.5f); dtype::QuantizedS8 stype(2.5f);
dtype::QuantizedS32 dtype(6.25f); dtype::QuantizedS32 dtype(6.25f);
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
#if __ARM_FEATURE_DOTPROD
#if MGB_ENBALE_DOT
benchmark_conv1x1("AARCH64_INT8X8X32_K8X12X4_DOTPROD", handle(), stype, benchmark_conv1x1("AARCH64_INT8X8X32_K8X12X4_DOTPROD", handle(), stype,
dtype, dtype, dtype); dtype, dtype, dtype);
#else #else
@@ -2460,7 +2460,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_S1_QUANTIZEDASYM) {
dtype::QuantizedS32 dtype(1.2 * 1.2); dtype::QuantizedS32 dtype(1.2 * 1.2);


#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
#if __ARM_FEATURE_DOTPROD
#if MGB_ENBALE_DOT
benchmark_conv1x1("AARCH64_QUINT8_K8X8X4_DOTPROD", handle(), stype, dtype, benchmark_conv1x1("AARCH64_QUINT8_K8X8X4_DOTPROD", handle(), stype, dtype,
dtype, dtype); dtype, dtype);
#else #else
@@ -2565,7 +2565,7 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_CONV1X1_GEMV_FP32) {
} }
} }


#ifndef __ARM_FEATURE_DOTPROD
//! enable none dot algo now
TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_1X1_S1_NCHW_VS_NCHW44_INT8x8x32) { TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_1X1_S1_NCHW_VS_NCHW44_INT8x8x32) {
std::vector<TestArg> conv_bias_1x1_args_nchw44 = std::vector<TestArg> conv_bias_1x1_args_nchw44 =
get_conv_bias_1x1_benchmark_args(4); get_conv_bias_1x1_benchmark_args(4);
@@ -2634,7 +2634,6 @@ TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_1X1_S1_NCHW_VS_NCHW44_INT8x8x32) {
computations / conv1x1_nchw44, conv1x1_nchw / conv1x1_nchw44); computations / conv1x1_nchw44, conv1x1_nchw / conv1x1_nchw44);
} }
} }
#endif


TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_WINOGRAD_VS_IM2COL_INT8) { TEST_F(ARM_COMMON, BENCHMARK_CONV_BIAS_WINOGRAD_VS_IM2COL_INT8) {
auto&& args = get_winograd_benchmark_args(3, 8); auto&& args = get_winograd_benchmark_args(3, 8);


+ 1
- 1
dnn/test/arm_common/conv_bias_multi_thread.cpp View File

@@ -500,7 +500,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_QUINT8_STRIDE2) {
} }


/****************************dot qint8 direct*************************/ /****************************dot qint8 direct*************************/
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_DOT_NCHW_NCHW44) {
auto args = get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE, auto args = get_nchw44_conv_bias_args({2, 3, 5, 7}, QUAN_NLMODE,
BR_AND_NO_BIASMODE, 2, false, true); BR_AND_NO_BIASMODE, 2, false, true);


+ 4
- 4
dnn/test/arm_common/conv_bias_multi_thread_benchmark.cpp View File

@@ -655,7 +655,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44) {
bench_case(1, 512, 256, 28, 28, 3, 4, 1, 2); bench_case(1, 512, 256, 28, 28, 3, 4, 1, 2);
} }


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT) { TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, BENCHMARK_CONVBIAS_INT8_NCHW44_DOT) {
constexpr size_t RUNS = 40; constexpr size_t RUNS = 40;
std::vector<DType> data_type = { std::vector<DType> data_type = {
@@ -892,7 +892,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}},
{1, {4}}, data_type); {1, {4}}, data_type);
} }
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
BENCHMARK_CONVBIAS_INT8_INT8_INT8_STRIDE1_WITHDOTPROD) { BENCHMARK_CONVBIAS_INT8_INT8_INT8_STRIDE1_WITHDOTPROD) {
constexpr size_t RUNS = 50; constexpr size_t RUNS = 50;
@@ -1157,7 +1157,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}}, benchmark_impl(param, shapes_and_computation, algo_name, RUNS, {2, {4, 5}},
{1, {4}}, data_type); {1, {4}}, data_type);
} }
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS, TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
BENCHMARK_CONVBIAS_QUINT8_QUINT8_QUINT8_STRIDE1_WITHDOTPROD) { BENCHMARK_CONVBIAS_QUINT8_QUINT8_QUINT8_STRIDE1_WITHDOTPROD) {
constexpr size_t RUNS = 50; constexpr size_t RUNS = 50;
@@ -1977,7 +1977,7 @@ TEST_F(ARM_COMMON_BENCHMARK_MULTI_THREADS,
dtype::QuantizedS32 btype(0.04f); dtype::QuantizedS32 btype(0.04f);
dtype::Quantized8Asymm dtype(1.4f, 110); dtype::Quantized8Asymm dtype(1.4f, 110);
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
conv1x1_multithread_benchmark("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:8", conv1x1_multithread_benchmark("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:8",
stype, ftype, btype, dtype); stype, ftype, btype, dtype);
#else #else


+ 9
- 11
dnn/test/arm_common/conv_bias_multi_thread_conv1x1.cpp View File

@@ -20,7 +20,7 @@ using namespace megdnn;
using namespace test; using namespace test;
using namespace conv_bias; using namespace conv_bias;


#ifdef __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_CONV1x1_QUANTIZEDSYM_MK4_DOT) {
UniformIntRNG rng{-50, 50}; UniformIntRNG rng{-50, 50};


@@ -138,7 +138,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDSYM) {
dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \ dtype::QuantizedS8(2.5f), dtype::QuantizedS32(6.25f), \
dtype::QuantizedS8(60.25f), name); dtype::QuantizedS8(60.25f), name);
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:24"); cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:24");
#else #else
cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24"); cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24");
@@ -174,7 +174,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUANTIZEDASYM) {
name); name);
float epsilon = 0.001; float epsilon = 0.001;
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:48"); cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:48");
#else #else
cb("CONV1x1:AARCH64_QUINT8_K8X8X8:24"); cb("CONV1x1:AARCH64_QUINT8_K8X8X8:24");
@@ -210,13 +210,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_QUINT8x8x32) {
dtype::QuantizedS32(1.2 * 1.3), {}, name); dtype::QuantizedS32(1.2 * 1.3), {}, name);


#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:24"); cb("CONV1x1:AARCH64_QUINT8_K8X8X4_DOTPROD:24");
#else #else
cb("CONV1x1:AARCH64_QUINT8_K8X8X8:48"); cb("CONV1x1:AARCH64_QUINT8_K8X8X8:48");
#endif #endif
#elif MEGDNN_ARMV7 #elif MEGDNN_ARMV7
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
cb("CONV1x1:AARCH32_QUINT8_K4X8X4:48"); cb("CONV1x1:AARCH32_QUINT8_K4X8X4:48");
#endif #endif
cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24"); cb("CONV1x1:ARMV7_QUINT8_K4X8X8:24");
@@ -287,14 +287,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) {
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);


#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:48"); cb("CONV1x1:AARCH64_INT8X8X32_K8X12X4_DOTPROD:48");
#else #else
cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24"); cb("CONV1x1:AARCH64_INT8X8X32_K8X8X8:24");
cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:24"); cb("CONV1x1:AARCH64_INT8X8X32_K4X4X16:24");
#endif #endif
#elif MEGDNN_ARMV7 #elif MEGDNN_ARMV7
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
cb("CONV1x1:AARCH32_INT8_K6X8X4:48"); cb("CONV1x1:AARCH32_INT8_K6X8X4:48");
#endif #endif
cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:24"); cb("CONV1x1:ARMV7_INT8X8X32_K4X8X8:24");
@@ -312,8 +312,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32) {
} }
checker_conv_bias_mul_int8x8x32(gemv_args, handle(), "CONV1x1_GEMV"); checker_conv_bias_mul_int8x8x32(gemv_args, handle(), "CONV1x1_GEMV");
} }

#ifndef __ARM_FEATURE_DOTPROD
//! enable none dot algo now
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) {
using namespace conv_bias; using namespace conv_bias;
std::vector<conv_bias::TestArg> args = std::vector<conv_bias::TestArg> args =
@@ -345,7 +344,6 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_MK4) {
#endif #endif
#undef cb #undef cb
} }
#endif


TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44) {
using namespace conv_bias; using namespace conv_bias;
@@ -364,7 +362,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44) {
"CONV1x1_GEMV"); "CONV1x1_GEMV");
} }


#ifdef __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44_DOT) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_1X1_S1_INT8x8x32_NCHW44_DOT) {
using namespace conv_bias; using namespace conv_bias;
std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args(


+ 11
- 10
dnn/test/arm_common/conv_bias_multi_thread_im2col.cpp View File

@@ -135,7 +135,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) {


float epsilon = 0.001; float epsilon = 0.001;
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD"); cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD");
#else #else
cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8"); cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8");
@@ -148,7 +148,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM) {
#undef cb #undef cb
} }


#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT


TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT) {
UniformIntRNG rng{-50, 50}; UniformIntRNG rng{-50, 50};
@@ -173,6 +173,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_MK4_DOT) {
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96");
#elif MEGDNN_ARMV7 #elif MEGDNN_ARMV7
epsilon = 1;
cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96"); cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96");
#endif #endif
#undef cb #undef cb
@@ -194,6 +195,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS,
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96"); cb("IM2COLMATMUL:AARCH64_INT8X8X32_MK4_8X12X4_DOTPROD:96");
#elif MEGDNN_ARMV7 #elif MEGDNN_ARMV7
epsilon = 1;
cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96"); cb("IM2COLMATMUL:AARCH32_INT8_MK4_8X4X4_DOTPROD:96");
#endif #endif
#undef cb #undef cb
@@ -273,7 +275,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUANTIZEDASYM) {
dtype::Quantized8Asymm(50.3f, (uint8_t)120), name); dtype::Quantized8Asymm(50.3f, (uint8_t)120), name);
float epsilon = 0.001; float epsilon = 0.001;
#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD"); cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD");
#else #else
cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8"); cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8");
@@ -305,13 +307,13 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_QUINT8x8x32) {
dtype::QuantizedS32(1.2 * 1.3), {}, name); dtype::QuantizedS32(1.2 * 1.3), {}, name);


#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD"); cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X4_DOTPROD");
#else #else
cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8"); cb("IM2COLMATMUL:AARCH64_QUINT8_K8X8X8");
#endif #endif
#elif MEGDNN_ARMV7 #elif MEGDNN_ARMV7
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
cb("IM2COLMATMUL:AARCH32_QUINT8_K4X8X4"); cb("IM2COLMATMUL:AARCH32_QUINT8_K4X8X4");
#endif #endif
cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8"); cb("IM2COLMATMUL:ARMV7_QUINT8_K4X8X8");
@@ -392,7 +394,7 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_FP16) {
#endif #endif


#if MEGDNN_AARCH64 || MEGDNN_ARMV7 #if MEGDNN_AARCH64 || MEGDNN_ARMV7
#if !__ARM_FEATURE_DOTPROD
//! enable none dot algo now
TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) { TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32NCHW44_S2) {
using namespace conv_bias; using namespace conv_bias;
std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args( std::vector<conv_bias::TestArg> args = get_nchw44_conv_bias_args(
@@ -483,10 +485,9 @@ TEST_F(ARM_COMMON_MULTI_THREADS,


#endif #endif
#endif #endif
#endif


#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
TEST_F(ARM_COMMON_MULTI_THREADS, TEST_F(ARM_COMMON_MULTI_THREADS,
CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44DOT_FUSE) { CONV_BIAS_IM2COLMATMUL_QUANTIZEDSYM_NCHW44DOT_FUSE) {
UniformIntRNG rng{-50, 50}; UniformIntRNG rng{-50, 50};
@@ -516,14 +517,14 @@ TEST_F(ARM_COMMON_MULTI_THREADS, CONV_BIAS_IM2COLMATMUL_INT8x8x32) {
#define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name); #define cb(name) checker_conv_bias_mul_int8x8x32(args, handle(), name);


#if MEGDNN_AARCH64 #if MEGDNN_AARCH64
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD"); cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X12X4_DOTPROD");
#else #else
cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8"); cb("IM2COLMATMUL:AARCH64_INT8X8X32_K8X8X8");
cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16"); cb("IM2COLMATMUL:AARCH64_INT8X8X32_K4X4X16");
#endif #endif
#elif MEGDNN_ARMV7 #elif MEGDNN_ARMV7
#if __ARM_FEATURE_DOTPROD
#if MGB_ENABLE_DOT
cb("IM2COLMATMUL:AARCH32_INT8_K6X8X4"); cb("IM2COLMATMUL:AARCH32_INT8_K6X8X4");
#endif #endif
cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8"); cb("IM2COLMATMUL:ARMV7_INT8X8X32_K4X8X8");


Some files were not shown because too many files changed in this diff

Loading…
Cancel
Save