Browse Source

refactor(dnn/arm): refactor CPU heuristic algo selection

GitOrigin-RevId: 60d2646bb3
release-1.1
Megvii Engine Team 4 years ago
parent
commit
2a3f4d099a
47 changed files with 856 additions and 172 deletions
  1. +12
    -0
      dnn/include/megdnn/oprs/base.h
  2. +4
    -0
      dnn/src/aarch64/conv_bias/fp16/algos.h
  3. +4
    -0
      dnn/src/aarch64/conv_bias/fp32/algos.h
  4. +3
    -0
      dnn/src/aarch64/conv_bias/int8/algos.h
  5. +2
    -3
      dnn/src/aarch64/conv_bias/opr_impl.cpp
  6. +3
    -0
      dnn/src/aarch64/conv_bias/quint8/algos.h
  7. +32
    -18
      dnn/src/aarch64/matrix_mul/algos.cpp
  8. +4
    -4
      dnn/src/aarch64/matrix_mul/algos.h
  9. +11
    -4
      dnn/src/arm_common/conv_bias/f16/algos.h
  10. +26
    -8
      dnn/src/arm_common/conv_bias/fp32/algos.h
  11. +35
    -3
      dnn/src/arm_common/conv_bias/int8/algos.h
  12. +18
    -0
      dnn/src/arm_common/conv_bias/int8x8x16/algos.h
  13. +92
    -15
      dnn/src/arm_common/conv_bias/opr_impl.cpp
  14. +4
    -1
      dnn/src/arm_common/conv_bias/opr_impl.h
  15. +12
    -0
      dnn/src/arm_common/conv_bias/quint8/algos.h
  16. +14
    -8
      dnn/src/arm_common/matrix_mul/algos.h
  17. +4
    -3
      dnn/src/arm_common/matrix_mul/opr_impl.cpp
  18. +3
    -0
      dnn/src/armv7/conv_bias/int8/algos.h
  19. +4
    -0
      dnn/src/armv7/conv_bias/quint8/algos.h
  20. +23
    -14
      dnn/src/armv7/matrix_mul/algos.cpp
  21. +3
    -3
      dnn/src/armv7/matrix_mul/algos.h
  22. +0
    -1
      dnn/src/armv7/matrix_mul/opr_impl.h
  23. +28
    -0
      dnn/src/common/utils.h
  24. +26
    -0
      dnn/src/fallback/conv_bias/algos.h
  25. +4
    -1
      dnn/src/fallback/conv_bias/common.h
  26. +2
    -1
      dnn/src/fallback/conv_bias/conv1x1/algos.cpp
  27. +5
    -0
      dnn/src/fallback/conv_bias/conv1x1/algos.h
  28. +10
    -0
      dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h
  29. +18
    -8
      dnn/src/fallback/conv_bias/im2col/algos.h
  30. +90
    -14
      dnn/src/fallback/conv_bias/opr_impl.cpp
  31. +19
    -0
      dnn/src/fallback/conv_bias/opr_impl.h
  32. +19
    -2
      dnn/src/fallback/convolution/algos.h
  33. +83
    -15
      dnn/src/fallback/convolution/opr_impl.cpp
  34. +31
    -0
      dnn/src/fallback/convolution/opr_impl.h
  35. +1
    -1
      dnn/src/fallback/matrix_mul/algos.cpp
  36. +9
    -1
      dnn/src/fallback/matrix_mul/algos.h
  37. +17
    -13
      dnn/src/fallback/matrix_mul/gemm_common.h
  38. +59
    -5
      dnn/src/fallback/matrix_mul/opr_impl.cpp
  39. +18
    -1
      dnn/src/fallback/matrix_mul/opr_impl.h
  40. +17
    -2
      dnn/src/x86/conv_bias/f32/algos.h
  41. +0
    -2
      dnn/src/x86/conv_bias/int8/algo_usable_preferred.cpp
  42. +24
    -0
      dnn/src/x86/conv_bias/int8/algos.h
  43. +39
    -2
      dnn/src/x86/conv_bias/opr_impl.cpp
  44. +2
    -0
      dnn/src/x86/conv_bias/opr_impl.h
  45. +13
    -10
      dnn/src/x86/matrix_mul/algos.cpp
  46. +4
    -4
      dnn/src/x86/matrix_mul/algos.h
  47. +5
    -5
      src/opr/impl/dnn/convolution.cpp

+ 12
- 0
dnn/include/megdnn/oprs/base.h View File

@@ -76,6 +76,18 @@ enum class AlgoSelectionStrategy {
FULL_RUN = 2, FULL_RUN = 2,
}; };


/**
* \brief separate algo by datatype for Matmul and conv
*/
enum class AlgoDataType : uint32_t {
FLOAT32 = 1 << 0,
FLOAT16 = 1 << 1,
QINT8X8X32 = 1 << 2,
QUINT8X8X32 = 1 << 3,
INT8X8X16 = 1 << 4,
INT16X16X32 = 1 << 5,
};

/*! /*!
* \brief Abstract representation of an algorithm for implementing * \brief Abstract representation of an algorithm for implementing
* the operator * the operator


+ 4
- 0
dnn/src/aarch64/conv_bias/fp16/algos.h View File

@@ -27,6 +27,10 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;


SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override;

ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT};
}
}; };
} // namespace aarch64 } // namespace aarch64
} // namespace megdnn } // namespace megdnn


+ 4
- 0
dnn/src/aarch64/conv_bias/fp32/algos.h View File

@@ -32,6 +32,10 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;


SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override;

ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
}; };


} // namespace aarch64 } // namespace aarch64


+ 3
- 0
dnn/src/aarch64/conv_bias/int8/algos.h View File

@@ -45,6 +45,9 @@ public:
return static_cast<ConvBiasImpl*>(conv_bias_opr) return static_cast<ConvBiasImpl*>(conv_bias_opr)
->is_matmul_quantized_prefer(param); ->is_matmul_quantized_prefer(param);
} }
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL};
}
}; };


} // namespace aarch64 } // namespace aarch64


+ 2
- 3
dnn/src/aarch64/conv_bias/opr_impl.cpp View File

@@ -50,10 +50,9 @@ SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
auto&& algos = arm_common::ConvBiasImpl::algo_pack(); auto&& algos = arm_common::ConvBiasImpl::algo_pack();
algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(),
sl_algo_pack.direct_algos.end()); sl_algo_pack.direct_algos.end());
//! We put matmul algos at the end. Because matmul will get privilege when
//! We put matmul algos at the begin. Because matmul will get privilege when
//! prefer return true. See //! prefer return true. See
//! fallback::ConvolutionImpl::ncb_1g_get_all_algorithms for more details.
algos.insert(algos.end(), sl_algo_pack.matmul_algos.begin(),
algos.insert(algos.begin(), sl_algo_pack.matmul_algos.begin(),
sl_algo_pack.matmul_algos.end()); sl_algo_pack.matmul_algos.end());
return std::move(algos); return std::move(algos);
} }


+ 3
- 0
dnn/src/aarch64/conv_bias/quint8/algos.h View File

@@ -45,6 +45,9 @@ public:
return static_cast<ConvBiasImpl*>(conv_bias_opr) return static_cast<ConvBiasImpl*>(conv_bias_opr)
->is_matmul_quantized_prefer(param); ->is_matmul_quantized_prefer(param);
} }
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL};
}
}; };
} // namespace aarch64 } // namespace aarch64
} // namespace megdnn } // namespace megdnn


+ 32
- 18
dnn/src/aarch64/matrix_mul/algos.cpp View File

@@ -89,7 +89,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern(
} }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern, MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern,
"AlgoF32K8x12x1Impl"_hash, "AlgoF32K8x12x1Impl"_hash,
aarch64::matmul::sgemm_8x12, float, float);
aarch64::matmul::sgemm_8x12, float, float,
AlgoDataType::FLOAT32, DEFAULT);


/* ===================== F32_MK4_8X12X1 algo ===================== */ /* ===================== F32_MK4_8X12X1 algo ===================== */
bool MatrixMulImpl::AlgoF32MK4_8x12x1::usable( bool MatrixMulImpl::AlgoF32MK4_8x12x1::usable(
@@ -151,7 +152,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4_8x12x1,
megdnn_aarch64_matmul_kern, megdnn_aarch64_matmul_kern,
"AlgoF32MK4_8x12x1Impl"_hash, "AlgoF32MK4_8x12x1Impl"_hash,
aarch64::matmul::sgemm_mk4_8x12, float, aarch64::matmul::sgemm_mk4_8x12, float,
float);
float, AlgoDataType::FLOAT32, MK4);


/* ===================== F32K4X16X1 algo ===================== */ /* ===================== F32K4X16X1 algo ===================== */


@@ -210,7 +211,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K4x16x1::get_kern(
} }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K4x16x1, megdnn_aarch64_matmul_kern, MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K4x16x1, megdnn_aarch64_matmul_kern,
"AlgoF32K4x16x1Impl"_hash, "AlgoF32K4x16x1Impl"_hash,
aarch64::matmul::sgemm_4x16, float, float);
aarch64::matmul::sgemm_4x16, float, float,
AlgoDataType::FLOAT32, MK4);


/* ===================== F32MK4_4x16 algo ===================== */ /* ===================== F32MK4_4x16 algo ===================== */


@@ -328,7 +330,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K8x24x1::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K8x24x1, megdnn_aarch64_matmul_kern, MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K8x24x1, megdnn_aarch64_matmul_kern,
"AlogF16K8x24x1Impl"_hash, "AlogF16K8x24x1Impl"_hash,
aarch64::matmul::hgemm_8x24, dt_float16, aarch64::matmul::hgemm_8x24, dt_float16,
dt_float16);
dt_float16, AlgoDataType::FLOAT16,
DEFAULT);
/* ===================== F16_MK8_8x8 algo ===================== */ /* ===================== F16_MK8_8x8 algo ===================== */


bool MatrixMulImpl::AlgoF16MK8_8x8::usable( bool MatrixMulImpl::AlgoF16MK8_8x8::usable(
@@ -449,7 +452,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x12x4DotProd,
megdnn_aarch64_matmul_kern, megdnn_aarch64_matmul_kern,
"AlgoInt8x8x32K8x12x4DotProdImpl"_hash, "AlgoInt8x8x32K8x12x4DotProdImpl"_hash,
aarch64::matmul::gemm_s8_8x12, int8_t, aarch64::matmul::gemm_s8_8x12, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32,
DEFAULT);


/* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */ /* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */
namespace { namespace {
@@ -520,7 +524,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x12x4DotProd,
megdnn_aarch64_matmul_kern, megdnn_aarch64_matmul_kern,
"AlgoInt8x8x32MK4_8x12x4DotProdImpl"_hash, "AlgoInt8x8x32MK4_8x12x4DotProdImpl"_hash,
aarch64::matmul::gemm_mk4_s8_8x12, int8_t, aarch64::matmul::gemm_mk4_s8_8x12, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32,
MK4_DOT);
#else #else


/* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */ /* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */
@@ -593,7 +598,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_4x4x16,
megdnn_aarch64_matmul_kern, megdnn_aarch64_matmul_kern,
"AlgoInt8x8x32MK4_4x4x16Impl"_hash, "AlgoInt8x8x32MK4_4x4x16Impl"_hash,
aarch64::matmul::gemm_mk4_s8_4x4, int8_t, aarch64::matmul::gemm_mk4_s8_4x4, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32,
MK4);


/* ===================== Int8x8x32 K4x4x16 algo ===================== */ /* ===================== Int8x8x32 K4x4x16 algo ===================== */
namespace { namespace {
@@ -656,7 +662,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x4x16,
megdnn_aarch64_matmul_kern, megdnn_aarch64_matmul_kern,
"AlgoInt8x8x32K4x4x16Impl"_hash, "AlgoInt8x8x32K4x4x16Impl"_hash,
aarch64::matmul::gemm_s8_4x4, int8_t, aarch64::matmul::gemm_s8_4x4, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32,
DEFAULT);
/* ===================== Int8x8x32 K8x8x8 algo ===================== */ /* ===================== Int8x8x32 K8x8x8 algo ===================== */
namespace { namespace {
void int8x8x32_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) { void int8x8x32_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) {
@@ -717,7 +724,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x8x8,
megdnn_aarch64_matmul_kern, megdnn_aarch64_matmul_kern,
"AlgoInt8x8x32K8x8x8Impl"_hash, "AlgoInt8x8x32K8x8x8Impl"_hash,
aarch64::matmul::gemm_s8_8x8, int8_t, aarch64::matmul::gemm_s8_8x8, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32,
DEFAULT);
#endif #endif


/* ===================== Int8x8x16 K8x8x8 algo ===================== */ /* ===================== Int8x8x16 K8x8x8 algo ===================== */
@@ -785,7 +793,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K8x8x8,
megdnn_aarch64_matmul_kern, megdnn_aarch64_matmul_kern,
"AlgoInt8x8x16K8x8x8Impl"_hash, "AlgoInt8x8x16K8x8x8Impl"_hash,
aarch64::matmul::gemm_s8x8x16_8x8, int8_t, aarch64::matmul::gemm_s8x8x16_8x8, int8_t,
int16_t);
int16_t, AlgoDataType::INT8X8X16, DEFAULT);
/* ===================== Int8x8x16 K4x4x16 algo ===================== */ /* ===================== Int8x8x16 K4x4x16 algo ===================== */
namespace { namespace {
void int8x8x16_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) { void int8x8x16_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) {
@@ -852,7 +860,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x4x16,
megdnn_aarch64_matmul_kern, megdnn_aarch64_matmul_kern,
"AlgoInt8x8x16K4x4x16Impl"_hash, "AlgoInt8x8x16K4x4x16Impl"_hash,
aarch64::matmul::gemm_s8x8x16_4x4, int8_t, aarch64::matmul::gemm_s8x8x16_4x4, int8_t,
int16_t);
int16_t, AlgoDataType::INT8X8X16, DEFAULT);


/* ===================== Int8x8x16 K16x12x4 algo ===================== */ /* ===================== Int8x8x16 K16x12x4 algo ===================== */
namespace { namespace {
@@ -929,7 +937,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
AlgoInt8x8x16MK4_16x12x4, megdnn_aarch64_matmul_kern, AlgoInt8x8x16MK4_16x12x4, megdnn_aarch64_matmul_kern,
"AlgoInt8x8x16MK4_16x12x4Impl"_hash, "AlgoInt8x8x16MK4_16x12x4Impl"_hash,
aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53, int8_t, int16_t, int16_t);
aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53, int8_t, int16_t, int16_t,
AlgoDataType::INT8X8X16, MK4);


/* ===================== Int8x8x16 MK4 4x4x8 algo ===================== */ /* ===================== Int8x8x16 MK4 4x4x8 algo ===================== */
namespace { namespace {
@@ -1007,7 +1016,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_4x4x8,
megdnn_aarch64_matmul_kern, megdnn_aarch64_matmul_kern,
"AlgoInt8x8x16MK4_4x4x8_Impl"_hash, "AlgoInt8x8x16MK4_4x4x8_Impl"_hash,
aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72, aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72,
int8_t, int16_t);
int8_t, int16_t, AlgoDataType::INT8X8X16,
MK4);


/* ===================== Int16x16x32 K12x8x1 algo ===================== */ /* ===================== Int16x16x32 K12x8x1 algo ===================== */
namespace { namespace {
@@ -1078,7 +1088,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x8x1,
megdnn_aarch64_matmul_kern, megdnn_aarch64_matmul_kern,
"AlgoInt16x16x32K12x8x1Impl"_hash, "AlgoInt16x16x32K12x8x1Impl"_hash,
aarch64::matmul::gemm_s16_12x8x1, int16_t, aarch64::matmul::gemm_s16_12x8x1, int16_t,
int32_t);
int32_t, AlgoDataType::INT16X16X32,
DEFAULT);


/* ===================== Int16x16x32MK8_8x8 algo ===================== */ /* ===================== Int16x16x32MK8_8x8 algo ===================== */


@@ -1201,7 +1212,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x4DotProd,
megdnn_aarch64_matmul_kern, megdnn_aarch64_matmul_kern,
"AlgoQuint8K8x8x4DotProdImpl"_hash, "AlgoQuint8K8x8x4DotProdImpl"_hash,
aarch64::matmul::gemm_u8_8x8, uint8_t, aarch64::matmul::gemm_u8_8x8, uint8_t,
int32_t);
int32_t, AlgoDataType::QUINT8X8X32,
DEFAULT);
/* ===================== Quint8 Gemv DotProd algo ===================== */ /* ===================== Quint8 Gemv DotProd algo ===================== */
namespace { namespace {
void quint8_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { void quint8_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) {
@@ -1307,7 +1319,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8,
megdnn_aarch64_matmul_kern, megdnn_aarch64_matmul_kern,
"AlgoQuint8K8x8x8Impl"_hash, "AlgoQuint8K8x8x8Impl"_hash,
aarch64::matmul::gemm_u8_8x8, uint8_t, aarch64::matmul::gemm_u8_8x8, uint8_t,
int32_t);
int32_t, AlgoDataType::QUINT8X8X32,
DEFAULT);
#endif #endif


/* ===================== Int8x8x16 K8x8x8 algo ===================== */ /* ===================== Int8x8x16 K8x8x8 algo ===================== */
@@ -1378,6 +1391,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_K8x8x8, MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_K8x8x8,
megdnn_aarch64_matmul_kern, megdnn_aarch64_matmul_kern,
"AlgoInt8x8x16MK4_K8x8x8Impl"_hash, "AlgoInt8x8x16MK4_K8x8x8Impl"_hash,
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8, int8_t,
int16_t);
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8,
int8_t, int16_t, AlgoDataType::INT8X8X16,
MK4);
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 4
- 4
dnn/src/aarch64/matrix_mul/algos.h View File

@@ -61,7 +61,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4, AlgoDataType::FLOAT32, MK4)
}; };


class MatrixMulImpl::AlgoF32Gemv final class MatrixMulImpl::AlgoF32Gemv final
@@ -88,7 +88,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::FLOAT16, MK8)
}; };


#endif #endif
@@ -253,7 +253,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::INT16X16X32, MK8)
}; };


#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
@@ -281,7 +281,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT)
}; };
#else #else




+ 11
- 4
dnn/src/arm_common/conv_bias/f16/algos.h View File

@@ -29,7 +29,7 @@ public:
} }
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
}; };


class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase { class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase {
@@ -44,7 +44,7 @@ public:
} }
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);


}; };
class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase { class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase {
@@ -60,7 +60,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }


MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
}; };
class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase { class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase {
public: public:
@@ -74,7 +74,7 @@ public:
} }
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16);
}; };


class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { class ConvBiasImpl::AlgoF16Direct final : public AlgoBase {
@@ -90,6 +90,10 @@ public:


virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;

ConvAlgoTypePack get_algo_type() const override{
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase {
@@ -103,6 +107,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT};
}
}; };


} // namespace arm_common } // namespace arm_common


+ 26
- 8
dnn/src/arm_common/conv_bias/fp32/algos.h View File

@@ -29,7 +29,7 @@ public:
} }
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
}; };


class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase { class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase {
@@ -44,7 +44,7 @@ public:
} }
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
}; };


class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase {
@@ -59,7 +59,7 @@ public:
} }
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
}; };


class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase {
@@ -74,7 +74,7 @@ public:
} }
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
}; };


class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase { class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase {
@@ -89,7 +89,7 @@ public:
} }
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
}; };


//===================== NCHW44 Winograd Support =====================// //===================== NCHW44 Winograd Support =====================//
@@ -106,7 +106,7 @@ public:
} }
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
}; };


class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase {
@@ -122,7 +122,7 @@ public:
} }
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
}; };


class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase {
@@ -138,7 +138,7 @@ public:
} }
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
}; };
// ================================================================= // // ================================================================= //


@@ -154,6 +154,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase {
@@ -168,6 +171,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase {
@@ -182,6 +188,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase {
@@ -197,6 +206,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase {
@@ -212,6 +224,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase {
@@ -226,6 +241,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
}; };


} // namespace arm_common } // namespace arm_common


+ 35
- 3
dnn/src/arm_common/conv_bias/int8/algos.h View File

@@ -29,6 +29,10 @@ public:
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;


bool is_preferred(const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override;

ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase {
@@ -42,6 +46,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase {
@@ -55,6 +62,9 @@ public:
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
bool is_preferred(const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase {
@@ -68,6 +78,9 @@ public:
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
bool is_preferred(const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase {
@@ -79,6 +92,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase {
@@ -90,6 +106,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
}; };


#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
@@ -104,6 +123,9 @@ public:
size_t get_workspace(const NCBKernSizeParam&) const override; size_t get_workspace(const NCBKernSizeParam&) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase {
@@ -117,6 +139,9 @@ public:
size_t get_workspace(const NCBKernSizeParam&) const override; size_t get_workspace(const NCBKernSizeParam&) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase { class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase {
@@ -131,6 +156,9 @@ public:
size_t get_workspace(const NCBKernSizeParam&) const override; size_t get_workspace(const NCBKernSizeParam&) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase {
@@ -148,6 +176,10 @@ public:
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;


bool is_preferred(const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override;

ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
}; };
#endif #endif


@@ -163,7 +195,7 @@ public:
} }
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32);
}; };


//=======================input int8 compute fp32 output int8============ //=======================input int8 compute fp32 output int8============
@@ -180,7 +212,7 @@ public:
} }
return m_name.c_str(); return m_name.c_str();
} }
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32);
}; };


//=======================input int8 compute int16 output int8============ //=======================input int8 compute int16 output int8============
@@ -198,7 +230,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }


MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32);
}; };


} // namespace arm_common } // namespace arm_common


+ 18
- 0
dnn/src/arm_common/conv_bias/int8x8x16/algos.h View File

@@ -36,6 +36,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase {
@@ -48,6 +51,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase {
@@ -71,6 +77,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase { class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase {
@@ -84,6 +93,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final : public AlgoBase {
@@ -96,6 +108,9 @@ public:
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase { class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase {
@@ -111,6 +126,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT};
}
}; };


} // namespace arm_common } // namespace arm_common


+ 92
- 15
dnn/src/arm_common/conv_bias/opr_impl.cpp View File

@@ -10,6 +10,7 @@
* implied. * implied.
*/ */


#include "megdnn/opr_param_defs.h"
#include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8/algos.h"
#include "src/arm_common/conv_bias/int8x8x16/algos.h" #include "src/arm_common/conv_bias/int8x8x16/algos.h"
#include "src/arm_common/conv_bias/quint8/algos.h" #include "src/arm_common/conv_bias/quint8/algos.h"
@@ -122,9 +123,11 @@ public:


static CpuOprDelegationStorage<2> storage; static CpuOprDelegationStorage<2> storage;
auto matmul_opr = storage.get<MatrixMul, 0>(); auto matmul_opr = storage.get<MatrixMul, 0>();
using MatmulFormat = param::MatrixMul::Format;
auto&& matmul_algos = auto&& matmul_algos =
static_cast<arm_common::MatrixMulImpl*>(matmul_opr) static_cast<arm_common::MatrixMulImpl*>(matmul_opr)
->algo_pack();
->select_algo_type(
{AlgoDataType::FLOAT32, MatmulFormat::MK4});
for (auto&& algo : matmul_algos) { for (auto&& algo : matmul_algos) {
if (algo->type() == nullptr) if (algo->type() == nullptr)
continue; continue;
@@ -133,38 +136,62 @@ public:
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get()); winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF63(
refhold.emplace_back(new AlgoFP32WinogradF63_4x4(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get()); winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF63_4x4(
refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get()); winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF54(
refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get()); winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF45(
//! uncomment this when low precision mode is done
#if 0
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get()); winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44(
#endif
//! Qint8x8x32 winograd compute with fp32
refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get()); winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44(
}
}
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr)
->select_algo_type({AlgoDataType::FLOAT32,
MatmulFormat::DEFAULT});
for (auto&& algo : matmul_algos) {
if (algo->type() == nullptr)
continue;
for (uint32_t tile_size : {16, 8, 24, 32}) {
refhold.emplace_back(new AlgoFP32WinogradF63(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get()); winograd_algos.emplace_back(refhold.back().get());
//! uncomment this when low precision mode is done
#if 0
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44(
refhold.emplace_back(new AlgoFP32WinogradF54(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get()); winograd_algos.emplace_back(refhold.back().get());
#endif
refhold.emplace_back(new AlgoFP32WinogradF45(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
}
}

#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr)
->select_algo_type({AlgoDataType::FLOAT16,
MatmulFormat::DEFAULT});
for (auto&& algo : matmul_algos) {
if (algo->type() == nullptr)
continue;
for (uint32_t tile_size : {16, 8, 24, 32}) {
refhold.emplace_back(new AlgoFP16WinogradF23( refhold.emplace_back(new AlgoFP16WinogradF23(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
@@ -177,19 +204,33 @@ public:
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get()); winograd_algos.emplace_back(refhold.back().get());
}
}
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr)
->select_algo_type({AlgoDataType::FLOAT16,
MatmulFormat::MK8});
for (auto&& algo : matmul_algos) {
if (algo->type() == nullptr)
continue;
for (uint32_t tile_size : {16, 8, 24, 32}) {
refhold.emplace_back(new AlgoFP16WinogradF23_8x8( refhold.emplace_back(new AlgoFP16WinogradF23_8x8(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get()); winograd_algos.emplace_back(refhold.back().get());
}
}
#endif #endif
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr)
->select_algo_type({AlgoDataType::INT16X16X32,
MatmulFormat::MK8});
for (auto&& algo : matmul_algos) {
if (algo->type() == nullptr)
continue;
for (uint32_t tile_size : {16, 8, 24, 32}) {
refhold.emplace_back(new AlgoS8WinogradF23_8x8( refhold.emplace_back(new AlgoS8WinogradF23_8x8(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
winograd_algos.emplace_back(refhold.back().get()); winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size));
winograd_algos.emplace_back(refhold.back().get());
refhold.emplace_back(new AlgoS8WinogradF23_8x8_NCHW44( refhold.emplace_back(new AlgoS8WinogradF23_8x8_NCHW44(
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo),
tile_size)); tile_size));
@@ -240,6 +281,42 @@ bool ConvBiasImpl::is_matmul_quantized_prefer(
return conv_direct_unusable; return conv_direct_unusable;
} }


SmallVector<AlgoCategory> ConvBiasImpl::suggest_algo_category_order(
const NCBKernSizeParam& param) const {
auto IC = param.filter_meta.icpg;
auto OC = param.filter_meta.ocpg;
auto FH = param.filter_meta.spatial[0];
auto FW = param.filter_meta.spatial[1];
//! TODO: now winograd only support fast-run
if (param.filter_meta.format == param::ConvBias::Format::NCHW_WINOGRAD ||
param.filter_meta.format == param::ConvBias::Format::NCHW44_WINOGRAD ||
param.filter_meta.format == param::ConvBias::Format::NCHW88_WINOGRAD) {
return {AlgoCategory::WINOGRAD};
}
//! im2col
bool im2col_prefer = (IC >= 32 || OC >= 32);
//! quantized algo use matmul when direct algo is unusable
if (param.src_type.category() == DTypeCategory::QUANTIZED) {
im2col_prefer = is_matmul_quantized_prefer(param);
}
//! conv1x1
im2col_prefer |= (FH == 1 && FW == 1);
//! nchw44 and nchw44-dot hybird mode is direct
if (param.filter_meta.format == param::ConvBias::Format::NCHW44 ||
param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT) {
if (IC < 4) {
im2col_prefer = false;
}
}
if (im2col_prefer) {
return {AlgoCategory::IM2COL, AlgoCategory::DIRECT,
AlgoCategory::NAIVE};
} else {
return {AlgoCategory::DIRECT, AlgoCategory::IM2COL,
AlgoCategory::NAIVE};
}
}

const char* ConvBiasImpl::get_algorithm_set_name() const { const char* ConvBiasImpl::get_algorithm_set_name() const {
// arm common version 0 // arm common version 0
return "AC0"; return "AC0";


+ 4
- 1
dnn/src/arm_common/conv_bias/opr_impl.h View File

@@ -28,6 +28,9 @@ public:


bool is_matmul_quantized_prefer( bool is_matmul_quantized_prefer(
const ConvBiasImpl::NCBKernSizeParam& ncb_param) const override; const ConvBiasImpl::NCBKernSizeParam& ncb_param) const override;

SmallVector<AlgoCategory> suggest_algo_category_order(
const NCBKernSizeParam& param) const override;
class AlgoPack; class AlgoPack;


protected: protected:
@@ -90,7 +93,7 @@ private:
class AlgoF16Direct; class AlgoF16Direct;
class AlgoF16DirectStride1; class AlgoF16DirectStride1;
#endif #endif
};
};


} // namespace arm_common } // namespace arm_common
} // namespace megdnn } // namespace megdnn


+ 12
- 0
dnn/src/arm_common/conv_bias/quint8/algos.h View File

@@ -29,6 +29,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase { class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase {
@@ -42,6 +45,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT};
}
}; };
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase {
@@ -56,6 +62,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT};
}
}; };


class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase { class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase {
@@ -69,6 +78,9 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
virtual SmallVector<NCBKern> dispatch_kerns( virtual SmallVector<NCBKern> dispatch_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT};
}
}; };
#endif #endif
} // namespace arm_common } // namespace arm_common


+ 14
- 8
dnn/src/arm_common/matrix_mul/algos.h View File

@@ -26,7 +26,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::INT8X8X16, DEFAULT)
}; };


class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase {
@@ -40,7 +40,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT)
}; };


class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase {
@@ -54,7 +54,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4)
}; };


#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
@@ -69,7 +69,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4_DOT)
}; };
#endif #endif


@@ -87,7 +87,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT)
}; };


class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase {
@@ -101,7 +101,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4)
}; };


#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
@@ -116,7 +116,7 @@ public:
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::FLOAT16, DEFAULT)
}; };
#endif #endif


@@ -131,7 +131,13 @@ public:
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(1, 1, 1, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(
1, 1, 1, 4,
static_cast<AlgoDataType>(
static_cast<uint32_t>(AlgoDataType::FLOAT16) |
static_cast<uint32_t>(AlgoDataType::FLOAT32) |
static_cast<uint32_t>(AlgoDataType::QINT8X8X32)),
DEFAULT)
}; };


} // namespace arm_common } // namespace arm_common


+ 4
- 3
dnn/src/arm_common/matrix_mul/opr_impl.cpp View File

@@ -25,7 +25,7 @@ void* const MatrixMulImpl::sm_arm_common_algo_type =
class MatrixMulImpl::AlgoPack : NonCopyableObj { class MatrixMulImpl::AlgoPack : NonCopyableObj {
AlgoInt8x8x16 int8x8x16; AlgoInt8x8x16 int8x8x16;
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
AlgoF16Gemv f16gemv;
AlgoF16Gemv f16gemv;
#endif #endif
AlgoInt8x8x32Gemv int8x8x32_gemv; AlgoInt8x8x32Gemv int8x8x32_gemv;
AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4; AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4;
@@ -34,10 +34,11 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj {
#endif #endif
AlgoGevm gevm; AlgoGevm gevm;
AlgoF32GemvMK4 f32_gemv_mk4; AlgoF32GemvMK4 f32_gemv_mk4;

public: public:
AlgoPack() { AlgoPack() {
all_algos.emplace_back(&int8x8x16); all_algos.emplace_back(&int8x8x16);
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
all_algos.emplace_back(&f16gemv); all_algos.emplace_back(&f16gemv);
#endif #endif
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
@@ -47,7 +48,7 @@ public:
all_algos.emplace_back(&int8x8x32_gemv_mk4); all_algos.emplace_back(&int8x8x32_gemv_mk4);
all_algos.emplace_back(&f32_gemv_mk4); all_algos.emplace_back(&f32_gemv_mk4);
all_algos.emplace_back(&gevm); all_algos.emplace_back(&gevm);
}
}
SmallVector<AlgoBase*> all_algos; SmallVector<AlgoBase*> all_algos;
}; };




+ 3
- 0
dnn/src/armv7/conv_bias/int8/algos.h View File

@@ -37,6 +37,9 @@ public:
size_t group = param.filter_meta.group; size_t group = param.filter_meta.group;
return {{kimpl, {group, 1_z, 1_z}}}; return {{kimpl, {group, 1_z, 1_z}}};
} }
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL};
}
}; };


} // namespace armv7 } // namespace armv7


+ 4
- 0
dnn/src/armv7/conv_bias/quint8/algos.h View File

@@ -38,6 +38,10 @@ public:
size_t group = param.filter_meta.group; size_t group = param.filter_meta.group;
return {{kimpl, {group, 1_z, 1_z}}}; return {{kimpl, {group, 1_z, 1_z}}};
} }

ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL};
}
}; };


} // namespace armv7 } // namespace armv7


+ 23
- 14
dnn/src/armv7/matrix_mul/algos.cpp View File

@@ -85,7 +85,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32::get_kern(


MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32, megdnn_armv7_matmul_kern, MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32, megdnn_armv7_matmul_kern,
"AlgoF32Impl"_hash, "AlgoF32Impl"_hash,
armv7::matmul::sgemm_4x12, float, float);
armv7::matmul::sgemm_4x12, float, float,
AlgoDataType::FLOAT32, DEFAULT);


/* ===================== F32 algo mk4 K4x12 ===================== */ /* ===================== F32 algo mk4 K4x12 ===================== */


@@ -154,7 +155,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4Pack4x12,
megdnn_armv7_matmul_kern, megdnn_armv7_matmul_kern,
"AlgoF32MK4Pack4x12"_hash, "AlgoF32MK4Pack4x12"_hash,
armv7::matmul::sgemm_mk4_pack_4x12, float, armv7::matmul::sgemm_mk4_pack_4x12, float,
float);
float, AlgoDataType::FLOAT32, MK4);


#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
/* ===================== F16 K4x16x1 algo ===================== */ /* ===================== F16 K4x16x1 algo ===================== */
@@ -215,7 +216,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K4x16x1::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K4x16x1, megdnn_armv7_matmul_kern, MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K4x16x1, megdnn_armv7_matmul_kern,
"AlgoF16K4x16x1"_hash, "AlgoF16K4x16x1"_hash,
armv7::matmul::hgemm_4x16, dt_float16, armv7::matmul::hgemm_4x16, dt_float16,
dt_float16);
dt_float16, AlgoDataType::FLOAT16,
DEFAULT);


#endif #endif


@@ -280,7 +282,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x2x16,
megdnn_armv7_matmul_kern, megdnn_armv7_matmul_kern,
"AlgoInt8x8x32K4x2x16"_hash, "AlgoInt8x8x32K4x2x16"_hash,
armv7::matmul::gemm_s8_4x2, int8_t, armv7::matmul::gemm_s8_4x2, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32,
DEFAULT);
/* ===================== Int8x8x32 Kernel 4x8x8 algo ===================== */ /* ===================== Int8x8x32 Kernel 4x8x8 algo ===================== */


namespace { namespace {
@@ -342,7 +345,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x8x8,
megdnn_armv7_matmul_kern, megdnn_armv7_matmul_kern,
"AlgoInt8x8x32K4x8x8"_hash, "AlgoInt8x8x32K4x8x8"_hash,
armv7::matmul::gemm_s8_4x8, int8_t, armv7::matmul::gemm_s8_4x8, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32,
DEFAULT);
/* ===================== Quint8 Kernel 4x8x8 algo ===================== */ /* ===================== Quint8 Kernel 4x8x8 algo ===================== */


namespace { namespace {
@@ -402,7 +406,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K4x8x8::get_kern(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K4x8x8, megdnn_armv7_matmul_kern, MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K4x8x8, megdnn_armv7_matmul_kern,
"AlgoQuint8K4x8x8"_hash, "AlgoQuint8K4x8x8"_hash,
armv7::matmul::gemm_u8_4x8, uint8_t, armv7::matmul::gemm_u8_4x8, uint8_t,
int32_t);
int32_t, AlgoDataType::QUINT8X8X32,
DEFAULT);
/* ===================== Int8x8x16 Kernel 2x4x16 algo ===================== */ /* ===================== Int8x8x16 Kernel 2x4x16 algo ===================== */


namespace { namespace {
@@ -468,7 +473,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x2x16,
megdnn_armv7_matmul_kern, megdnn_armv7_matmul_kern,
"AlgoInt8x8x16K4x2x16"_hash, "AlgoInt8x8x16K4x2x16"_hash,
armv7::matmul::gemm_s8x8x16_4x2, int8_t, armv7::matmul::gemm_s8x8x16_4x2, int8_t,
int16_t);
int16_t, AlgoDataType::INT8X8X16, DEFAULT);
/* ===================== Int8x8x16 Kernel 4x8x8 algo ===================== */ /* ===================== Int8x8x16 Kernel 4x8x8 algo ===================== */


namespace { namespace {
@@ -534,7 +539,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x8x8,
megdnn_armv7_matmul_kern, megdnn_armv7_matmul_kern,
"AlgoInt8x8x16K4x8x8"_hash, "AlgoInt8x8x16K4x8x8"_hash,
armv7::matmul::gemm_s8x8x16_4x8, int8_t, armv7::matmul::gemm_s8x8x16_4x8, int8_t,
int16_t);
int16_t, AlgoDataType::INT8X8X16, DEFAULT);


/* =================== Int8x8x16 Kernel MK4 8x8x4 algo ===================*/ /* =================== Int8x8x16 Kernel MK4 8x8x4 algo ===================*/


@@ -602,7 +607,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16MK4_8x8x4,
megdnn_armv7_matmul_kern, megdnn_armv7_matmul_kern,
"AlgoInt8x8x16MK4_8x8x4"_hash, "AlgoInt8x8x16MK4_8x8x4"_hash,
armv7::matmul::gemm_s8x8x16_mk4_8x8, armv7::matmul::gemm_s8x8x16_mk4_8x8,
int8_t, int16_t, int16_t);
int8_t, int16_t, int16_t,
AlgoDataType::INT8X8X16, MK4);


/* ===================== Int16x16x32 Kernel 12x4x1 algo ===================== */ /* ===================== Int16x16x32 Kernel 12x4x1 algo ===================== */


@@ -668,7 +674,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x4x1,
megdnn_armv7_matmul_kern, megdnn_armv7_matmul_kern,
"AlgoInt16x16x32K12x4x1"_hash, "AlgoInt16x16x32K12x4x1"_hash,
armv7::matmul::gemm_s16x16x32_12x4, armv7::matmul::gemm_s16x16x32_12x4,
int16_t, int32_t);
int16_t, int32_t,
AlgoDataType::INT16X16X32, DEFAULT);
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
/* ===================== Int8 K6x8x4 algo ===================== */ /* ===================== Int8 K6x8x4 algo ===================== */
namespace { namespace {
@@ -724,7 +731,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K6x8x4,
megdnn_armv7_matmul_kern, megdnn_armv7_matmul_kern,
"AlgoInt8x8x32K6x8x4"_hash, "AlgoInt8x8x32K6x8x4"_hash,
armv7::matmul::gemm_dots8_6x8, int8_t, armv7::matmul::gemm_dots8_6x8, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32,
DEFAULT);
/* ===================== Quint8 K4x8x4 algo ===================== */ /* ===================== Quint8 K4x8x4 algo ===================== */
namespace { namespace {
void quint8_dot_k4x8x4_kern(const MatrixMulImpl::KernParam& kern_param) { void quint8_dot_k4x8x4_kern(const MatrixMulImpl::KernParam& kern_param) {
@@ -786,7 +794,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8DotK4x8x4,
megdnn_armv7_matmul_kern, megdnn_armv7_matmul_kern,
"AlgoQuint8DotK4x8x4"_hash, "AlgoQuint8DotK4x8x4"_hash,
armv7::matmul::gemm_dot_quint8_4x8, armv7::matmul::gemm_dot_quint8_4x8,
uint8_t, int32_t);
uint8_t, int32_t,
AlgoDataType::QUINT8X8X32, DEFAULT);


/* ======================== Int8 MK4 8x4x4 dot algo ======================== */ /* ======================== Int8 MK4 8x4x4 dot algo ======================== */
namespace { namespace {
@@ -854,7 +863,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x4x4DotProd,
megdnn_armv7_matmul_kern, megdnn_armv7_matmul_kern,
"AlgoInt8x8x32MK4_8x4x4DotProd"_hash, "AlgoInt8x8x32MK4_8x4x4DotProd"_hash,
armv7::matmul::gemm_mk4_dots8_8x4, int8_t, armv7::matmul::gemm_mk4_dots8_8x4, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32, MK4_DOT);
#endif #endif


/* ===================== F32 algo K4x8 ===================== */ /* ===================== F32 algo K4x8 ===================== */
@@ -1099,6 +1108,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_4x2x16,
megdnn_armv7_matmul_kern, megdnn_armv7_matmul_kern,
"AlgoInt8x8x32MK4_4x2x16"_hash, "AlgoInt8x8x32MK4_4x2x16"_hash,
armv7::matmul::gemm_mk4_s8_4x2, int8_t, armv7::matmul::gemm_mk4_s8_4x2, int8_t,
int32_t);
int32_t, AlgoDataType::QINT8X8X32, MK4);


// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 3
- 3
dnn/src/armv7/matrix_mul/algos.h View File

@@ -50,7 +50,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4)
}; };


#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
@@ -73,7 +73,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::FLOAT16, MK8)
}; };
#endif #endif
#if __ARM_FEATURE_DOTPROD #if __ARM_FEATURE_DOTPROD
@@ -205,7 +205,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_arm_common_algo_type; } void* type() const override { return sm_arm_common_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::INT16X16X32, MK8)
}; };


class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase {


+ 0
- 1
dnn/src/armv7/matrix_mul/opr_impl.h View File

@@ -18,7 +18,6 @@ namespace armv7 {
class MatrixMulImpl : public arm_common::MatrixMulImpl { class MatrixMulImpl : public arm_common::MatrixMulImpl {
public: public:
using arm_common::MatrixMulImpl::MatrixMulImpl; using arm_common::MatrixMulImpl::MatrixMulImpl;

SmallVector<AlgoBase*> algo_pack() override; SmallVector<AlgoBase*> algo_pack() override;


private: private:


+ 28
- 0
dnn/src/common/utils.h View File

@@ -110,6 +110,11 @@ void __log__(LogLevel level, const char* file, const char* func, int line,
} while (0) } while (0)
#endif // megdnn_ENABLE_LOGGING #endif // megdnn_ENABLE_LOGGING


template <typename T>
constexpr int32_t cast_int(T data) {
return static_cast<int32_t>(data);
}

/* helper functions */ /* helper functions */
/** /**
* \brief Get the next `stride' index lexicographically. * \brief Get the next `stride' index lexicographically.
@@ -187,6 +192,29 @@ std::unique_ptr<T> make_unique(Args&&... args) {
return std::unique_ptr<T>(new T(std::forward<Args>(args)...)); return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
} }


/*!
* \brief check whether the source enum contain the target data type enum
*/
bool inline contain_data_type(detail::AlgoDataType source,
detail::AlgoDataType target) {
return static_cast<bool>(static_cast<uint32_t>(source) &
static_cast<uint32_t>(target));
}

/*!
* \brief get the source enum contain the data type number
*/
template<typename T>
size_t nr_type_contain(T index) {
uint32_t sr_index = static_cast<uint32_t>(index);
size_t nr_type = 0;
while (sr_index != 0) {
nr_type++;
sr_index &= (sr_index - 1);
}
return nr_type;
}

/** /**
* \brief Aligned workspace bundle. * \brief Aligned workspace bundle.
* *


+ 26
- 0
dnn/src/fallback/conv_bias/algos.h View File

@@ -26,6 +26,16 @@ public:
AlgoSelectionStrategy algo_selection_strategy) const override; AlgoSelectionStrategy algo_selection_strategy) const override;
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override;

ConvAlgoTypePack get_algo_type() const override {
auto support_data_type = static_cast<AlgoDataType>(
static_cast<uint32_t>(AlgoDataType::FLOAT16) |
static_cast<uint32_t>(AlgoDataType::FLOAT32) |
static_cast<uint32_t>(AlgoDataType::INT8X8X16) |
static_cast<uint32_t>(AlgoDataType::QINT8X8X32) |
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32));
return {support_data_type, AlgoCategory::NAIVE};
}
}; };


class ConvBiasImpl::AlgoWinogradF32 final : public AlgoBase { class ConvBiasImpl::AlgoWinogradF32 final : public AlgoBase {
@@ -46,6 +56,10 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override;


ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD};
}

private: private:
MatrixMulImpl::AlgoBase* m_matmul_algo; MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name; mutable std::string m_name;
@@ -70,6 +84,10 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override;


ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD};
}

private: private:
MatrixMulImpl::AlgoBase* m_matmul_algo; MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name; mutable std::string m_name;
@@ -94,6 +112,10 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override;


ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD};
}

private: private:
MatrixMulImpl::AlgoBase* m_matmul_algo; MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name; mutable std::string m_name;
@@ -118,6 +140,10 @@ public:
size_t get_workspace(const NCBKernSizeParam& param) const override; size_t get_workspace(const NCBKernSizeParam& param) const override;
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override;


ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD};
}

private: private:
MatrixMulImpl::AlgoBase* m_matmul_algo; MatrixMulImpl::AlgoBase* m_matmul_algo;
mutable std::string m_name; mutable std::string m_name;


+ 4
- 1
dnn/src/fallback/conv_bias/common.h View File

@@ -140,7 +140,7 @@ using BiasMode = ConvBiasForward::BiasMode;
break; \ break; \
} }


#define MEGDNN_WINOGRAD_ALGO_FUN_DECLARE() \
#define MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(_algo_data_type) \
bool is_reproducible() const override { return true; } \ bool is_reproducible() const override { return true; } \
bool usable(const NCBKernSizeParam& param, \ bool usable(const NCBKernSizeParam& param, \
AlgoSelectionStrategy algo_selection_strategy) const override; \ AlgoSelectionStrategy algo_selection_strategy) const override; \
@@ -153,6 +153,9 @@ using BiasMode = ConvBiasForward::BiasMode;
const override; \ const override; \
virtual SmallVector<NCBKern> dispatch_preprocess_kerns( \ virtual SmallVector<NCBKern> dispatch_preprocess_kerns( \
const NCBKernSizeParam& param) const override; \ const NCBKernSizeParam& param) const override; \
ConvAlgoTypePack get_algo_type() const override { \
return {_algo_data_type, AlgoCategory::WINOGRAD}; \
} \
\ \
private: \ private: \
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; \ fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; \


+ 2
- 1
dnn/src/fallback/conv_bias/conv1x1/algos.cpp View File

@@ -288,7 +288,8 @@ bool ConvBiasImpl::AlgoConv1x1::is_preferred(
size_t OH = param.osz[0]; size_t OH = param.osz[0];
size_t OW = param.osz[1]; size_t OW = param.osz[1];
if (OH * OW != 1) { if (OH * OW != 1) {
return true;
return m_matmul_algo->algoset() !=
MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV;
} else { } else {
#if (MEGDNN_ARMV7 || MEGDNN_AARCH64) #if (MEGDNN_ARMV7 || MEGDNN_AARCH64)
if (param.src_type.enumv() == DTypeEnum::Int8 && if (param.src_type.enumv() == DTypeEnum::Int8 &&


+ 5
- 0
dnn/src/fallback/conv_bias/conv1x1/algos.h View File

@@ -56,6 +56,11 @@ public:
SmallVector<NCBKern> dispatch_preprocess_kerns( SmallVector<NCBKern> dispatch_preprocess_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;


ConvAlgoTypePack get_algo_type() const override{
return {m_matmul_algo->matmul_description().algo_type.data_type,
AlgoCategory::IM2COL};
}

protected: protected:
size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const;




+ 10
- 0
dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h View File

@@ -34,6 +34,16 @@ public:


bool is_preferred(const NCBKernSizeParam&) const override; bool is_preferred(const NCBKernSizeParam&) const override;


ConvAlgoTypePack get_algo_type() const override {
auto support_data_type = static_cast<AlgoDataType>(
static_cast<uint32_t>(AlgoDataType::FLOAT16) |
static_cast<uint32_t>(AlgoDataType::FLOAT32) |
static_cast<uint32_t>(AlgoDataType::INT8X8X16) |
static_cast<uint32_t>(AlgoDataType::QINT8X8X32) |
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32));
return {support_data_type, AlgoCategory::IM2COL};
}

protected: protected:
size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const;
}; };


+ 18
- 8
dnn/src/fallback/conv_bias/im2col/algos.h View File

@@ -48,15 +48,25 @@ public:
SmallVector<NCBKern> dispatch_preprocess_kerns( SmallVector<NCBKern> dispatch_preprocess_kerns(
const NCBKernSizeParam& param) const override; const NCBKernSizeParam& param) const override;
bool is_preferred(const NCBKernSizeParam& param) const override { bool is_preferred(const NCBKernSizeParam& param) const override {
if (param.src_type.category() == DTypeCategory::QUANTIZED) {
static CpuOprDelegationStorage<1> storage;
auto conv_bias_opr = storage.get<ConvBias, 0>();
return static_cast<ConvBiasImpl*>(conv_bias_opr)
->is_matmul_quantized_prefer(param);
size_t OH = param.osz[0];
size_t OW = param.osz[1];
//! gemm and oh * ow > 1 is prefer
//! gemv and oh * ow == 1 is prefer
if ((m_matmul_algo->algoset() !=
MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV &&
OH * OW > 1) ||
(m_matmul_algo->algoset() ==
MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV &&
OH * OW == 1)) {
return true;
} else {
return false;
} }
auto&& fm = param.filter_meta;
auto OC = fm.ocpg, IC = fm.icpg;
return OC >= 32 || IC >= 32;
}

ConvAlgoTypePack get_algo_type() const override {
return {m_matmul_algo->matmul_description().algo_type.data_type,
AlgoCategory::IM2COL};
} }


private: private:


+ 90
- 14
dnn/src/fallback/conv_bias/opr_impl.cpp View File

@@ -48,11 +48,26 @@ void incr_ptr(T*& dst, ptrdiff_t delta) {


} // namespace } // namespace


#if MEGDNN_X86
#define SKIP_GEMV()
//! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may
//! fallback to naive implementation, which may cause performance very low, so
//! here we just enable im2col for gemv in x86 backend.
//! FIXME: remove it when we add direct conv support for int8x8x16
#else
#define SKIP_GEMV() \
if (algo->algoset() == MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { \
continue; \
}
#endif


class ConvBiasImpl::AlgoPack : NonCopyableObj { class ConvBiasImpl::AlgoPack : NonCopyableObj {
AlgoNaive algo_naive; AlgoNaive algo_naive;
SmallVector<std::unique_ptr<AlgoBase>> refhold; SmallVector<std::unique_ptr<AlgoBase>> refhold;


public: public:

AlgoPack() { AlgoPack() {
refhold.emplace_back(new AlgoConv1x1Gemv()); refhold.emplace_back(new AlgoConv1x1Gemv());
all_algos.emplace_back(refhold.back().get()); all_algos.emplace_back(refhold.back().get());
@@ -110,8 +125,6 @@ public:
all_algos.emplace_back(refhold.back().get()); all_algos.emplace_back(refhold.back().get());
#endif #endif
} }
//! reverse matmul algo, when the algo is_prefer can be selected first
std::reverse(all_algos.begin(), all_algos.end());
all_algos.emplace_back(&algo_naive); all_algos.emplace_back(&algo_naive);
} }
SmallVector<AlgoBase*> all_algos; SmallVector<AlgoBase*> all_algos;
@@ -121,6 +134,22 @@ SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() {
static AlgoPack sl_algo_pack; static AlgoPack sl_algo_pack;
return sl_algo_pack.all_algos; return sl_algo_pack.all_algos;
} }

SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::select_algo_type(
ConvAlgoTypePack target_type) {
megdnn_assert(nr_type_contain(target_type.data_type),
"ConvBias algo selection only support one type");
SmallVector<ConvBiasImpl::AlgoBase*> algos;
for (auto&& algo : algo_pack()) {
auto algo_type = algo->get_algo_type();
if (contain_data_type(algo_type.data_type, target_type.data_type) &&
algo_type.algo_category == target_type.algo_category) {
algos.push_back(algo);
}
}
return algos;
}

bool ConvBiasImpl::is_naive_algo(ConvBiasImpl::Algorithm* algo) { bool ConvBiasImpl::is_naive_algo(ConvBiasImpl::Algorithm* algo) {
return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0; return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0;
} }
@@ -248,12 +277,32 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic(
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb( ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
bool reproducible) { bool reproducible) {
for (auto i : get_all_algorithms_with_ncb(param)) {
if (static_cast<AlgoBase*>(i)->usable_reproducible(
param, AlgoSelectionStrategy::HEURISTIC, reproducible) &&
NCB_ALGO_FUNC(get_workspace, i, param) <=
workspace_limit_in_bytes) {
return i;
auto algo_data_type = param.deduce_algo_data_type();
auto suggest_category_order = suggest_algo_category_order(param);
for (auto category : suggest_category_order) {
auto&& origin_algos = select_algo_type({algo_data_type, category});
ConvBiasImpl::Algorithm* heuristic_algo = nullptr;
for (auto i : origin_algos) {
bool usable_reproducible =
static_cast<AlgoBase*>(i)->usable_reproducible(
param, AlgoSelectionStrategy::HEURISTIC,
reproducible);
if (usable_reproducible &&
static_cast<AlgoBase*>(i)->get_workspace(param) <=
workspace_limit_in_bytes) {
//! store the first usable algo if no prefer algo, choose it as
//! the target algo
if (!heuristic_algo) {
heuristic_algo = i;
}
//! choose the first prefer algo
if (i->is_preferred(param)) {
return i;
}
}
}
if (heuristic_algo) {
return heuristic_algo;
} }
} }
return nullptr; return nullptr;
@@ -300,9 +349,8 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param(
sizeof(ConvolutionImpl::CanonizedFilterMeta), sizeof(ConvolutionImpl::CanonizedFilterMeta),
"sizeof CanonizedFilterMeta in convolution and conv_bias " "sizeof CanonizedFilterMeta in convolution and conv_bias "
"should be equal"); "should be equal");
CanonizedFilterMeta fm = check_layout_fwd(src, filter, dst);
ConvolutionImpl::CanonizedFilterMeta conv_fm;
conv_fm.copy_from(fm);
auto&& fm = check_layout_fwd(src, filter, dst);
auto& conv_fm = reinterpret_cast<ConvolutionImpl::CanonizedFilterMeta&>(fm);


param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT; param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT;
if (param().format == Param::Format::NCHW_WINOGRAD || if (param().format == Param::Format::NCHW_WINOGRAD ||
@@ -367,7 +415,7 @@ ConvBiasImpl::NCBKernParam ConvBiasImpl::make_ncb_kern_param(


void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param, void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param,
ConvBiasImpl::Algorithm* algo) { ConvBiasImpl::Algorithm* algo) {
auto ncb_kerns = NCB_ALGO_FUNC(dispatch_kerns, algo, param);
auto&& ncb_kerns = NCB_ALGO_FUNC(dispatch_kerns, algo, param);
for (auto&& kernel : ncb_kerns) { for (auto&& kernel : ncb_kerns) {
auto run = [kernel, param](size_t index, size_t thread_id) { auto run = [kernel, param](size_t index, size_t thread_id) {
CpuNDRange ndrange_id(kernel.global_size, index); CpuNDRange ndrange_id(kernel.global_size, index);
@@ -380,7 +428,7 @@ void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param,


void ConvBiasImpl::exec_preprocess_with_ncb_kern( void ConvBiasImpl::exec_preprocess_with_ncb_kern(
const NCBKernParam& param, ConvBiasImpl::Algorithm* algo) { const NCBKernParam& param, ConvBiasImpl::Algorithm* algo) {
auto ncb_kerns = NCB_ALGO_FUNC(dispatch_preprocess_kerns, algo, param);
auto&& ncb_kerns = NCB_ALGO_FUNC(dispatch_preprocess_kerns, algo, param);
for (auto&& kernel : ncb_kerns) { for (auto&& kernel : ncb_kerns) {
auto run = [kernel, param](size_t index, size_t thread_id) { auto run = [kernel, param](size_t index, size_t thread_id) {
CpuNDRange ndrange_id(kernel.global_size, index); CpuNDRange ndrange_id(kernel.global_size, index);
@@ -405,7 +453,6 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb(
} }
} }
} }
std::reverse(prefer_algos.begin(), prefer_algos.end());
//! Prefer algo inserted from begin //! Prefer algo inserted from begin
algos.insert(algos.begin(), prefer_algos.begin(), prefer_algos.end()); algos.insert(algos.begin(), prefer_algos.begin(), prefer_algos.end());
return algos; return algos;
@@ -425,6 +472,35 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm(
return m_prev_selected_algo; return m_prev_selected_algo;
} }


SmallVector<AlgoCategory> ConvBiasImpl::suggest_algo_category_order(
const NCBKernSizeParam& param) const {
auto IC = param.filter_meta.icpg;
auto OC = param.filter_meta.ocpg;
auto FH = param.filter_meta.spatial[0];
auto FW = param.filter_meta.spatial[1];
//! TODO: now winograd only support in fast-run
if (param.filter_meta.format == param::ConvBias::Format::NCHW_WINOGRAD ||
param.filter_meta.format == param::ConvBias::Format::NCHW44_WINOGRAD ||
param.filter_meta.format == param::ConvBias::Format::NCHW88_WINOGRAD) {
return {AlgoCategory::WINOGRAD};
}
//! im2col + matmul
bool im2col_prefer = (IC >= 32 || OC >= 32);
//! quantized algo use matmul when direct algo is unusable
if (param.src_type.category() == DTypeCategory::QUANTIZED) {
im2col_prefer = is_matmul_quantized_prefer(param);
}
//! conv1x1
im2col_prefer |= (FH == 1 && FW == 1);
if (im2col_prefer) {
return {AlgoCategory::IM2COL, AlgoCategory::DIRECT,
AlgoCategory::NAIVE};
} else {
return {AlgoCategory::DIRECT, AlgoCategory::IM2COL,
AlgoCategory::NAIVE};
}
}

const char* ConvBiasImpl::get_algorithm_set_name() const { const char* ConvBiasImpl::get_algorithm_set_name() const {
// fallback version 0 // fallback version 0
return "F0"; return "F0";


+ 19
- 0
dnn/src/fallback/conv_bias/opr_impl.h View File

@@ -18,6 +18,8 @@
#include "src/fallback/matrix_mul/opr_impl.h" #include "src/fallback/matrix_mul/opr_impl.h"
#include "src/naive/conv_bias/opr_impl.h" #include "src/naive/conv_bias/opr_impl.h"


#include <unordered_map>

namespace megdnn { namespace megdnn {
namespace fallback { namespace fallback {


@@ -44,6 +46,7 @@ class ConvBiasImpl : public naive::ConvBiasForwardImpl {
public: public:
using naive::ConvBiasForwardImpl::ConvBiasForwardImpl; using naive::ConvBiasForwardImpl::ConvBiasForwardImpl;
using AlgoSelectionStrategy = detail::AlgoSelectionStrategy; using AlgoSelectionStrategy = detail::AlgoSelectionStrategy;
using AlgoDataType = detail::AlgoDataType;


//! implemented by exec_with_ncb_kern() //! implemented by exec_with_ncb_kern()
void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
@@ -94,6 +97,8 @@ public:
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) override; bool reproducible) override;




//! size param for kernels with non-contiguous batch //! size param for kernels with non-contiguous batch
struct NCBKernSizeParam : ConvolutionImpl::NCBKernSizeParam { struct NCBKernSizeParam : ConvolutionImpl::NCBKernSizeParam {
NCBKernSizeParam() = default; NCBKernSizeParam() = default;
@@ -244,6 +249,9 @@ public:
return (!reproducible || is_reproducible()) && return (!reproducible || is_reproducible()) &&
usable(param, algo_selection_strategy); usable(param, algo_selection_strategy);
} }

//! get the type of the algo
virtual ConvAlgoTypePack get_algo_type() const = 0;
}; };


/** /**
@@ -251,6 +259,17 @@ public:
*/ */
virtual SmallVector<AlgoBase*> algo_pack(); virtual SmallVector<AlgoBase*> algo_pack();


/**
* \brief select algo according to input algo type
*/
SmallVector<AlgoBase*> select_algo_type(ConvAlgoTypePack algo_type);

/**
* \brief suggest algo category according to the param
*/
virtual SmallVector<AlgoCategory> suggest_algo_category_order(
const NCBKernSizeParam& param) const;

protected: protected:
virtual void exec_with_ncb_kern(const NCBKernParam& param, virtual void exec_with_ncb_kern(const NCBKernParam& param,
ConvBiasImpl::Algorithm* algo); ConvBiasImpl::Algorithm* algo);


+ 19
- 2
dnn/src/fallback/convolution/algos.h View File

@@ -83,6 +83,10 @@ public:


SmallVector<NCBKern> dispatch_kern( SmallVector<NCBKern> dispatch_kern(
const NCBKernSizeParam& /*param*/) const override; const NCBKernSizeParam& /*param*/) const override;

ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::NAIVE};
}
}; };


class ConvolutionImpl::AlgoNaive final : public AlgoBase { class ConvolutionImpl::AlgoNaive final : public AlgoBase {
@@ -96,11 +100,17 @@ public:


SmallVector<NCBKern> dispatch_kern( SmallVector<NCBKern> dispatch_kern(
const NCBKernSizeParam& /*param*/) const override; const NCBKernSizeParam& /*param*/) const override;

ConvAlgoTypePack get_algo_type() const override {
auto support_data_type = static_cast<AlgoDataType>(
static_cast<uint32_t>(AlgoDataType::INT8X8X16) |
static_cast<uint32_t>(AlgoDataType::QINT8X8X32) |
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32));
return {support_data_type, AlgoCategory::NAIVE};
}
}; };


class ConvolutionImpl::AlgoDefault final : public AlgoBase { class ConvolutionImpl::AlgoDefault final : public AlgoBase {
static ConvBiasImpl::NCBKernSizeParam init_conv_bias_param(
const NCBKernSizeParam& param);
WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const;
static SmallVector<NCBKern> get_kimpl(ConvBiasImpl::AlgoBase* algo, static SmallVector<NCBKern> get_kimpl(ConvBiasImpl::AlgoBase* algo,
const NCBKernSizeParam& param); const NCBKernSizeParam& param);
@@ -136,6 +146,13 @@ public:
//! select matmul to the highest preference //! select matmul to the highest preference
bool is_preferred(const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override;


static ConvBiasImpl::NCBKernSizeParam init_conv_bias_param(
const NCBKernSizeParam& param);

ConvAlgoTypePack get_algo_type() const override {
return m_algorithm->get_algo_type();
}

private: private:
std::string m_name; std::string m_name;
ConvBiasImpl::AlgoBase* m_algorithm; ConvBiasImpl::AlgoBase* m_algorithm;


+ 83
- 15
dnn/src/fallback/convolution/opr_impl.cpp View File

@@ -23,6 +23,7 @@
#include "midout.h" #include "midout.h"


#include <cstring> #include <cstring>
#include <unordered_map>


MIDOUT_DECL(megdnn_fb_convbwd_float) MIDOUT_DECL(megdnn_fb_convbwd_float)


@@ -75,6 +76,22 @@ SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::algo_pack() {
static AlgoPack sl_algo_pack; static AlgoPack sl_algo_pack;
return sl_algo_pack.all_algos; return sl_algo_pack.all_algos;
} }

SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::select_algo_type(
ConvAlgoTypePack target_type) {
megdnn_assert(nr_type_contain(target_type.data_type),
"ConvBias algo selection only support one type");
SmallVector<ConvolutionImpl::AlgoBase*> algos;
for (auto&& algo : algo_pack()) {
auto algo_type = algo->get_algo_type();
if (contain_data_type(algo_type.data_type, target_type.data_type) &&
algo_type.algo_category == target_type.algo_category) {
algos.push_back(algo);
}
}
return algos;
}

bool ConvolutionImpl::is_naive_algo(ConvolutionImpl::Algorithm* algo) { bool ConvolutionImpl::is_naive_algo(ConvolutionImpl::Algorithm* algo) {
return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0; return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0;
} }
@@ -249,9 +266,9 @@ ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param(


void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param, void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param,
Algorithm* algo) { Algorithm* algo) {
auto kerns = NCB_ALGO_FUNC(dispatch_preprocess_kern, algo, param);
auto fallback_handle = handle();
for (auto kernel : kerns) {
auto&& kerns = NCB_ALGO_FUNC(dispatch_preprocess_kern, algo, param);
auto&& fallback_handle = handle();
for (auto&& kernel : kerns) {
megdnn_assert( megdnn_assert(
param.filter_meta.format == Param::Format::NCHW || param.filter_meta.format == Param::Format::NCHW ||
param.filter_meta.format == Param::Format::NHWC || param.filter_meta.format == Param::Format::NHWC ||
@@ -270,9 +287,9 @@ void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param,


void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param,
Algorithm* algo) { Algorithm* algo) {
auto kerns = NCB_ALGO_FUNC(dispatch_kern, algo, param);
auto fallback_handle = handle();
for (auto kernel : kerns) {
auto&& kerns = NCB_ALGO_FUNC(dispatch_kern, algo, param);
auto&& fallback_handle = handle();
for (auto&& kernel : kerns) {
megdnn_assert( megdnn_assert(
param.filter_meta.format == Param::Format::NCHW || param.filter_meta.format == Param::Format::NCHW ||
param.filter_meta.format == Param::Format::NHWC || param.filter_meta.format == Param::Format::NHWC ||
@@ -292,13 +309,32 @@ void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param,
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb( ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb(
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, const NCBKernSizeParam& param, size_t workspace_limit_in_bytes,
bool reproducible) { bool reproducible) {
for (auto i : get_all_algorithms_with_ncb(param)) {
bool usable_reproducible =
static_cast<AlgoBase*>(i)->usable_reproducible(
param, AlgoSelectionStrategy::HEURISTIC, reproducible);
if (usable_reproducible && NCB_ALGO_FUNC(get_workspace, i, param) <=
workspace_limit_in_bytes) {
return i;
auto algo_data_type = param.deduce_algo_data_type();
auto suggest_category_order = suggest_algo_category_order(param);
for (auto category : suggest_category_order) {
auto&& origin_algos = select_algo_type({algo_data_type, category});
ConvolutionImpl::Algorithm* heuristic_algo = nullptr;
for (auto i : origin_algos) {
bool usable_reproducible =
static_cast<AlgoBase*>(i)->usable_reproducible(
param, AlgoSelectionStrategy::HEURISTIC,
reproducible);
if (usable_reproducible &&
static_cast<AlgoBase*>(i)->get_workspace(param) <=
workspace_limit_in_bytes) {
//! store the first usable algo if no prefer algo, choose it as
//! the target algo
if (!heuristic_algo) {
heuristic_algo = i;
}
//! choose the first prefer algo
if (i->is_preferred(param)) {
return i;
}
}
}
if (heuristic_algo) {
return heuristic_algo;
} }
} }
return nullptr; return nullptr;
@@ -317,8 +353,6 @@ ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) {
} }
} }
} }
std::reverse(prefer_algos.begin(), prefer_algos.end());
//! Prefer algo inserted from begin
ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end()); ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end());
return ret; return ret;
} }
@@ -337,11 +371,45 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm(
return m_prev_selected_algo; return m_prev_selected_algo;
} }


SmallVector<AlgoCategory> ConvolutionImpl::suggest_algo_category_order(
const NCBKernSizeParam& param) const {
static CpuOprDelegationStorage<1> storage;
auto conv_bias_opr = storage.get<ConvBias, 0>();
auto conv_bias_param =
ConvolutionImpl::AlgoDefault::init_conv_bias_param(param);
return static_cast<ConvBiasImpl*>(conv_bias_opr)
->suggest_algo_category_order(conv_bias_param);
}

const char* ConvolutionImpl::get_algorithm_set_name() const { const char* ConvolutionImpl::get_algorithm_set_name() const {
// fallback version 0 // fallback version 0
return "F0"; return "F0";
} }


ConvolutionImpl::AlgoDataType
ConvolutionImpl::NCBKernSizeParam::deduce_algo_data_type() const {
if (src_type.enumv() == DTypeEnum::Float32) {
return ConvolutionImpl::AlgoDataType::FLOAT32;
#if !MEGDNN_DISABLE_FLOAT16
} else if (src_type.enumv() == DTypeEnum::Float16) {
return ConvolutionImpl::AlgoDataType::FLOAT16;
#endif
} else if (src_type.enumv() == DTypeEnum::Int8 ||
src_type.enumv() == DTypeEnum::QuantizedS8) {
if (dst_type.enumv() == DTypeEnum::Int16) {
return ConvolutionImpl::AlgoDataType::INT8X8X16;
} else {
return ConvolutionImpl::AlgoDataType::QINT8X8X32;
}
} else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) {
return ConvolutionImpl::AlgoDataType::QUINT8X8X32;
} else {
megdnn_throw(ssprintf("megdnn not support data type of %s * %s -> %s\n",
src_type.name(), filter_type.name(),
dst_type.name()));
}
}

/* ===================== ConvolutionBackwardData ===================== */ /* ===================== ConvolutionBackwardData ===================== */


void* const ConvolutionBackwardDataImpl::sm_fallback_deconv_algo_type = void* const ConvolutionBackwardDataImpl::sm_fallback_deconv_algo_type =


+ 31
- 0
dnn/src/fallback/convolution/opr_impl.h View File

@@ -10,11 +10,28 @@
*/ */
#pragma once #pragma once


#include "megdnn/oprs/base.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/handle.h" #include "src/fallback/handle.h"
#include "src/naive/convolution/opr_impl.h" #include "src/naive/convolution/opr_impl.h"


namespace megdnn { namespace megdnn {

/**
* \brief Convolutino algo category
*/
enum class AlgoCategory : int32_t {
DIRECT = 0,
IM2COL = 1,
WINOGRAD = 2,
NAIVE = 3,
};

struct ConvAlgoTypePack {
detail::AlgoDataType data_type : 32;
AlgoCategory algo_category : 32;
};

namespace fallback { namespace fallback {


/*! /*!
@@ -33,6 +50,7 @@ class ConvolutionImpl : public naive::ConvolutionForwardImpl {
public: public:
using naive::ConvolutionForwardImpl::ConvolutionForwardImpl; using naive::ConvolutionForwardImpl::ConvolutionForwardImpl;
using AlgoSelectionStrategy = detail::AlgoSelectionStrategy; using AlgoSelectionStrategy = detail::AlgoSelectionStrategy;
using AlgoDataType = detail::AlgoDataType;


//! implemented by exec_with_ncb_kern() //! implemented by exec_with_ncb_kern()
void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
@@ -86,6 +104,8 @@ public:
size_t nr_threads; size_t nr_threads;
//! weight_preprocess info //! weight_preprocess info
const PreprocessedFilter* preprocessed_filter; const PreprocessedFilter* preprocessed_filter;
//! get the data type category of the param for select the algo
AlgoDataType deduce_algo_data_type() const;
}; };


//! memory param for kernels with non-contiguous batch //! memory param for kernels with non-contiguous batch
@@ -211,6 +231,9 @@ public:
return (!reproducible || is_reproducible()) && return (!reproducible || is_reproducible()) &&
usable(param, algo_selection_strategy); usable(param, algo_selection_strategy);
} }

//! get the type of the algo
virtual ConvAlgoTypePack get_algo_type() const = 0;
}; };


/** /**
@@ -218,6 +241,11 @@ public:
*/ */
virtual SmallVector<AlgoBase*> algo_pack(); virtual SmallVector<AlgoBase*> algo_pack();


/**
* \brief select algo according to input algo type
*/
SmallVector<AlgoBase*> select_algo_type(ConvAlgoTypePack algo_type);

protected: protected:
virtual void exec_with_ncb_kern(const NCBKernParam& param, Algorithm* algo); virtual void exec_with_ncb_kern(const NCBKernParam& param, Algorithm* algo);


@@ -258,6 +286,9 @@ private:
_megdnn_tensor_out dst, _megdnn_tensor_out dst,
const PreprocessedFilter* preprocessed_filter, const PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace); _megdnn_workspace workspace);

SmallVector<AlgoCategory> suggest_algo_category_order(
const NCBKernSizeParam& param) const;
}; };


class ConvolutionBackwardDataImpl : public naive::ConvolutionBackwardDataImpl { class ConvolutionBackwardDataImpl : public naive::ConvolutionBackwardDataImpl {


+ 1
- 1
dnn/src/fallback/matrix_mul/algos.cpp View File

@@ -76,7 +76,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern(


MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_fb_matmul_f32_kern, MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_fb_matmul_f32_kern,
5, matmul::fallback::sgemm_8x12, float, 5, matmul::fallback::sgemm_8x12, float,
float);
float, AlgoDataType::FLOAT32, DEFAULT);


/* ===================== gemv algo ===================== */ /* ===================== gemv algo ===================== */
bool MatrixMulImpl::AlgoGemv::usable( bool MatrixMulImpl::AlgoGemv::usable(


+ 9
- 1
dnn/src/fallback/matrix_mul/algos.h View File

@@ -37,7 +37,15 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(
8, 16, 1, 4,
static_cast<AlgoDataType>(
static_cast<uint32_t>(AlgoDataType::FLOAT16) |
static_cast<uint32_t>(AlgoDataType::FLOAT32) |
static_cast<uint32_t>(AlgoDataType::INT8X8X16) |
static_cast<uint32_t>(AlgoDataType::QINT8X8X32) |
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)),
DEFAULT)
}; };


} // namespace fallback } // namespace fallback


+ 17
- 13
dnn/src/fallback/matrix_mul/gemm_common.h View File

@@ -352,13 +352,15 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
DType dtype_c) \ DType dtype_c) \
: A_dtype(dtype_a), B_dtype(dtype_b), C_dtype(dtype_c) {} : A_dtype(dtype_a), B_dtype(dtype_b), C_dtype(dtype_c) {}


#define MEGDNN_OVERRIDE_MATMUL_DESC(_m, _n, _k, _packa_type_size) \
MatmulDescription matmul_description() const override { \
MatmulDescription mdesc; \
mdesc.packmode = packmode(); \
mdesc.innerblocksize = {_m, _n, _k}; \
mdesc.packa_type_size = _packa_type_size; \
return mdesc; \
#define MEGDNN_OVERRIDE_MATMUL_DESC(_m, _n, _k, _packa_type_size, _data_type, \
_format) \
MatmulDescription matmul_description() const override { \
MatmulDescription mdesc; \
mdesc.packmode = packmode(); \
mdesc.innerblocksize = {_m, _n, _k}; \
mdesc.packa_type_size = _packa_type_size; \
mdesc.algo_type = {_data_type, Param::Format::_format}; \
return mdesc; \
} }


#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL() \ #define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL() \
@@ -373,7 +375,7 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,


#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \ #define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \ _algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \
_packa_type) \
_packa_type, _support_data_type, _format) \
\ \
MatrixMulImpl::kern_naked_t MatrixMulImpl::_algo_name::get_kern_naked( \ MatrixMulImpl::kern_naked_t MatrixMulImpl::_algo_name::get_kern_naked( \
const KernSizeParam&) const { \ const KernSizeParam&) const { \
@@ -474,14 +476,16 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K,
mdesc.innerblocksize = {_strategy::KERNEL_H, _strategy::KERNEL_W, \ mdesc.innerblocksize = {_strategy::KERNEL_H, _strategy::KERNEL_W, \
_strategy::UNROLL_K}; \ _strategy::UNROLL_K}; \
mdesc.packa_type_size = sizeof(_packa_type); \ mdesc.packa_type_size = sizeof(_packa_type); \
mdesc.algo_type = {_support_data_type, Param::Format::_format}; \
return mdesc; \ return mdesc; \
} }


#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type) \
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(_algo_name, _midout_name, \
_mid_index, _strategy, \
_i_type, _c_type, _i_type)
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \
_support_data_type, _format) \
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \
_i_type, _support_data_type, _format)
} // namespace matmul } // namespace matmul
} // namespace megdnn } // namespace megdnn




+ 59
- 5
dnn/src/fallback/matrix_mul/opr_impl.cpp View File

@@ -38,6 +38,22 @@ SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() {
return s_algo_pack.all_algos; return s_algo_pack.all_algos;
} }


SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::select_algo_type(
AlgoTypePack index) {
megdnn_assert(nr_type_contain(index.data_type),
"Matmul algo selection only support one type");
SmallVector<MatrixMulImpl::AlgoBase*> algos;
for (auto&& algo : algo_pack()) {
auto algo_desc = algo->matmul_description();
if (contain_data_type(algo_desc.algo_type.data_type,
index.data_type) &&
algo_desc.algo_type.format == index.format) {
algos.push_back(algo);
}
}
return algos;
}

std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms( std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms(
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) {
std::vector<Algorithm*> gemm_algos, gemv_algos; std::vector<Algorithm*> gemm_algos, gemv_algos;
@@ -71,17 +87,25 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic(
"require reproducible algorithm, but given algorithm is not " "require reproducible algorithm, but given algorithm is not "
"reproducible"); "reproducible");
} }

auto algos = get_all_algorithms(A, B, C);
AlgoTypePack algo_type;
algo_type.data_type = kern_size_param.deduce_algo_data_type();
algo_type.format = kern_size_param.format;
auto algos = select_algo_type(algo_type);
Algorithm *heuristic_algo = nullptr;
for (auto&& algo : algos) { for (auto&& algo : algos) {
if (static_cast<AlgoBase*>(algo)->preferred_reproducible(
if (static_cast<AlgoBase*>(algo)->usable(kern_size_param) &&
static_cast<AlgoBase*>(algo)->preferred_reproducible(
kern_size_param, reproducible) && kern_size_param, reproducible) &&
static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param) <= static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param) <=
workspace_limit_in_bytes) { workspace_limit_in_bytes) {
return algo;
if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) {
return algo;
} else if (!heuristic_algo) {
heuristic_algo = algo;
}
} }
} }
return nullptr;
return heuristic_algo;
} }


MatrixMulImpl::KernSizeParam MatrixMulImpl::make_kern_size_param( MatrixMulImpl::KernSizeParam MatrixMulImpl::make_kern_size_param(
@@ -150,4 +174,34 @@ void MatrixMulImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B,
naive::MatrixMulForwardImpl::exec(A, B, C, workspace); naive::MatrixMulForwardImpl::exec(A, B, C, workspace);
} }


MatrixMulImpl::AlgoDataType
MatrixMulImpl::KernSizeParam::deduce_algo_data_type() const {
megdnn_assert(A_type.enumv() == B_type.enumv(),
"Matmul A type and B type of different ctype\n");
if (A_type.enumv() == DTypeEnum::Float32) {
return MatrixMulImpl::AlgoDataType::FLOAT32;
#if !MEGDNN_DISABLE_FLOAT16
} else if (A_type.enumv() == DTypeEnum::Float16) {
return MatrixMulImpl::AlgoDataType::FLOAT16;
#endif
} else if (A_type.enumv() == DTypeEnum::Int8 ||
A_type.enumv() == DTypeEnum::QuantizedS8) {
if (C_type.enumv() == DTypeEnum::Int16) {
return MatrixMulImpl::AlgoDataType::INT8X8X16;
} else {
megdnn_assert(C_type.enumv() == DTypeEnum::Int32 ||
C_type.enumv() == DTypeEnum::QuantizedS32);
return MatrixMulImpl::AlgoDataType::QINT8X8X32;
}
} else if (A_type.enumv() == DTypeEnum::Quantized8Asymm) {
return MatrixMulImpl::AlgoDataType::QUINT8X8X32;
} else if (A_type.enumv() == DTypeEnum::Int16) {
return MatrixMulImpl::AlgoDataType::INT16X16X32;
} else {
megdnn_throw(ssprintf(
"megdnn matmul not support data type of %s * %s -> %s\n",
A_type.name(), B_type.name(), C_type.name()));
}
}

// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 18
- 1
dnn/src/fallback/matrix_mul/opr_impl.h View File

@@ -10,14 +10,23 @@
* implied. * implied.
*/ */
#pragma once #pragma once
#include "megdnn/opr_param_defs.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/naive/matrix_mul/opr_impl.h" #include "src/naive/matrix_mul/opr_impl.h"
#include <unordered_map>

namespace megdnn { namespace megdnn {
namespace fallback {


struct AlgoTypePack {
detail::AlgoDataType data_type : 32;
param::MatrixMul::Format format : 32;
};

namespace fallback {
class MatrixMulImpl : public naive::MatrixMulForwardImpl { class MatrixMulImpl : public naive::MatrixMulForwardImpl {
public: public:
using naive::MatrixMulForwardImpl::MatrixMulForwardImpl; using naive::MatrixMulForwardImpl::MatrixMulForwardImpl;
using AlgoDataType = detail::AlgoDataType;


bool is_thread_safe() const override { return true; } bool is_thread_safe() const override { return true; }


@@ -34,6 +43,8 @@ public:
bool trA, trB; bool trA, trB;
Param::ComputeMode compute_mode; Param::ComputeMode compute_mode;
Param::Format format; Param::Format format;
//! get the data type category of the param for select the algo
AlgoDataType deduce_algo_data_type() const;
}; };


struct KernParam : public KernSizeParam { struct KernParam : public KernSizeParam {
@@ -110,6 +121,7 @@ public:
struct MatmulDescription { struct MatmulDescription {
PackMode packmode; PackMode packmode;
InnerBlockSize innerblocksize; InnerBlockSize innerblocksize;
AlgoTypePack algo_type;
size_t packa_type_size; size_t packa_type_size;
}; };


@@ -146,6 +158,11 @@ public:
*/ */
virtual SmallVector<AlgoBase*> algo_pack(); virtual SmallVector<AlgoBase*> algo_pack();


/**
* \brief select algo according to input algo type
*/
SmallVector<AlgoBase*> select_algo_type(AlgoTypePack algo_type);

protected: protected:
KernSizeParam make_kern_size_param(const TensorLayout& A, KernSizeParam make_kern_size_param(const TensorLayout& A,
const TensorLayout& B, const TensorLayout& B,


+ 17
- 2
dnn/src/x86/conv_bias/f32/algos.h View File

@@ -48,6 +48,10 @@ public:
} }


void* type() const override; void* type() const override;

ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
}; };


/* ===================== direct-stride2 algo ===================== */ /* ===================== direct-stride2 algo ===================== */
@@ -81,6 +85,10 @@ public:
} }


void* type() const override; void* type() const override;

ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
}; };
/* =========================== winograd ======================== */ /* =========================== winograd ======================== */
class ConvBiasImpl::AlgoFP32WinogradF63_8x8 final : public AlgoBase { class ConvBiasImpl::AlgoFP32WinogradF63_8x8 final : public AlgoBase {
@@ -96,7 +104,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
void* type() const override; void* type() const override;
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
}; };


class ConvBiasImpl::AlgoFP32WinogradF23_8x8 final : public AlgoBase { class ConvBiasImpl::AlgoFP32WinogradF23_8x8 final : public AlgoBase {
@@ -112,7 +120,7 @@ public:
return m_name.c_str(); return m_name.c_str();
} }
void* type() const override; void* type() const override;
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE();
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32);
}; };


/* ===================== matmul algo ===================== */ /* ===================== matmul algo ===================== */
@@ -151,6 +159,9 @@ public:
} }


void* type() const override; void* type() const override;
ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::IM2COL};
}
}; };


#if MEGDNN_X86_WITH_MKL_DNN #if MEGDNN_X86_WITH_MKL_DNN
@@ -192,6 +203,10 @@ public:
return {{kern, {1_z, 1_z, 1_z}}}; return {{kern, {1_z, 1_z, 1_z}}};
} }
void* type() const override; void* type() const override;

ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT};
}
}; };
#endif #endif
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 0
- 2
dnn/src/x86/conv_bias/int8/algo_usable_preferred.cpp View File

@@ -224,8 +224,6 @@ bool mkldnn_matmul_qint8_preferred(
const ConvBiasImpl::NCBKernSizeParam& param) { const ConvBiasImpl::NCBKernSizeParam& param) {
auto is_preferred = true; auto is_preferred = true;
auto&& fm = param.filter_meta; auto&& fm = param.filter_meta;
megdnn_assert_internal(fm.group == 1 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1);


// single channel conv should never use matrix mul // single channel conv should never use matrix mul
if (fm.ocpg == 1 || fm.icpg == 1) if (fm.ocpg == 1 || fm.icpg == 1)


+ 24
- 0
dnn/src/x86/conv_bias/int8/algos.h View File

@@ -34,6 +34,10 @@ public:
} }
void* type() const override; void* type() const override;
bool is_preferred(const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override;

ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
}; };


/* ===================== avx2 stride2 chanwise algo ===================== */ /* ===================== avx2 stride2 chanwise algo ===================== */
@@ -55,6 +59,10 @@ public:
} }
void* type() const override; void* type() const override;
bool is_preferred(const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override;

ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
}; };


/* ===================== avx2 stride1 direct algo ===================== */ /* ===================== avx2 stride1 direct algo ===================== */
@@ -76,6 +84,10 @@ public:
} }
void* type() const override; void* type() const override;
bool is_preferred(const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override;

ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
}; };


/* ================== avx2 int8 direct conv stride2 algo ================== */ /* ================== avx2 int8 direct conv stride2 algo ================== */
@@ -97,6 +109,10 @@ public:
} }
void* type() const override; void* type() const override;
bool is_preferred(const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override;

ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
}; };


#if MEGDNN_X86_WITH_MKL_DNN #if MEGDNN_X86_WITH_MKL_DNN
@@ -134,6 +150,10 @@ public:
} }
void* type() const override; void* type() const override;
bool is_preferred(const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override;

ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT};
}
}; };
/* ===================== mkldnn qint8 matmul algo ===================== */ /* ===================== mkldnn qint8 matmul algo ===================== */
class ConvBiasImpl::AlgoMkldnnMatmulQint8 final : public AlgoBase { class ConvBiasImpl::AlgoMkldnnMatmulQint8 final : public AlgoBase {
@@ -160,6 +180,10 @@ public:
bool is_preferred(const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override;


void* type() const override; void* type() const override;

ConvAlgoTypePack get_algo_type() const override {
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL};
}
}; };
#endif #endif




+ 39
- 2
dnn/src/x86/conv_bias/opr_impl.cpp View File

@@ -103,10 +103,10 @@ public:
#endif #endif
all_algos.emplace_back(&stride1_direct); all_algos.emplace_back(&stride1_direct);
all_algos.emplace_back(&stride2_direct); all_algos.emplace_back(&stride2_direct);
all_algos.emplace_back(&avx2_stride1_direct_int8);
all_algos.emplace_back(&avx2_stride2_direct);
all_algos.emplace_back(&avx2_stride1_chanwsie_qint8); all_algos.emplace_back(&avx2_stride1_chanwsie_qint8);
all_algos.emplace_back(&avx2_stride2_chanwsie_qint8); all_algos.emplace_back(&avx2_stride2_chanwsie_qint8);
all_algos.emplace_back(&avx2_stride1_direct_int8);
all_algos.emplace_back(&avx2_stride2_direct);
all_algos.emplace_back(&matmul); all_algos.emplace_back(&matmul);


static CpuOprDelegationStorage<> storage; static CpuOprDelegationStorage<> storage;
@@ -182,4 +182,41 @@ bool ConvBiasImpl::is_matmul_quantized_prefer(
!chanwise_avx2_stride2_qint8_usable_preferred(param)); !chanwise_avx2_stride2_qint8_usable_preferred(param));
} }


SmallVector<AlgoCategory>
ConvBiasImpl::suggest_algo_category_order(const NCBKernSizeParam& param) const {
auto IC = param.filter_meta.icpg;
auto OC = param.filter_meta.ocpg;
auto FH = param.filter_meta.spatial[0];
auto FW = param.filter_meta.spatial[1];
//! TODO: now winograd only support fast-run
if (param.filter_meta.format == param::ConvBias::Format::NCHW_WINOGRAD ||
param.filter_meta.format == param::ConvBias::Format::NCHW44_WINOGRAD ||
param.filter_meta.format == param::ConvBias::Format::NCHW88_WINOGRAD) {
return {AlgoCategory::WINOGRAD};
}
//! nchw88 use mkl-dnn which algo is direct
if (param.filter_meta.format == param::ConvBias::Format::NCHW88) {
return {AlgoCategory::DIRECT, AlgoCategory::IM2COL};
}
//! im2col + matmul
bool im2col_prefer = (IC >= 32 || OC >= 32);
//! quantized algo use matmul when direct algo is unusable
if (param.src_type.category() == DTypeCategory::QUANTIZED) {
im2col_prefer = is_matmul_quantized_prefer(param);
}
//! conv1x1
im2col_prefer |= (FH == 1 && FW == 1);
//! x86 8x8x16 not optmized, so it will use fallback im2col+matmul
if (param.deduce_algo_data_type() == AlgoDataType::INT8X8X16) {
im2col_prefer = true;
}
if (im2col_prefer) {
return {AlgoCategory::IM2COL, AlgoCategory::DIRECT,
AlgoCategory::NAIVE};
} else {
return {AlgoCategory::DIRECT, AlgoCategory::IM2COL,
AlgoCategory::NAIVE};
}
}

// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen

+ 2
- 0
dnn/src/x86/conv_bias/opr_impl.h View File

@@ -24,6 +24,8 @@ public:


bool is_thread_safe() const override { return true; } bool is_thread_safe() const override { return true; }
SmallVector<AlgoBase*> algo_pack() override; SmallVector<AlgoBase*> algo_pack() override;
SmallVector<AlgoCategory> suggest_algo_category_order(
const NCBKernSizeParam& param) const override;


class AlgoDirect; class AlgoDirect;
class AlgoDirectStride2; class AlgoDirectStride2;


+ 13
- 10
dnn/src/x86/matrix_mul/algos.cpp View File

@@ -184,11 +184,10 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Vnni::get_kern(
return int8x8x32_kern_vnni; return int8x8x32_kern_vnni;
} }


MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32Vnni,
megdnn_x86_matmul_kern,
"AlgoInt8x8x32Vnni"_hash,
x86::matmul::gemm_int8_vnni_12x32x4,
dt_int8, dt_int32, dt_uint8);
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
AlgoInt8x8x32Vnni, megdnn_x86_matmul_kern, "AlgoInt8x8x32Vnni"_hash,
x86::matmul::gemm_int8_vnni_12x32x4, dt_int8, dt_int32,
dt_uint8AlgoDataType::QINT8X8X32, DEFAULT);
#endif #endif


/* ===================== Int8 mkldnn algo ===================== */ /* ===================== Int8 mkldnn algo ===================== */
@@ -397,7 +396,8 @@ size_t MatrixMulImpl::AlgoInt8x8x16AVX2::get_workspace(
} }
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
AlgoInt8x8x16AVX2, megdnn_x86_matmul_kern, "AlgoInt8x8x16AVX2"_hash, AlgoInt8x8x16AVX2, megdnn_x86_matmul_kern, "AlgoInt8x8x16AVX2"_hash,
x86::matmul::gemm_avx2_s8s8s16_4x16x2, dt_int8, dt_int16, dt_int16);
x86::matmul::gemm_avx2_s8s8s16_4x16x2, dt_int8, dt_int16, dt_int16,
AlgoDataType::INT8X8X16, DEFAULT);


/*************************AlgoInt8x8x16SSE********************/ /*************************AlgoInt8x8x16SSE********************/
void MatrixMulImpl::AlgoInt8x8x16SSE::gemm_s8s8s16_sse_4x8x2( void MatrixMulImpl::AlgoInt8x8x16SSE::gemm_s8s8s16_sse_4x8x2(
@@ -474,7 +474,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16SSE,
megdnn_x86_matmul_kern, megdnn_x86_matmul_kern,
"AlgoInt8x8x16SSE"_hash, "AlgoInt8x8x16SSE"_hash,
x86::matmul::gemm_sse_s8s8s16_4x8x2, x86::matmul::gemm_sse_s8s8s16_4x8x2,
dt_int8, dt_int16, dt_int16);
dt_int8, dt_int16, dt_int16,
AlgoDataType::INT8X8X16, DEFAULT);


/*************************AlgoInt8x8x32AVX2M4N16K2********************/ /*************************AlgoInt8x8x32AVX2M4N16K2********************/
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern( MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern(
@@ -516,7 +517,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_workspace(
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(
AlgoInt8x8x32AVX2M4N16K2, megdnn_x86_matmul_kern, AlgoInt8x8x32AVX2M4N16K2, megdnn_x86_matmul_kern,
"AlgoInt8x8x32AVX2M4N16K2"_hash, x86::matmul::gemm_avx2_s8s8s32_4x16x2, "AlgoInt8x8x32AVX2M4N16K2"_hash, x86::matmul::gemm_avx2_s8s8s32_4x16x2,
dt_int8, dt_int32, dt_int16);
dt_int8, dt_int32, dt_int16, AlgoDataType::QINT8X8X32, DEFAULT);


MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_kern( MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_kern(
const KernSizeParam&) const { const KernSizeParam&) const {
@@ -556,7 +557,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32AVX2M2N4K16,
megdnn_x86_matmul_kern, megdnn_x86_matmul_kern,
"AlgoInt8x8x32AVX2M2N4K16"_hash, "AlgoInt8x8x32AVX2M2N4K16"_hash,
x86::matmul::gemm_avx2_s8s8s32_2x4x16, x86::matmul::gemm_avx2_s8s8s32_2x4x16,
dt_int8, dt_int32);
dt_int8, dt_int32,
AlgoDataType::QINT8X8X32, DEFAULT);


/*************************AlgoInt8x8x32SSEM4N8K2********************/ /*************************AlgoInt8x8x32SSEM4N8K2********************/
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_kern( MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_kern(
@@ -596,7 +598,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32SSEM4N8K2,
megdnn_x86_matmul_kern, megdnn_x86_matmul_kern,
"AlgoInt8x8x32SSEM4N8K2"_hash, "AlgoInt8x8x32SSEM4N8K2"_hash,
x86::matmul::gemm_sse_s8s8s32_4x8x2, x86::matmul::gemm_sse_s8s8s32_4x8x2,
dt_int8, dt_int32, dt_int16);
dt_int8, dt_int32, dt_int16,
AlgoDataType::QINT8X8X32, DEFAULT);


/*************************AlgoF32MK8_8x8********************/ /*************************AlgoF32MK8_8x8********************/
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK8_8x8::get_kern( MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK8_8x8::get_kern(


+ 4
- 4
dnn/src/x86/matrix_mul/algos.h View File

@@ -27,7 +27,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; } void* type() const override { return sm_x86_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT)
}; };


#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM
@@ -49,7 +49,7 @@ public:
WorkspaceBundle get_bundle(const KernSizeParam& param) const override; WorkspaceBundle get_bundle(const KernSizeParam& param) const override;
InnerBlockSize get_inner_block_size() const override{ return {8, 16, 1}; }; InnerBlockSize get_inner_block_size() const override{ return {8, 16, 1}; };


MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT)
}; };
#endif #endif


@@ -127,7 +127,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; } void* type() const override { return sm_x86_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4, AlgoDataType::FLOAT32, MK8)
}; };


#if MEGDNN_X86_WITH_VNNI #if MEGDNN_X86_WITH_VNNI
@@ -153,7 +153,7 @@ public:
kern_t get_kern(const KernSizeParam&) const override; kern_t get_kern(const KernSizeParam&) const override;
void* type() const override { return sm_x86_algo_type; } void* type() const override { return sm_x86_algo_type; }
PackMode packmode() const override { return PackMode::NO_PACK; } PackMode packmode() const override { return PackMode::NO_PACK; }
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2)
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT)
}; };
#endif #endif
} // namespace x86 } // namespace x86


+ 5
- 5
src/opr/impl/dnn/convolution.cpp View File

@@ -495,8 +495,9 @@ class AlgoChooser {
} }
} }
mgb_assert(found, mgb_assert(found,
"algo got by heuristic not found in "
"candidate list");
"algo %s got by heuristic not found in "
"candidate list",
heu->name());
return std::move(ret); return std::move(ret);
} }


@@ -628,7 +629,7 @@ public:
auto algo = get_algo(ctx); auto algo = get_algo(ctx);
size_t workspace = ctx.get_workspace_size_bytes(algo); size_t workspace = ctx.get_workspace_size_bytes(algo);
mgb_log_debug( mgb_log_debug(
"%s: input shapes (%s %s, %s %s) -> (%s %s): algo=%s "
"%s:tensor layouts (%s %s, %s %s)->(%s %s) :algo=%s "
"workspace=%.2fMiB reproducible=%d", "workspace=%.2fMiB reproducible=%d",
mgb_opr->dyn_typeinfo()->name, mgb_opr->dyn_typeinfo()->name,
layouts[0].to_string().c_str(), layouts[0].to_string().c_str(),
@@ -636,8 +637,7 @@ public:
layouts[1].to_string().c_str(), layouts[1].to_string().c_str(),
layouts[1].dtype.name(), layouts[1].dtype.name(),
layouts[layouts.size() - 1].to_string().c_str(), layouts[layouts.size() - 1].to_string().c_str(),
layouts[layouts.size() - 1].dtype.name(),
algo->name(),
layouts[layouts.size() - 1].dtype.name(), algo->name(),
workspace / (1024 * 1024.0), algo->is_reproducible()); workspace / (1024 * 1024.0), algo->is_reproducible());
megdnn_opr->execution_policy() = {algo}; megdnn_opr->execution_policy() = {algo};
return workspace; return workspace;


Loading…
Cancel
Save