From 2a3f4d099a419296941532036f9fbde39273137b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 3 Aug 2020 22:34:34 +0800 Subject: [PATCH] refactor(dnn/arm): refactor CPU heuristic algo selection GitOrigin-RevId: 60d2646bb33316411caa18686eec724dc1f6c430 --- dnn/include/megdnn/oprs/base.h | 12 +++ dnn/src/aarch64/conv_bias/fp16/algos.h | 4 + dnn/src/aarch64/conv_bias/fp32/algos.h | 4 + dnn/src/aarch64/conv_bias/int8/algos.h | 3 + dnn/src/aarch64/conv_bias/opr_impl.cpp | 5 +- dnn/src/aarch64/conv_bias/quint8/algos.h | 3 + dnn/src/aarch64/matrix_mul/algos.cpp | 50 ++++++---- dnn/src/aarch64/matrix_mul/algos.h | 8 +- dnn/src/arm_common/conv_bias/f16/algos.h | 15 ++- dnn/src/arm_common/conv_bias/fp32/algos.h | 34 +++++-- dnn/src/arm_common/conv_bias/int8/algos.h | 38 +++++++- dnn/src/arm_common/conv_bias/int8x8x16/algos.h | 18 ++++ dnn/src/arm_common/conv_bias/opr_impl.cpp | 107 ++++++++++++++++++--- dnn/src/arm_common/conv_bias/opr_impl.h | 5 +- dnn/src/arm_common/conv_bias/quint8/algos.h | 12 +++ dnn/src/arm_common/matrix_mul/algos.h | 22 +++-- dnn/src/arm_common/matrix_mul/opr_impl.cpp | 7 +- dnn/src/armv7/conv_bias/int8/algos.h | 3 + dnn/src/armv7/conv_bias/quint8/algos.h | 4 + dnn/src/armv7/matrix_mul/algos.cpp | 37 ++++--- dnn/src/armv7/matrix_mul/algos.h | 6 +- dnn/src/armv7/matrix_mul/opr_impl.h | 1 - dnn/src/common/utils.h | 28 ++++++ dnn/src/fallback/conv_bias/algos.h | 26 +++++ dnn/src/fallback/conv_bias/common.h | 5 +- dnn/src/fallback/conv_bias/conv1x1/algos.cpp | 3 +- dnn/src/fallback/conv_bias/conv1x1/algos.h | 5 + .../conv_bias/conv1x1/algos_conv1x1_gemv.h | 10 ++ dnn/src/fallback/conv_bias/im2col/algos.h | 26 +++-- dnn/src/fallback/conv_bias/opr_impl.cpp | 104 +++++++++++++++++--- dnn/src/fallback/conv_bias/opr_impl.h | 19 ++++ dnn/src/fallback/convolution/algos.h | 21 +++- dnn/src/fallback/convolution/opr_impl.cpp | 98 ++++++++++++++++--- dnn/src/fallback/convolution/opr_impl.h | 31 ++++++ dnn/src/fallback/matrix_mul/algos.cpp | 2 +- dnn/src/fallback/matrix_mul/algos.h | 10 +- dnn/src/fallback/matrix_mul/gemm_common.h | 30 +++--- dnn/src/fallback/matrix_mul/opr_impl.cpp | 64 +++++++++++- dnn/src/fallback/matrix_mul/opr_impl.h | 19 +++- dnn/src/x86/conv_bias/f32/algos.h | 19 +++- .../x86/conv_bias/int8/algo_usable_preferred.cpp | 2 - dnn/src/x86/conv_bias/int8/algos.h | 24 +++++ dnn/src/x86/conv_bias/opr_impl.cpp | 41 +++++++- dnn/src/x86/conv_bias/opr_impl.h | 2 + dnn/src/x86/matrix_mul/algos.cpp | 23 +++-- dnn/src/x86/matrix_mul/algos.h | 8 +- src/opr/impl/dnn/convolution.cpp | 10 +- 47 files changed, 856 insertions(+), 172 deletions(-) diff --git a/dnn/include/megdnn/oprs/base.h b/dnn/include/megdnn/oprs/base.h index d758c6d0..8f2197af 100644 --- a/dnn/include/megdnn/oprs/base.h +++ b/dnn/include/megdnn/oprs/base.h @@ -76,6 +76,18 @@ enum class AlgoSelectionStrategy { FULL_RUN = 2, }; +/** + * \brief separate algo by datatype for Matmul and conv + */ +enum class AlgoDataType : uint32_t { + FLOAT32 = 1 << 0, + FLOAT16 = 1 << 1, + QINT8X8X32 = 1 << 2, + QUINT8X8X32 = 1 << 3, + INT8X8X16 = 1 << 4, + INT16X16X32 = 1 << 5, +}; + /*! * \brief Abstract representation of an algorithm for implementing * the operator diff --git a/dnn/src/aarch64/conv_bias/fp16/algos.h b/dnn/src/aarch64/conv_bias/fp16/algos.h index 77ab5d76..3b36bfd3 100644 --- a/dnn/src/aarch64/conv_bias/fp16/algos.h +++ b/dnn/src/aarch64/conv_bias/fp16/algos.h @@ -27,6 +27,10 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; SmallVector dispatch_kerns(const NCBKernSizeParam&) const override; + + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; + } }; } // namespace aarch64 } // namespace megdnn diff --git a/dnn/src/aarch64/conv_bias/fp32/algos.h b/dnn/src/aarch64/conv_bias/fp32/algos.h index 6ae1bf00..3340c726 100644 --- a/dnn/src/aarch64/conv_bias/fp32/algos.h +++ b/dnn/src/aarch64/conv_bias/fp32/algos.h @@ -32,6 +32,10 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; SmallVector dispatch_kerns(const NCBKernSizeParam&) const override; + + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; + } }; } // namespace aarch64 diff --git a/dnn/src/aarch64/conv_bias/int8/algos.h b/dnn/src/aarch64/conv_bias/int8/algos.h index 7d79f0cd..afac5922 100644 --- a/dnn/src/aarch64/conv_bias/int8/algos.h +++ b/dnn/src/aarch64/conv_bias/int8/algos.h @@ -45,6 +45,9 @@ public: return static_cast(conv_bias_opr) ->is_matmul_quantized_prefer(param); } + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; + } }; } // namespace aarch64 diff --git a/dnn/src/aarch64/conv_bias/opr_impl.cpp b/dnn/src/aarch64/conv_bias/opr_impl.cpp index e65a96a5..65ec4809 100644 --- a/dnn/src/aarch64/conv_bias/opr_impl.cpp +++ b/dnn/src/aarch64/conv_bias/opr_impl.cpp @@ -50,10 +50,9 @@ SmallVector ConvBiasImpl::algo_pack() { auto&& algos = arm_common::ConvBiasImpl::algo_pack(); algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), sl_algo_pack.direct_algos.end()); - //! We put matmul algos at the end. Because matmul will get privilege when + //! We put matmul algos at the begin. Because matmul will get privilege when //! prefer return true. See - //! fallback::ConvolutionImpl::ncb_1g_get_all_algorithms for more details. - algos.insert(algos.end(), sl_algo_pack.matmul_algos.begin(), + algos.insert(algos.begin(), sl_algo_pack.matmul_algos.begin(), sl_algo_pack.matmul_algos.end()); return std::move(algos); } diff --git a/dnn/src/aarch64/conv_bias/quint8/algos.h b/dnn/src/aarch64/conv_bias/quint8/algos.h index 9f99b0d0..a55ee568 100644 --- a/dnn/src/aarch64/conv_bias/quint8/algos.h +++ b/dnn/src/aarch64/conv_bias/quint8/algos.h @@ -45,6 +45,9 @@ public: return static_cast(conv_bias_opr) ->is_matmul_quantized_prefer(param); } + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL}; + } }; } // namespace aarch64 } // namespace megdnn diff --git a/dnn/src/aarch64/matrix_mul/algos.cpp b/dnn/src/aarch64/matrix_mul/algos.cpp index 89aea82c..e360ced9 100644 --- a/dnn/src/aarch64/matrix_mul/algos.cpp +++ b/dnn/src/aarch64/matrix_mul/algos.cpp @@ -89,7 +89,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern( } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern, "AlgoF32K8x12x1Impl"_hash, - aarch64::matmul::sgemm_8x12, float, float); + aarch64::matmul::sgemm_8x12, float, float, + AlgoDataType::FLOAT32, DEFAULT); /* ===================== F32_MK4_8X12X1 algo ===================== */ bool MatrixMulImpl::AlgoF32MK4_8x12x1::usable( @@ -151,7 +152,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4_8x12x1, megdnn_aarch64_matmul_kern, "AlgoF32MK4_8x12x1Impl"_hash, aarch64::matmul::sgemm_mk4_8x12, float, - float); + float, AlgoDataType::FLOAT32, MK4); /* ===================== F32K4X16X1 algo ===================== */ @@ -210,7 +211,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K4x16x1::get_kern( } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K4x16x1, megdnn_aarch64_matmul_kern, "AlgoF32K4x16x1Impl"_hash, - aarch64::matmul::sgemm_4x16, float, float); + aarch64::matmul::sgemm_4x16, float, float, + AlgoDataType::FLOAT32, MK4); /* ===================== F32MK4_4x16 algo ===================== */ @@ -328,7 +330,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K8x24x1::get_kern( MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K8x24x1, megdnn_aarch64_matmul_kern, "AlogF16K8x24x1Impl"_hash, aarch64::matmul::hgemm_8x24, dt_float16, - dt_float16); + dt_float16, AlgoDataType::FLOAT16, + DEFAULT); /* ===================== F16_MK8_8x8 algo ===================== */ bool MatrixMulImpl::AlgoF16MK8_8x8::usable( @@ -449,7 +452,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x12x4DotProd, megdnn_aarch64_matmul_kern, "AlgoInt8x8x32K8x12x4DotProdImpl"_hash, aarch64::matmul::gemm_s8_8x12, int8_t, - int32_t); + int32_t, AlgoDataType::QINT8X8X32, + DEFAULT); /* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */ namespace { @@ -520,7 +524,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x12x4DotProd, megdnn_aarch64_matmul_kern, "AlgoInt8x8x32MK4_8x12x4DotProdImpl"_hash, aarch64::matmul::gemm_mk4_s8_8x12, int8_t, - int32_t); + int32_t, AlgoDataType::QINT8X8X32, + MK4_DOT); #else /* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */ @@ -593,7 +598,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_4x4x16, megdnn_aarch64_matmul_kern, "AlgoInt8x8x32MK4_4x4x16Impl"_hash, aarch64::matmul::gemm_mk4_s8_4x4, int8_t, - int32_t); + int32_t, AlgoDataType::QINT8X8X32, + MK4); /* ===================== Int8x8x32 K4x4x16 algo ===================== */ namespace { @@ -656,7 +662,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x4x16, megdnn_aarch64_matmul_kern, "AlgoInt8x8x32K4x4x16Impl"_hash, aarch64::matmul::gemm_s8_4x4, int8_t, - int32_t); + int32_t, AlgoDataType::QINT8X8X32, + DEFAULT); /* ===================== Int8x8x32 K8x8x8 algo ===================== */ namespace { void int8x8x32_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) { @@ -717,7 +724,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x8x8, megdnn_aarch64_matmul_kern, "AlgoInt8x8x32K8x8x8Impl"_hash, aarch64::matmul::gemm_s8_8x8, int8_t, - int32_t); + int32_t, AlgoDataType::QINT8X8X32, + DEFAULT); #endif /* ===================== Int8x8x16 K8x8x8 algo ===================== */ @@ -785,7 +793,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K8x8x8, megdnn_aarch64_matmul_kern, "AlgoInt8x8x16K8x8x8Impl"_hash, aarch64::matmul::gemm_s8x8x16_8x8, int8_t, - int16_t); + int16_t, AlgoDataType::INT8X8X16, DEFAULT); /* ===================== Int8x8x16 K4x4x16 algo ===================== */ namespace { void int8x8x16_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) { @@ -852,7 +860,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x4x16, megdnn_aarch64_matmul_kern, "AlgoInt8x8x16K4x4x16Impl"_hash, aarch64::matmul::gemm_s8x8x16_4x4, int8_t, - int16_t); + int16_t, AlgoDataType::INT8X8X16, DEFAULT); /* ===================== Int8x8x16 K16x12x4 algo ===================== */ namespace { @@ -929,7 +937,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_kern( MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( AlgoInt8x8x16MK4_16x12x4, megdnn_aarch64_matmul_kern, "AlgoInt8x8x16MK4_16x12x4Impl"_hash, - aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53, int8_t, int16_t, int16_t); + aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53, int8_t, int16_t, int16_t, + AlgoDataType::INT8X8X16, MK4); /* ===================== Int8x8x16 MK4 4x4x8 algo ===================== */ namespace { @@ -1007,7 +1016,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_4x4x8, megdnn_aarch64_matmul_kern, "AlgoInt8x8x16MK4_4x4x8_Impl"_hash, aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72, - int8_t, int16_t); + int8_t, int16_t, AlgoDataType::INT8X8X16, + MK4); /* ===================== Int16x16x32 K12x8x1 algo ===================== */ namespace { @@ -1078,7 +1088,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x8x1, megdnn_aarch64_matmul_kern, "AlgoInt16x16x32K12x8x1Impl"_hash, aarch64::matmul::gemm_s16_12x8x1, int16_t, - int32_t); + int32_t, AlgoDataType::INT16X16X32, + DEFAULT); /* ===================== Int16x16x32MK8_8x8 algo ===================== */ @@ -1201,7 +1212,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x4DotProd, megdnn_aarch64_matmul_kern, "AlgoQuint8K8x8x4DotProdImpl"_hash, aarch64::matmul::gemm_u8_8x8, uint8_t, - int32_t); + int32_t, AlgoDataType::QUINT8X8X32, + DEFAULT); /* ===================== Quint8 Gemv DotProd algo ===================== */ namespace { void quint8_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { @@ -1307,7 +1319,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8, megdnn_aarch64_matmul_kern, "AlgoQuint8K8x8x8Impl"_hash, aarch64::matmul::gemm_u8_8x8, uint8_t, - int32_t); + int32_t, AlgoDataType::QUINT8X8X32, + DEFAULT); #endif /* ===================== Int8x8x16 K8x8x8 algo ===================== */ @@ -1378,6 +1391,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_kern( MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_K8x8x8, megdnn_aarch64_matmul_kern, "AlgoInt8x8x16MK4_K8x8x8Impl"_hash, - aarch64::matmul::gemm_s8x8x16_mk4_8x8x8, int8_t, - int16_t); + aarch64::matmul::gemm_s8x8x16_mk4_8x8x8, + int8_t, int16_t, AlgoDataType::INT8X8X16, + MK4); // vim: syntax=cpp.doxygen diff --git a/dnn/src/aarch64/matrix_mul/algos.h b/dnn/src/aarch64/matrix_mul/algos.h index 46b7df25..b08eeb7d 100644 --- a/dnn/src/aarch64/matrix_mul/algos.h +++ b/dnn/src/aarch64/matrix_mul/algos.h @@ -61,7 +61,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4) + MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4, AlgoDataType::FLOAT32, MK4) }; class MatrixMulImpl::AlgoF32Gemv final @@ -88,7 +88,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2) + MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::FLOAT16, MK8) }; #endif @@ -253,7 +253,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2) + MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) }; #if __ARM_FEATURE_DOTPROD @@ -281,7 +281,7 @@ public: void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) + MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT) }; #else diff --git a/dnn/src/arm_common/conv_bias/f16/algos.h b/dnn/src/arm_common/conv_bias/f16/algos.h index 0e9d5c32..0f985651 100644 --- a/dnn/src/arm_common/conv_bias/f16/algos.h +++ b/dnn/src/arm_common/conv_bias/f16/algos.h @@ -29,7 +29,7 @@ public: } return m_name.c_str(); } - MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); }; class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase { @@ -44,7 +44,7 @@ public: } return m_name.c_str(); } - MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); }; class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase { @@ -60,7 +60,7 @@ public: return m_name.c_str(); } - MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); }; class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase { public: @@ -74,7 +74,7 @@ public: } return m_name.c_str(); } - MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); }; class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { @@ -90,6 +90,10 @@ public: virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + + ConvAlgoTypePack get_algo_type() const override{ + return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { @@ -103,6 +107,9 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; + } }; } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/fp32/algos.h b/dnn/src/arm_common/conv_bias/fp32/algos.h index 9cf5fb99..e65fe8bd 100644 --- a/dnn/src/arm_common/conv_bias/fp32/algos.h +++ b/dnn/src/arm_common/conv_bias/fp32/algos.h @@ -29,7 +29,7 @@ public: } return m_name.c_str(); } - MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); }; class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase { @@ -44,7 +44,7 @@ public: } return m_name.c_str(); } - MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); }; class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { @@ -59,7 +59,7 @@ public: } return m_name.c_str(); } - MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); }; class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { @@ -74,7 +74,7 @@ public: } return m_name.c_str(); } - MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); }; class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase { @@ -89,7 +89,7 @@ public: } return m_name.c_str(); } - MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); }; //===================== NCHW44 Winograd Support =====================// @@ -106,7 +106,7 @@ public: } return m_name.c_str(); } - MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); }; class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase { @@ -122,7 +122,7 @@ public: } return m_name.c_str(); } - MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); }; class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase { @@ -138,7 +138,7 @@ public: } return m_name.c_str(); } - MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); }; // ================================================================= // @@ -154,6 +154,9 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { @@ -168,6 +171,9 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { @@ -182,6 +188,9 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase { @@ -197,6 +206,9 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase { @@ -212,6 +224,9 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase { @@ -226,6 +241,9 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; + } }; } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/int8/algos.h b/dnn/src/arm_common/conv_bias/int8/algos.h index 196584f7..f9611372 100644 --- a/dnn/src/arm_common/conv_bias/int8/algos.h +++ b/dnn/src/arm_common/conv_bias/int8/algos.h @@ -29,6 +29,10 @@ public: const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override; + + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { @@ -42,6 +46,9 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase { @@ -55,6 +62,9 @@ public: virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase { @@ -68,6 +78,9 @@ public: virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase { @@ -79,6 +92,9 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase { @@ -90,6 +106,9 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; + } }; #if __ARM_FEATURE_DOTPROD @@ -104,6 +123,9 @@ public: size_t get_workspace(const NCBKernSizeParam&) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { @@ -117,6 +139,9 @@ public: size_t get_workspace(const NCBKernSizeParam&) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase { @@ -131,6 +156,9 @@ public: size_t get_workspace(const NCBKernSizeParam&) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase { @@ -148,6 +176,10 @@ public: const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override; + + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; + } }; #endif @@ -163,7 +195,7 @@ public: } return m_name.c_str(); } - MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); }; //=======================input int8 compute fp32 output int8============ @@ -180,7 +212,7 @@ public: } return m_name.c_str(); } - MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); }; //=======================input int8 compute int16 output int8============ @@ -198,7 +230,7 @@ public: return m_name.c_str(); } - MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); }; } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/int8x8x16/algos.h b/dnn/src/arm_common/conv_bias/int8x8x16/algos.h index 26e9c5d1..4591a278 100644 --- a/dnn/src/arm_common/conv_bias/int8x8x16/algos.h +++ b/dnn/src/arm_common/conv_bias/int8x8x16/algos.h @@ -36,6 +36,9 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase { @@ -48,6 +51,9 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { @@ -71,6 +77,9 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase { @@ -84,6 +93,9 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final : public AlgoBase { @@ -96,6 +108,9 @@ public: const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase { @@ -111,6 +126,9 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; + } }; } // namespace arm_common diff --git a/dnn/src/arm_common/conv_bias/opr_impl.cpp b/dnn/src/arm_common/conv_bias/opr_impl.cpp index 374bd031..052d1430 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.cpp +++ b/dnn/src/arm_common/conv_bias/opr_impl.cpp @@ -10,6 +10,7 @@ * implied. */ +#include "megdnn/opr_param_defs.h" #include "src/arm_common/conv_bias/int8/algos.h" #include "src/arm_common/conv_bias/int8x8x16/algos.h" #include "src/arm_common/conv_bias/quint8/algos.h" @@ -122,9 +123,11 @@ public: static CpuOprDelegationStorage<2> storage; auto matmul_opr = storage.get(); + using MatmulFormat = param::MatrixMul::Format; auto&& matmul_algos = static_cast(matmul_opr) - ->algo_pack(); + ->select_algo_type( + {AlgoDataType::FLOAT32, MatmulFormat::MK4}); for (auto&& algo : matmul_algos) { if (algo->type() == nullptr) continue; @@ -133,38 +136,62 @@ public: static_cast(algo), tile_size)); winograd_algos.emplace_back(refhold.back().get()); - refhold.emplace_back(new AlgoFP32WinogradF63( + refhold.emplace_back(new AlgoFP32WinogradF63_4x4( static_cast(algo), tile_size)); winograd_algos.emplace_back(refhold.back().get()); - refhold.emplace_back(new AlgoFP32WinogradF63_4x4( + refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( static_cast(algo), tile_size)); winograd_algos.emplace_back(refhold.back().get()); - refhold.emplace_back(new AlgoFP32WinogradF54( + refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( static_cast(algo), tile_size)); winograd_algos.emplace_back(refhold.back().get()); - refhold.emplace_back(new AlgoFP32WinogradF45( +//! uncomment this when low precision mode is done +#if 0 + refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( static_cast(algo), tile_size)); winograd_algos.emplace_back(refhold.back().get()); - refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( +#endif + //! Qint8x8x32 winograd compute with fp32 + refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44( static_cast(algo), tile_size)); winograd_algos.emplace_back(refhold.back().get()); - refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( + } + } + matmul_algos = static_cast(matmul_opr) + ->select_algo_type({AlgoDataType::FLOAT32, + MatmulFormat::DEFAULT}); + for (auto&& algo : matmul_algos) { + if (algo->type() == nullptr) + continue; + for (uint32_t tile_size : {16, 8, 24, 32}) { + refhold.emplace_back(new AlgoFP32WinogradF63( static_cast(algo), tile_size)); winograd_algos.emplace_back(refhold.back().get()); -//! uncomment this when low precision mode is done -#if 0 - refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( + refhold.emplace_back(new AlgoFP32WinogradF54( static_cast(algo), tile_size)); winograd_algos.emplace_back(refhold.back().get()); -#endif + refhold.emplace_back(new AlgoFP32WinogradF45( + static_cast(algo), + tile_size)); + winograd_algos.emplace_back(refhold.back().get()); + } + } + #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC + matmul_algos = static_cast(matmul_opr) + ->select_algo_type({AlgoDataType::FLOAT16, + MatmulFormat::DEFAULT}); + for (auto&& algo : matmul_algos) { + if (algo->type() == nullptr) + continue; + for (uint32_t tile_size : {16, 8, 24, 32}) { refhold.emplace_back(new AlgoFP16WinogradF23( static_cast(algo), tile_size)); @@ -177,19 +204,33 @@ public: static_cast(algo), tile_size)); winograd_algos.emplace_back(refhold.back().get()); + } + } + matmul_algos = static_cast(matmul_opr) + ->select_algo_type({AlgoDataType::FLOAT16, + MatmulFormat::MK8}); + for (auto&& algo : matmul_algos) { + if (algo->type() == nullptr) + continue; + for (uint32_t tile_size : {16, 8, 24, 32}) { refhold.emplace_back(new AlgoFP16WinogradF23_8x8( static_cast(algo), tile_size)); winograd_algos.emplace_back(refhold.back().get()); + } + } #endif + matmul_algos = static_cast(matmul_opr) + ->select_algo_type({AlgoDataType::INT16X16X32, + MatmulFormat::MK8}); + for (auto&& algo : matmul_algos) { + if (algo->type() == nullptr) + continue; + for (uint32_t tile_size : {16, 8, 24, 32}) { refhold.emplace_back(new AlgoS8WinogradF23_8x8( static_cast(algo), tile_size)); winograd_algos.emplace_back(refhold.back().get()); - refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44( - static_cast(algo), - tile_size)); - winograd_algos.emplace_back(refhold.back().get()); refhold.emplace_back(new AlgoS8WinogradF23_8x8_NCHW44( static_cast(algo), tile_size)); @@ -240,6 +281,42 @@ bool ConvBiasImpl::is_matmul_quantized_prefer( return conv_direct_unusable; } +SmallVector ConvBiasImpl::suggest_algo_category_order( + const NCBKernSizeParam& param) const { + auto IC = param.filter_meta.icpg; + auto OC = param.filter_meta.ocpg; + auto FH = param.filter_meta.spatial[0]; + auto FW = param.filter_meta.spatial[1]; + //! TODO: now winograd only support fast-run + if (param.filter_meta.format == param::ConvBias::Format::NCHW_WINOGRAD || + param.filter_meta.format == param::ConvBias::Format::NCHW44_WINOGRAD || + param.filter_meta.format == param::ConvBias::Format::NCHW88_WINOGRAD) { + return {AlgoCategory::WINOGRAD}; + } + //! im2col + bool im2col_prefer = (IC >= 32 || OC >= 32); + //! quantized algo use matmul when direct algo is unusable + if (param.src_type.category() == DTypeCategory::QUANTIZED) { + im2col_prefer = is_matmul_quantized_prefer(param); + } + //! conv1x1 + im2col_prefer |= (FH == 1 && FW == 1); + //! nchw44 and nchw44-dot hybird mode is direct + if (param.filter_meta.format == param::ConvBias::Format::NCHW44 || + param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT) { + if (IC < 4) { + im2col_prefer = false; + } + } + if (im2col_prefer) { + return {AlgoCategory::IM2COL, AlgoCategory::DIRECT, + AlgoCategory::NAIVE}; + } else { + return {AlgoCategory::DIRECT, AlgoCategory::IM2COL, + AlgoCategory::NAIVE}; + } +} + const char* ConvBiasImpl::get_algorithm_set_name() const { // arm common version 0 return "AC0"; diff --git a/dnn/src/arm_common/conv_bias/opr_impl.h b/dnn/src/arm_common/conv_bias/opr_impl.h index 0176a9af..80fe1c3e 100644 --- a/dnn/src/arm_common/conv_bias/opr_impl.h +++ b/dnn/src/arm_common/conv_bias/opr_impl.h @@ -28,6 +28,9 @@ public: bool is_matmul_quantized_prefer( const ConvBiasImpl::NCBKernSizeParam& ncb_param) const override; + + SmallVector suggest_algo_category_order( + const NCBKernSizeParam& param) const override; class AlgoPack; protected: @@ -90,7 +93,7 @@ private: class AlgoF16Direct; class AlgoF16DirectStride1; #endif -}; + }; } // namespace arm_common } // namespace megdnn diff --git a/dnn/src/arm_common/conv_bias/quint8/algos.h b/dnn/src/arm_common/conv_bias/quint8/algos.h index 2f0de9a2..bda2412f 100644 --- a/dnn/src/arm_common/conv_bias/quint8/algos.h +++ b/dnn/src/arm_common/conv_bias/quint8/algos.h @@ -29,6 +29,9 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase { @@ -42,6 +45,9 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; + } }; #if __ARM_FEATURE_DOTPROD class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { @@ -56,6 +62,9 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; + } }; class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase { @@ -69,6 +78,9 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; virtual SmallVector dispatch_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; + } }; #endif } // namespace arm_common diff --git a/dnn/src/arm_common/matrix_mul/algos.h b/dnn/src/arm_common/matrix_mul/algos.h index 5a7f665d..ef85bc73 100644 --- a/dnn/src/arm_common/matrix_mul/algos.h +++ b/dnn/src/arm_common/matrix_mul/algos.h @@ -26,7 +26,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) + MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::INT8X8X16, DEFAULT) }; class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { @@ -40,7 +40,7 @@ public: void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) + MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) }; class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { @@ -54,7 +54,7 @@ public: void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) + MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4) }; #if __ARM_FEATURE_DOTPROD @@ -69,7 +69,7 @@ public: void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) + MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4_DOT) }; #endif @@ -87,7 +87,7 @@ public: void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) + MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) }; class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { @@ -101,7 +101,7 @@ public: void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4) + MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) }; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -116,7 +116,7 @@ public: void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) + MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::FLOAT16, DEFAULT) }; #endif @@ -131,7 +131,13 @@ public: void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(1, 1, 1, 4) + MEGDNN_OVERRIDE_MATMUL_DESC( + 1, 1, 1, 4, + static_cast( + static_cast(AlgoDataType::FLOAT16) | + static_cast(AlgoDataType::FLOAT32) | + static_cast(AlgoDataType::QINT8X8X32)), + DEFAULT) }; } // namespace arm_common diff --git a/dnn/src/arm_common/matrix_mul/opr_impl.cpp b/dnn/src/arm_common/matrix_mul/opr_impl.cpp index 128da56f..66d93084 100644 --- a/dnn/src/arm_common/matrix_mul/opr_impl.cpp +++ b/dnn/src/arm_common/matrix_mul/opr_impl.cpp @@ -25,7 +25,7 @@ void* const MatrixMulImpl::sm_arm_common_algo_type = class MatrixMulImpl::AlgoPack : NonCopyableObj { AlgoInt8x8x16 int8x8x16; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC - AlgoF16Gemv f16gemv; + AlgoF16Gemv f16gemv; #endif AlgoInt8x8x32Gemv int8x8x32_gemv; AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4; @@ -34,10 +34,11 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { #endif AlgoGevm gevm; AlgoF32GemvMK4 f32_gemv_mk4; + public: AlgoPack() { all_algos.emplace_back(&int8x8x16); -#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC +#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC all_algos.emplace_back(&f16gemv); #endif #if __ARM_FEATURE_DOTPROD @@ -47,7 +48,7 @@ public: all_algos.emplace_back(&int8x8x32_gemv_mk4); all_algos.emplace_back(&f32_gemv_mk4); all_algos.emplace_back(&gevm); - } + } SmallVector all_algos; }; diff --git a/dnn/src/armv7/conv_bias/int8/algos.h b/dnn/src/armv7/conv_bias/int8/algos.h index 8dc948bb..748e92d1 100644 --- a/dnn/src/armv7/conv_bias/int8/algos.h +++ b/dnn/src/armv7/conv_bias/int8/algos.h @@ -37,6 +37,9 @@ public: size_t group = param.filter_meta.group; return {{kimpl, {group, 1_z, 1_z}}}; } + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; + } }; } // namespace armv7 diff --git a/dnn/src/armv7/conv_bias/quint8/algos.h b/dnn/src/armv7/conv_bias/quint8/algos.h index b3e2c131..cd6ed708 100644 --- a/dnn/src/armv7/conv_bias/quint8/algos.h +++ b/dnn/src/armv7/conv_bias/quint8/algos.h @@ -38,6 +38,10 @@ public: size_t group = param.filter_meta.group; return {{kimpl, {group, 1_z, 1_z}}}; } + + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL}; + } }; } // namespace armv7 diff --git a/dnn/src/armv7/matrix_mul/algos.cpp b/dnn/src/armv7/matrix_mul/algos.cpp index 6add7ee5..20bf00ae 100644 --- a/dnn/src/armv7/matrix_mul/algos.cpp +++ b/dnn/src/armv7/matrix_mul/algos.cpp @@ -85,7 +85,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32::get_kern( MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32, megdnn_armv7_matmul_kern, "AlgoF32Impl"_hash, - armv7::matmul::sgemm_4x12, float, float); + armv7::matmul::sgemm_4x12, float, float, + AlgoDataType::FLOAT32, DEFAULT); /* ===================== F32 algo mk4 K4x12 ===================== */ @@ -154,7 +155,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4Pack4x12, megdnn_armv7_matmul_kern, "AlgoF32MK4Pack4x12"_hash, armv7::matmul::sgemm_mk4_pack_4x12, float, - float); + float, AlgoDataType::FLOAT32, MK4); #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC /* ===================== F16 K4x16x1 algo ===================== */ @@ -215,7 +216,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K4x16x1::get_kern( MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K4x16x1, megdnn_armv7_matmul_kern, "AlgoF16K4x16x1"_hash, armv7::matmul::hgemm_4x16, dt_float16, - dt_float16); + dt_float16, AlgoDataType::FLOAT16, + DEFAULT); #endif @@ -280,7 +282,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x2x16, megdnn_armv7_matmul_kern, "AlgoInt8x8x32K4x2x16"_hash, armv7::matmul::gemm_s8_4x2, int8_t, - int32_t); + int32_t, AlgoDataType::QINT8X8X32, + DEFAULT); /* ===================== Int8x8x32 Kernel 4x8x8 algo ===================== */ namespace { @@ -342,7 +345,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x8x8, megdnn_armv7_matmul_kern, "AlgoInt8x8x32K4x8x8"_hash, armv7::matmul::gemm_s8_4x8, int8_t, - int32_t); + int32_t, AlgoDataType::QINT8X8X32, + DEFAULT); /* ===================== Quint8 Kernel 4x8x8 algo ===================== */ namespace { @@ -402,7 +406,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K4x8x8::get_kern( MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K4x8x8, megdnn_armv7_matmul_kern, "AlgoQuint8K4x8x8"_hash, armv7::matmul::gemm_u8_4x8, uint8_t, - int32_t); + int32_t, AlgoDataType::QUINT8X8X32, + DEFAULT); /* ===================== Int8x8x16 Kernel 2x4x16 algo ===================== */ namespace { @@ -468,7 +473,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x2x16, megdnn_armv7_matmul_kern, "AlgoInt8x8x16K4x2x16"_hash, armv7::matmul::gemm_s8x8x16_4x2, int8_t, - int16_t); + int16_t, AlgoDataType::INT8X8X16, DEFAULT); /* ===================== Int8x8x16 Kernel 4x8x8 algo ===================== */ namespace { @@ -534,7 +539,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x8x8, megdnn_armv7_matmul_kern, "AlgoInt8x8x16K4x8x8"_hash, armv7::matmul::gemm_s8x8x16_4x8, int8_t, - int16_t); + int16_t, AlgoDataType::INT8X8X16, DEFAULT); /* =================== Int8x8x16 Kernel MK4 8x8x4 algo ===================*/ @@ -602,7 +607,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16MK4_8x8x4, megdnn_armv7_matmul_kern, "AlgoInt8x8x16MK4_8x8x4"_hash, armv7::matmul::gemm_s8x8x16_mk4_8x8, - int8_t, int16_t, int16_t); + int8_t, int16_t, int16_t, + AlgoDataType::INT8X8X16, MK4); /* ===================== Int16x16x32 Kernel 12x4x1 algo ===================== */ @@ -668,7 +674,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x4x1, megdnn_armv7_matmul_kern, "AlgoInt16x16x32K12x4x1"_hash, armv7::matmul::gemm_s16x16x32_12x4, - int16_t, int32_t); + int16_t, int32_t, + AlgoDataType::INT16X16X32, DEFAULT); #if __ARM_FEATURE_DOTPROD /* ===================== Int8 K6x8x4 algo ===================== */ namespace { @@ -724,7 +731,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K6x8x4, megdnn_armv7_matmul_kern, "AlgoInt8x8x32K6x8x4"_hash, armv7::matmul::gemm_dots8_6x8, int8_t, - int32_t); + int32_t, AlgoDataType::QINT8X8X32, + DEFAULT); /* ===================== Quint8 K4x8x4 algo ===================== */ namespace { void quint8_dot_k4x8x4_kern(const MatrixMulImpl::KernParam& kern_param) { @@ -786,7 +794,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8DotK4x8x4, megdnn_armv7_matmul_kern, "AlgoQuint8DotK4x8x4"_hash, armv7::matmul::gemm_dot_quint8_4x8, - uint8_t, int32_t); + uint8_t, int32_t, + AlgoDataType::QUINT8X8X32, DEFAULT); /* ======================== Int8 MK4 8x4x4 dot algo ======================== */ namespace { @@ -854,7 +863,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x4x4DotProd, megdnn_armv7_matmul_kern, "AlgoInt8x8x32MK4_8x4x4DotProd"_hash, armv7::matmul::gemm_mk4_dots8_8x4, int8_t, - int32_t); + int32_t, AlgoDataType::QINT8X8X32, MK4_DOT); #endif /* ===================== F32 algo K4x8 ===================== */ @@ -1099,6 +1108,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_4x2x16, megdnn_armv7_matmul_kern, "AlgoInt8x8x32MK4_4x2x16"_hash, armv7::matmul::gemm_mk4_s8_4x2, int8_t, - int32_t); + int32_t, AlgoDataType::QINT8X8X32, MK4); // vim: syntax=cpp.doxygen diff --git a/dnn/src/armv7/matrix_mul/algos.h b/dnn/src/armv7/matrix_mul/algos.h index 60e35f1b..481662fd 100644 --- a/dnn/src/armv7/matrix_mul/algos.h +++ b/dnn/src/armv7/matrix_mul/algos.h @@ -50,7 +50,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4) + MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4) }; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -73,7 +73,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2) + MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::FLOAT16, MK8) }; #endif #if __ARM_FEATURE_DOTPROD @@ -205,7 +205,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2) + MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) }; class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { diff --git a/dnn/src/armv7/matrix_mul/opr_impl.h b/dnn/src/armv7/matrix_mul/opr_impl.h index 7e573d07..701cef4e 100644 --- a/dnn/src/armv7/matrix_mul/opr_impl.h +++ b/dnn/src/armv7/matrix_mul/opr_impl.h @@ -18,7 +18,6 @@ namespace armv7 { class MatrixMulImpl : public arm_common::MatrixMulImpl { public: using arm_common::MatrixMulImpl::MatrixMulImpl; - SmallVector algo_pack() override; private: diff --git a/dnn/src/common/utils.h b/dnn/src/common/utils.h index 66807c96..ec04802f 100644 --- a/dnn/src/common/utils.h +++ b/dnn/src/common/utils.h @@ -110,6 +110,11 @@ void __log__(LogLevel level, const char* file, const char* func, int line, } while (0) #endif // megdnn_ENABLE_LOGGING +template +constexpr int32_t cast_int(T data) { + return static_cast(data); +} + /* helper functions */ /** * \brief Get the next `stride' index lexicographically. @@ -187,6 +192,29 @@ std::unique_ptr make_unique(Args&&... args) { return std::unique_ptr(new T(std::forward(args)...)); } +/*! + * \brief check whether the source enum contain the target data type enum + */ +bool inline contain_data_type(detail::AlgoDataType source, + detail::AlgoDataType target) { + return static_cast(static_cast(source) & + static_cast(target)); +} + +/*! + * \brief get the source enum contain the data type number + */ +template +size_t nr_type_contain(T index) { + uint32_t sr_index = static_cast(index); + size_t nr_type = 0; + while (sr_index != 0) { + nr_type++; + sr_index &= (sr_index - 1); + } + return nr_type; +} + /** * \brief Aligned workspace bundle. * diff --git a/dnn/src/fallback/conv_bias/algos.h b/dnn/src/fallback/conv_bias/algos.h index c0f95521..e70959bb 100644 --- a/dnn/src/fallback/conv_bias/algos.h +++ b/dnn/src/fallback/conv_bias/algos.h @@ -26,6 +26,16 @@ public: AlgoSelectionStrategy algo_selection_strategy) const override; size_t get_workspace(const NCBKernSizeParam& param) const override; SmallVector dispatch_kerns(const NCBKernSizeParam&) const override; + + ConvAlgoTypePack get_algo_type() const override { + auto support_data_type = static_cast( + static_cast(AlgoDataType::FLOAT16) | + static_cast(AlgoDataType::FLOAT32) | + static_cast(AlgoDataType::INT8X8X16) | + static_cast(AlgoDataType::QINT8X8X32) | + static_cast(AlgoDataType::QUINT8X8X32)); + return {support_data_type, AlgoCategory::NAIVE}; + } }; class ConvBiasImpl::AlgoWinogradF32 final : public AlgoBase { @@ -46,6 +56,10 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; SmallVector dispatch_kerns(const NCBKernSizeParam&) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; + } + private: MatrixMulImpl::AlgoBase* m_matmul_algo; mutable std::string m_name; @@ -70,6 +84,10 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; SmallVector dispatch_kerns(const NCBKernSizeParam&) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; + } + private: MatrixMulImpl::AlgoBase* m_matmul_algo; mutable std::string m_name; @@ -94,6 +112,10 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; SmallVector dispatch_kerns(const NCBKernSizeParam&) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; + } + private: MatrixMulImpl::AlgoBase* m_matmul_algo; mutable std::string m_name; @@ -118,6 +140,10 @@ public: size_t get_workspace(const NCBKernSizeParam& param) const override; SmallVector dispatch_kerns(const NCBKernSizeParam&) const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; + } + private: MatrixMulImpl::AlgoBase* m_matmul_algo; mutable std::string m_name; diff --git a/dnn/src/fallback/conv_bias/common.h b/dnn/src/fallback/conv_bias/common.h index 75cc155c..fb9caf2e 100644 --- a/dnn/src/fallback/conv_bias/common.h +++ b/dnn/src/fallback/conv_bias/common.h @@ -140,7 +140,7 @@ using BiasMode = ConvBiasForward::BiasMode; break; \ } -#define MEGDNN_WINOGRAD_ALGO_FUN_DECLARE() \ +#define MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(_algo_data_type) \ bool is_reproducible() const override { return true; } \ bool usable(const NCBKernSizeParam& param, \ AlgoSelectionStrategy algo_selection_strategy) const override; \ @@ -153,6 +153,9 @@ using BiasMode = ConvBiasForward::BiasMode; const override; \ virtual SmallVector dispatch_preprocess_kerns( \ const NCBKernSizeParam& param) const override; \ + ConvAlgoTypePack get_algo_type() const override { \ + return {_algo_data_type, AlgoCategory::WINOGRAD}; \ + } \ \ private: \ fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; \ diff --git a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp index 0ff7852b..01bbd194 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/algos.cpp +++ b/dnn/src/fallback/conv_bias/conv1x1/algos.cpp @@ -288,7 +288,8 @@ bool ConvBiasImpl::AlgoConv1x1::is_preferred( size_t OH = param.osz[0]; size_t OW = param.osz[1]; if (OH * OW != 1) { - return true; + return m_matmul_algo->algoset() != + MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV; } else { #if (MEGDNN_ARMV7 || MEGDNN_AARCH64) if (param.src_type.enumv() == DTypeEnum::Int8 && diff --git a/dnn/src/fallback/conv_bias/conv1x1/algos.h b/dnn/src/fallback/conv_bias/conv1x1/algos.h index 6c7f5bf0..6c3bc4ef 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/algos.h +++ b/dnn/src/fallback/conv_bias/conv1x1/algos.h @@ -56,6 +56,11 @@ public: SmallVector dispatch_preprocess_kerns( const NCBKernSizeParam& param) const override; + ConvAlgoTypePack get_algo_type() const override{ + return {m_matmul_algo->matmul_description().algo_type.data_type, + AlgoCategory::IM2COL}; + } + protected: size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; diff --git a/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h b/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h index 3f266a60..b56bb138 100644 --- a/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h +++ b/dnn/src/fallback/conv_bias/conv1x1/algos_conv1x1_gemv.h @@ -34,6 +34,16 @@ public: bool is_preferred(const NCBKernSizeParam&) const override; + ConvAlgoTypePack get_algo_type() const override { + auto support_data_type = static_cast( + static_cast(AlgoDataType::FLOAT16) | + static_cast(AlgoDataType::FLOAT32) | + static_cast(AlgoDataType::INT8X8X16) | + static_cast(AlgoDataType::QINT8X8X32) | + static_cast(AlgoDataType::QUINT8X8X32)); + return {support_data_type, AlgoCategory::IM2COL}; + } + protected: size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; }; diff --git a/dnn/src/fallback/conv_bias/im2col/algos.h b/dnn/src/fallback/conv_bias/im2col/algos.h index b699f571..919ae250 100644 --- a/dnn/src/fallback/conv_bias/im2col/algos.h +++ b/dnn/src/fallback/conv_bias/im2col/algos.h @@ -48,15 +48,25 @@ public: SmallVector dispatch_preprocess_kerns( const NCBKernSizeParam& param) const override; bool is_preferred(const NCBKernSizeParam& param) const override { - if (param.src_type.category() == DTypeCategory::QUANTIZED) { - static CpuOprDelegationStorage<1> storage; - auto conv_bias_opr = storage.get(); - return static_cast(conv_bias_opr) - ->is_matmul_quantized_prefer(param); + size_t OH = param.osz[0]; + size_t OW = param.osz[1]; + //! gemm and oh * ow > 1 is prefer + //! gemv and oh * ow == 1 is prefer + if ((m_matmul_algo->algoset() != + MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV && + OH * OW > 1) || + (m_matmul_algo->algoset() == + MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV && + OH * OW == 1)) { + return true; + } else { + return false; } - auto&& fm = param.filter_meta; - auto OC = fm.ocpg, IC = fm.icpg; - return OC >= 32 || IC >= 32; + } + + ConvAlgoTypePack get_algo_type() const override { + return {m_matmul_algo->matmul_description().algo_type.data_type, + AlgoCategory::IM2COL}; } private: diff --git a/dnn/src/fallback/conv_bias/opr_impl.cpp b/dnn/src/fallback/conv_bias/opr_impl.cpp index 45b09bc4..e8be6c08 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.cpp +++ b/dnn/src/fallback/conv_bias/opr_impl.cpp @@ -48,11 +48,26 @@ void incr_ptr(T*& dst, ptrdiff_t delta) { } // namespace +#if MEGDNN_X86 +#define SKIP_GEMV() +//! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may +//! fallback to naive implementation, which may cause performance very low, so +//! here we just enable im2col for gemv in x86 backend. +//! FIXME: remove it when we add direct conv support for int8x8x16 +#else +#define SKIP_GEMV() \ + if (algo->algoset() == MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { \ + continue; \ + } +#endif + + class ConvBiasImpl::AlgoPack : NonCopyableObj { AlgoNaive algo_naive; SmallVector> refhold; public: + AlgoPack() { refhold.emplace_back(new AlgoConv1x1Gemv()); all_algos.emplace_back(refhold.back().get()); @@ -110,8 +125,6 @@ public: all_algos.emplace_back(refhold.back().get()); #endif } - //! reverse matmul algo, when the algo is_prefer can be selected first - std::reverse(all_algos.begin(), all_algos.end()); all_algos.emplace_back(&algo_naive); } SmallVector all_algos; @@ -121,6 +134,22 @@ SmallVector ConvBiasImpl::algo_pack() { static AlgoPack sl_algo_pack; return sl_algo_pack.all_algos; } + +SmallVector ConvBiasImpl::select_algo_type( + ConvAlgoTypePack target_type) { + megdnn_assert(nr_type_contain(target_type.data_type), + "ConvBias algo selection only support one type"); + SmallVector algos; + for (auto&& algo : algo_pack()) { + auto algo_type = algo->get_algo_type(); + if (contain_data_type(algo_type.data_type, target_type.data_type) && + algo_type.algo_category == target_type.algo_category) { + algos.push_back(algo); + } + } + return algos; +} + bool ConvBiasImpl::is_naive_algo(ConvBiasImpl::Algorithm* algo) { return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0; } @@ -248,12 +277,32 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic( ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb( const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, bool reproducible) { - for (auto i : get_all_algorithms_with_ncb(param)) { - if (static_cast(i)->usable_reproducible( - param, AlgoSelectionStrategy::HEURISTIC, reproducible) && - NCB_ALGO_FUNC(get_workspace, i, param) <= - workspace_limit_in_bytes) { - return i; + auto algo_data_type = param.deduce_algo_data_type(); + auto suggest_category_order = suggest_algo_category_order(param); + for (auto category : suggest_category_order) { + auto&& origin_algos = select_algo_type({algo_data_type, category}); + ConvBiasImpl::Algorithm* heuristic_algo = nullptr; + for (auto i : origin_algos) { + bool usable_reproducible = + static_cast(i)->usable_reproducible( + param, AlgoSelectionStrategy::HEURISTIC, + reproducible); + if (usable_reproducible && + static_cast(i)->get_workspace(param) <= + workspace_limit_in_bytes) { + //! store the first usable algo if no prefer algo, choose it as + //! the target algo + if (!heuristic_algo) { + heuristic_algo = i; + } + //! choose the first prefer algo + if (i->is_preferred(param)) { + return i; + } + } + } + if (heuristic_algo) { + return heuristic_algo; } } return nullptr; @@ -300,9 +349,8 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param( sizeof(ConvolutionImpl::CanonizedFilterMeta), "sizeof CanonizedFilterMeta in convolution and conv_bias " "should be equal"); - CanonizedFilterMeta fm = check_layout_fwd(src, filter, dst); - ConvolutionImpl::CanonizedFilterMeta conv_fm; - conv_fm.copy_from(fm); + auto&& fm = check_layout_fwd(src, filter, dst); + auto& conv_fm = reinterpret_cast(fm); param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT; if (param().format == Param::Format::NCHW_WINOGRAD || @@ -367,7 +415,7 @@ ConvBiasImpl::NCBKernParam ConvBiasImpl::make_ncb_kern_param( void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param, ConvBiasImpl::Algorithm* algo) { - auto ncb_kerns = NCB_ALGO_FUNC(dispatch_kerns, algo, param); + auto&& ncb_kerns = NCB_ALGO_FUNC(dispatch_kerns, algo, param); for (auto&& kernel : ncb_kerns) { auto run = [kernel, param](size_t index, size_t thread_id) { CpuNDRange ndrange_id(kernel.global_size, index); @@ -380,7 +428,7 @@ void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param, void ConvBiasImpl::exec_preprocess_with_ncb_kern( const NCBKernParam& param, ConvBiasImpl::Algorithm* algo) { - auto ncb_kerns = NCB_ALGO_FUNC(dispatch_preprocess_kerns, algo, param); + auto&& ncb_kerns = NCB_ALGO_FUNC(dispatch_preprocess_kerns, algo, param); for (auto&& kernel : ncb_kerns) { auto run = [kernel, param](size_t index, size_t thread_id) { CpuNDRange ndrange_id(kernel.global_size, index); @@ -405,7 +453,6 @@ std::vector ConvBiasImpl::get_all_algorithms_with_ncb( } } } - std::reverse(prefer_algos.begin(), prefer_algos.end()); //! Prefer algo inserted from begin algos.insert(algos.begin(), prefer_algos.begin(), prefer_algos.end()); return algos; @@ -425,6 +472,35 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm( return m_prev_selected_algo; } +SmallVector ConvBiasImpl::suggest_algo_category_order( + const NCBKernSizeParam& param) const { + auto IC = param.filter_meta.icpg; + auto OC = param.filter_meta.ocpg; + auto FH = param.filter_meta.spatial[0]; + auto FW = param.filter_meta.spatial[1]; + //! TODO: now winograd only support in fast-run + if (param.filter_meta.format == param::ConvBias::Format::NCHW_WINOGRAD || + param.filter_meta.format == param::ConvBias::Format::NCHW44_WINOGRAD || + param.filter_meta.format == param::ConvBias::Format::NCHW88_WINOGRAD) { + return {AlgoCategory::WINOGRAD}; + } + //! im2col + matmul + bool im2col_prefer = (IC >= 32 || OC >= 32); + //! quantized algo use matmul when direct algo is unusable + if (param.src_type.category() == DTypeCategory::QUANTIZED) { + im2col_prefer = is_matmul_quantized_prefer(param); + } + //! conv1x1 + im2col_prefer |= (FH == 1 && FW == 1); + if (im2col_prefer) { + return {AlgoCategory::IM2COL, AlgoCategory::DIRECT, + AlgoCategory::NAIVE}; + } else { + return {AlgoCategory::DIRECT, AlgoCategory::IM2COL, + AlgoCategory::NAIVE}; + } +} + const char* ConvBiasImpl::get_algorithm_set_name() const { // fallback version 0 return "F0"; diff --git a/dnn/src/fallback/conv_bias/opr_impl.h b/dnn/src/fallback/conv_bias/opr_impl.h index 082bb2e5..86717452 100644 --- a/dnn/src/fallback/conv_bias/opr_impl.h +++ b/dnn/src/fallback/conv_bias/opr_impl.h @@ -18,6 +18,8 @@ #include "src/fallback/matrix_mul/opr_impl.h" #include "src/naive/conv_bias/opr_impl.h" +#include + namespace megdnn { namespace fallback { @@ -44,6 +46,7 @@ class ConvBiasImpl : public naive::ConvBiasForwardImpl { public: using naive::ConvBiasForwardImpl::ConvBiasForwardImpl; using AlgoSelectionStrategy = detail::AlgoSelectionStrategy; + using AlgoDataType = detail::AlgoDataType; //! implemented by exec_with_ncb_kern() void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, @@ -94,6 +97,8 @@ public: size_t workspace_limit_in_bytes, bool reproducible) override; + + //! size param for kernels with non-contiguous batch struct NCBKernSizeParam : ConvolutionImpl::NCBKernSizeParam { NCBKernSizeParam() = default; @@ -244,6 +249,9 @@ public: return (!reproducible || is_reproducible()) && usable(param, algo_selection_strategy); } + + //! get the type of the algo + virtual ConvAlgoTypePack get_algo_type() const = 0; }; /** @@ -251,6 +259,17 @@ public: */ virtual SmallVector algo_pack(); + /** + * \brief select algo according to input algo type + */ + SmallVector select_algo_type(ConvAlgoTypePack algo_type); + + /** + * \brief suggest algo category according to the param + */ + virtual SmallVector suggest_algo_category_order( + const NCBKernSizeParam& param) const; + protected: virtual void exec_with_ncb_kern(const NCBKernParam& param, ConvBiasImpl::Algorithm* algo); diff --git a/dnn/src/fallback/convolution/algos.h b/dnn/src/fallback/convolution/algos.h index 591feb0f..f8f73fdd 100644 --- a/dnn/src/fallback/convolution/algos.h +++ b/dnn/src/fallback/convolution/algos.h @@ -83,6 +83,10 @@ public: SmallVector dispatch_kern( const NCBKernSizeParam& /*param*/) const override; + + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::FLOAT32, AlgoCategory::NAIVE}; + } }; class ConvolutionImpl::AlgoNaive final : public AlgoBase { @@ -96,11 +100,17 @@ public: SmallVector dispatch_kern( const NCBKernSizeParam& /*param*/) const override; + + ConvAlgoTypePack get_algo_type() const override { + auto support_data_type = static_cast( + static_cast(AlgoDataType::INT8X8X16) | + static_cast(AlgoDataType::QINT8X8X32) | + static_cast(AlgoDataType::QUINT8X8X32)); + return {support_data_type, AlgoCategory::NAIVE}; + } }; class ConvolutionImpl::AlgoDefault final : public AlgoBase { - static ConvBiasImpl::NCBKernSizeParam init_conv_bias_param( - const NCBKernSizeParam& param); WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; static SmallVector get_kimpl(ConvBiasImpl::AlgoBase* algo, const NCBKernSizeParam& param); @@ -136,6 +146,13 @@ public: //! select matmul to the highest preference bool is_preferred(const NCBKernSizeParam& param) const override; + static ConvBiasImpl::NCBKernSizeParam init_conv_bias_param( + const NCBKernSizeParam& param); + + ConvAlgoTypePack get_algo_type() const override { + return m_algorithm->get_algo_type(); + } + private: std::string m_name; ConvBiasImpl::AlgoBase* m_algorithm; diff --git a/dnn/src/fallback/convolution/opr_impl.cpp b/dnn/src/fallback/convolution/opr_impl.cpp index 8836fc74..21adb166 100644 --- a/dnn/src/fallback/convolution/opr_impl.cpp +++ b/dnn/src/fallback/convolution/opr_impl.cpp @@ -23,6 +23,7 @@ #include "midout.h" #include +#include MIDOUT_DECL(megdnn_fb_convbwd_float) @@ -75,6 +76,22 @@ SmallVector ConvolutionImpl::algo_pack() { static AlgoPack sl_algo_pack; return sl_algo_pack.all_algos; } + +SmallVector ConvolutionImpl::select_algo_type( + ConvAlgoTypePack target_type) { + megdnn_assert(nr_type_contain(target_type.data_type), + "ConvBias algo selection only support one type"); + SmallVector algos; + for (auto&& algo : algo_pack()) { + auto algo_type = algo->get_algo_type(); + if (contain_data_type(algo_type.data_type, target_type.data_type) && + algo_type.algo_category == target_type.algo_category) { + algos.push_back(algo); + } + } + return algos; +} + bool ConvolutionImpl::is_naive_algo(ConvolutionImpl::Algorithm* algo) { return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0; } @@ -249,9 +266,9 @@ ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param( void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param, Algorithm* algo) { - auto kerns = NCB_ALGO_FUNC(dispatch_preprocess_kern, algo, param); - auto fallback_handle = handle(); - for (auto kernel : kerns) { + auto&& kerns = NCB_ALGO_FUNC(dispatch_preprocess_kern, algo, param); + auto&& fallback_handle = handle(); + for (auto&& kernel : kerns) { megdnn_assert( param.filter_meta.format == Param::Format::NCHW || param.filter_meta.format == Param::Format::NHWC || @@ -270,9 +287,9 @@ void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param, void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, Algorithm* algo) { - auto kerns = NCB_ALGO_FUNC(dispatch_kern, algo, param); - auto fallback_handle = handle(); - for (auto kernel : kerns) { + auto&& kerns = NCB_ALGO_FUNC(dispatch_kern, algo, param); + auto&& fallback_handle = handle(); + for (auto&& kernel : kerns) { megdnn_assert( param.filter_meta.format == Param::Format::NCHW || param.filter_meta.format == Param::Format::NHWC || @@ -292,13 +309,32 @@ void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb( const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, bool reproducible) { - for (auto i : get_all_algorithms_with_ncb(param)) { - bool usable_reproducible = - static_cast(i)->usable_reproducible( - param, AlgoSelectionStrategy::HEURISTIC, reproducible); - if (usable_reproducible && NCB_ALGO_FUNC(get_workspace, i, param) <= - workspace_limit_in_bytes) { - return i; + auto algo_data_type = param.deduce_algo_data_type(); + auto suggest_category_order = suggest_algo_category_order(param); + for (auto category : suggest_category_order) { + auto&& origin_algos = select_algo_type({algo_data_type, category}); + ConvolutionImpl::Algorithm* heuristic_algo = nullptr; + for (auto i : origin_algos) { + bool usable_reproducible = + static_cast(i)->usable_reproducible( + param, AlgoSelectionStrategy::HEURISTIC, + reproducible); + if (usable_reproducible && + static_cast(i)->get_workspace(param) <= + workspace_limit_in_bytes) { + //! store the first usable algo if no prefer algo, choose it as + //! the target algo + if (!heuristic_algo) { + heuristic_algo = i; + } + //! choose the first prefer algo + if (i->is_preferred(param)) { + return i; + } + } + } + if (heuristic_algo) { + return heuristic_algo; } } return nullptr; @@ -317,8 +353,6 @@ ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) { } } } - std::reverse(prefer_algos.begin(), prefer_algos.end()); - //! Prefer algo inserted from begin ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end()); return ret; } @@ -337,11 +371,45 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm( return m_prev_selected_algo; } +SmallVector ConvolutionImpl::suggest_algo_category_order( + const NCBKernSizeParam& param) const { + static CpuOprDelegationStorage<1> storage; + auto conv_bias_opr = storage.get(); + auto conv_bias_param = + ConvolutionImpl::AlgoDefault::init_conv_bias_param(param); + return static_cast(conv_bias_opr) + ->suggest_algo_category_order(conv_bias_param); +} + const char* ConvolutionImpl::get_algorithm_set_name() const { // fallback version 0 return "F0"; } +ConvolutionImpl::AlgoDataType +ConvolutionImpl::NCBKernSizeParam::deduce_algo_data_type() const { + if (src_type.enumv() == DTypeEnum::Float32) { + return ConvolutionImpl::AlgoDataType::FLOAT32; +#if !MEGDNN_DISABLE_FLOAT16 + } else if (src_type.enumv() == DTypeEnum::Float16) { + return ConvolutionImpl::AlgoDataType::FLOAT16; +#endif + } else if (src_type.enumv() == DTypeEnum::Int8 || + src_type.enumv() == DTypeEnum::QuantizedS8) { + if (dst_type.enumv() == DTypeEnum::Int16) { + return ConvolutionImpl::AlgoDataType::INT8X8X16; + } else { + return ConvolutionImpl::AlgoDataType::QINT8X8X32; + } + } else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) { + return ConvolutionImpl::AlgoDataType::QUINT8X8X32; + } else { + megdnn_throw(ssprintf("megdnn not support data type of %s * %s -> %s\n", + src_type.name(), filter_type.name(), + dst_type.name())); + } +} + /* ===================== ConvolutionBackwardData ===================== */ void* const ConvolutionBackwardDataImpl::sm_fallback_deconv_algo_type = diff --git a/dnn/src/fallback/convolution/opr_impl.h b/dnn/src/fallback/convolution/opr_impl.h index f2dbf198..62843d5a 100644 --- a/dnn/src/fallback/convolution/opr_impl.h +++ b/dnn/src/fallback/convolution/opr_impl.h @@ -10,11 +10,28 @@ */ #pragma once +#include "megdnn/oprs/base.h" #include "src/common/utils.h" #include "src/fallback/handle.h" #include "src/naive/convolution/opr_impl.h" namespace megdnn { + +/** + * \brief Convolutino algo category + */ +enum class AlgoCategory : int32_t { + DIRECT = 0, + IM2COL = 1, + WINOGRAD = 2, + NAIVE = 3, +}; + +struct ConvAlgoTypePack { + detail::AlgoDataType data_type : 32; + AlgoCategory algo_category : 32; +}; + namespace fallback { /*! @@ -33,6 +50,7 @@ class ConvolutionImpl : public naive::ConvolutionForwardImpl { public: using naive::ConvolutionForwardImpl::ConvolutionForwardImpl; using AlgoSelectionStrategy = detail::AlgoSelectionStrategy; + using AlgoDataType = detail::AlgoDataType; //! implemented by exec_with_ncb_kern() void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, @@ -86,6 +104,8 @@ public: size_t nr_threads; //! weight_preprocess info const PreprocessedFilter* preprocessed_filter; + //! get the data type category of the param for select the algo + AlgoDataType deduce_algo_data_type() const; }; //! memory param for kernels with non-contiguous batch @@ -211,6 +231,9 @@ public: return (!reproducible || is_reproducible()) && usable(param, algo_selection_strategy); } + + //! get the type of the algo + virtual ConvAlgoTypePack get_algo_type() const = 0; }; /** @@ -218,6 +241,11 @@ public: */ virtual SmallVector algo_pack(); + /** + * \brief select algo according to input algo type + */ + SmallVector select_algo_type(ConvAlgoTypePack algo_type); + protected: virtual void exec_with_ncb_kern(const NCBKernParam& param, Algorithm* algo); @@ -258,6 +286,9 @@ private: _megdnn_tensor_out dst, const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace); + + SmallVector suggest_algo_category_order( + const NCBKernSizeParam& param) const; }; class ConvolutionBackwardDataImpl : public naive::ConvolutionBackwardDataImpl { diff --git a/dnn/src/fallback/matrix_mul/algos.cpp b/dnn/src/fallback/matrix_mul/algos.cpp index 375c47b7..d3b78292 100644 --- a/dnn/src/fallback/matrix_mul/algos.cpp +++ b/dnn/src/fallback/matrix_mul/algos.cpp @@ -76,7 +76,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern( MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_fb_matmul_f32_kern, 5, matmul::fallback::sgemm_8x12, float, - float); + float, AlgoDataType::FLOAT32, DEFAULT); /* ===================== gemv algo ===================== */ bool MatrixMulImpl::AlgoGemv::usable( diff --git a/dnn/src/fallback/matrix_mul/algos.h b/dnn/src/fallback/matrix_mul/algos.h index 670f116a..8abb64bc 100644 --- a/dnn/src/fallback/matrix_mul/algos.h +++ b/dnn/src/fallback/matrix_mul/algos.h @@ -37,7 +37,15 @@ public: kern_t get_kern(const KernSizeParam&) const override; AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) + MEGDNN_OVERRIDE_MATMUL_DESC( + 8, 16, 1, 4, + static_cast( + static_cast(AlgoDataType::FLOAT16) | + static_cast(AlgoDataType::FLOAT32) | + static_cast(AlgoDataType::INT8X8X16) | + static_cast(AlgoDataType::QINT8X8X32) | + static_cast(AlgoDataType::QUINT8X8X32)), + DEFAULT) }; } // namespace fallback diff --git a/dnn/src/fallback/matrix_mul/gemm_common.h b/dnn/src/fallback/matrix_mul/gemm_common.h index ab63d374..c45b3f89 100644 --- a/dnn/src/fallback/matrix_mul/gemm_common.h +++ b/dnn/src/fallback/matrix_mul/gemm_common.h @@ -352,13 +352,15 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K, DType dtype_c) \ : A_dtype(dtype_a), B_dtype(dtype_b), C_dtype(dtype_c) {} -#define MEGDNN_OVERRIDE_MATMUL_DESC(_m, _n, _k, _packa_type_size) \ - MatmulDescription matmul_description() const override { \ - MatmulDescription mdesc; \ - mdesc.packmode = packmode(); \ - mdesc.innerblocksize = {_m, _n, _k}; \ - mdesc.packa_type_size = _packa_type_size; \ - return mdesc; \ +#define MEGDNN_OVERRIDE_MATMUL_DESC(_m, _n, _k, _packa_type_size, _data_type, \ + _format) \ + MatmulDescription matmul_description() const override { \ + MatmulDescription mdesc; \ + mdesc.packmode = packmode(); \ + mdesc.innerblocksize = {_m, _n, _k}; \ + mdesc.packa_type_size = _packa_type_size; \ + mdesc.algo_type = {_data_type, Param::Format::_format}; \ + return mdesc; \ } #define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL() \ @@ -373,7 +375,7 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K, #define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \ _algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \ - _packa_type) \ + _packa_type, _support_data_type, _format) \ \ MatrixMulImpl::kern_naked_t MatrixMulImpl::_algo_name::get_kern_naked( \ const KernSizeParam&) const { \ @@ -474,14 +476,16 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K, mdesc.innerblocksize = {_strategy::KERNEL_H, _strategy::KERNEL_W, \ _strategy::UNROLL_K}; \ mdesc.packa_type_size = sizeof(_packa_type); \ + mdesc.algo_type = {_support_data_type, Param::Format::_format}; \ return mdesc; \ } -#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \ - _algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type) \ - MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(_algo_name, _midout_name, \ - _mid_index, _strategy, \ - _i_type, _c_type, _i_type) +#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \ + _algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \ + _support_data_type, _format) \ + MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \ + _algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \ + _i_type, _support_data_type, _format) } // namespace matmul } // namespace megdnn diff --git a/dnn/src/fallback/matrix_mul/opr_impl.cpp b/dnn/src/fallback/matrix_mul/opr_impl.cpp index 9da3fdaf..6ba8c0b6 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.cpp +++ b/dnn/src/fallback/matrix_mul/opr_impl.cpp @@ -38,6 +38,22 @@ SmallVector MatrixMulImpl::algo_pack() { return s_algo_pack.all_algos; } +SmallVector MatrixMulImpl::select_algo_type( + AlgoTypePack index) { + megdnn_assert(nr_type_contain(index.data_type), + "Matmul algo selection only support one type"); + SmallVector algos; + for (auto&& algo : algo_pack()) { + auto algo_desc = algo->matmul_description(); + if (contain_data_type(algo_desc.algo_type.data_type, + index.data_type) && + algo_desc.algo_type.format == index.format) { + algos.push_back(algo); + } + } + return algos; +} + std::vector MatrixMulImpl::get_all_algorithms( const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { std::vector gemm_algos, gemv_algos; @@ -71,17 +87,25 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( "require reproducible algorithm, but given algorithm is not " "reproducible"); } - - auto algos = get_all_algorithms(A, B, C); + AlgoTypePack algo_type; + algo_type.data_type = kern_size_param.deduce_algo_data_type(); + algo_type.format = kern_size_param.format; + auto algos = select_algo_type(algo_type); + Algorithm *heuristic_algo = nullptr; for (auto&& algo : algos) { - if (static_cast(algo)->preferred_reproducible( + if (static_cast(algo)->usable(kern_size_param) && + static_cast(algo)->preferred_reproducible( kern_size_param, reproducible) && static_cast(algo)->get_workspace(kern_size_param) <= workspace_limit_in_bytes) { - return algo; + if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { + return algo; + } else if (!heuristic_algo) { + heuristic_algo = algo; + } } } - return nullptr; + return heuristic_algo; } MatrixMulImpl::KernSizeParam MatrixMulImpl::make_kern_size_param( @@ -150,4 +174,34 @@ void MatrixMulImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, naive::MatrixMulForwardImpl::exec(A, B, C, workspace); } +MatrixMulImpl::AlgoDataType +MatrixMulImpl::KernSizeParam::deduce_algo_data_type() const { + megdnn_assert(A_type.enumv() == B_type.enumv(), + "Matmul A type and B type of different ctype\n"); + if (A_type.enumv() == DTypeEnum::Float32) { + return MatrixMulImpl::AlgoDataType::FLOAT32; +#if !MEGDNN_DISABLE_FLOAT16 + } else if (A_type.enumv() == DTypeEnum::Float16) { + return MatrixMulImpl::AlgoDataType::FLOAT16; +#endif + } else if (A_type.enumv() == DTypeEnum::Int8 || + A_type.enumv() == DTypeEnum::QuantizedS8) { + if (C_type.enumv() == DTypeEnum::Int16) { + return MatrixMulImpl::AlgoDataType::INT8X8X16; + } else { + megdnn_assert(C_type.enumv() == DTypeEnum::Int32 || + C_type.enumv() == DTypeEnum::QuantizedS32); + return MatrixMulImpl::AlgoDataType::QINT8X8X32; + } + } else if (A_type.enumv() == DTypeEnum::Quantized8Asymm) { + return MatrixMulImpl::AlgoDataType::QUINT8X8X32; + } else if (A_type.enumv() == DTypeEnum::Int16) { + return MatrixMulImpl::AlgoDataType::INT16X16X32; + } else { + megdnn_throw(ssprintf( + "megdnn matmul not support data type of %s * %s -> %s\n", + A_type.name(), B_type.name(), C_type.name())); + } +} + // vim: syntax=cpp.doxygen diff --git a/dnn/src/fallback/matrix_mul/opr_impl.h b/dnn/src/fallback/matrix_mul/opr_impl.h index 3184f2c6..fd4089dc 100644 --- a/dnn/src/fallback/matrix_mul/opr_impl.h +++ b/dnn/src/fallback/matrix_mul/opr_impl.h @@ -10,14 +10,23 @@ * implied. */ #pragma once +#include "megdnn/opr_param_defs.h" #include "src/common/utils.h" #include "src/naive/matrix_mul/opr_impl.h" +#include + namespace megdnn { -namespace fallback { +struct AlgoTypePack { + detail::AlgoDataType data_type : 32; + param::MatrixMul::Format format : 32; +}; + +namespace fallback { class MatrixMulImpl : public naive::MatrixMulForwardImpl { public: using naive::MatrixMulForwardImpl::MatrixMulForwardImpl; + using AlgoDataType = detail::AlgoDataType; bool is_thread_safe() const override { return true; } @@ -34,6 +43,8 @@ public: bool trA, trB; Param::ComputeMode compute_mode; Param::Format format; + //! get the data type category of the param for select the algo + AlgoDataType deduce_algo_data_type() const; }; struct KernParam : public KernSizeParam { @@ -110,6 +121,7 @@ public: struct MatmulDescription { PackMode packmode; InnerBlockSize innerblocksize; + AlgoTypePack algo_type; size_t packa_type_size; }; @@ -146,6 +158,11 @@ public: */ virtual SmallVector algo_pack(); + /** + * \brief select algo according to input algo type + */ + SmallVector select_algo_type(AlgoTypePack algo_type); + protected: KernSizeParam make_kern_size_param(const TensorLayout& A, const TensorLayout& B, diff --git a/dnn/src/x86/conv_bias/f32/algos.h b/dnn/src/x86/conv_bias/f32/algos.h index 66c751d7..b3565cec 100644 --- a/dnn/src/x86/conv_bias/f32/algos.h +++ b/dnn/src/x86/conv_bias/f32/algos.h @@ -48,6 +48,10 @@ public: } void* type() const override; + + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; + } }; /* ===================== direct-stride2 algo ===================== */ @@ -81,6 +85,10 @@ public: } void* type() const override; + + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; + } }; /* =========================== winograd ======================== */ class ConvBiasImpl::AlgoFP32WinogradF63_8x8 final : public AlgoBase { @@ -96,7 +104,7 @@ public: return m_name.c_str(); } void* type() const override; - MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); }; class ConvBiasImpl::AlgoFP32WinogradF23_8x8 final : public AlgoBase { @@ -112,7 +120,7 @@ public: return m_name.c_str(); } void* type() const override; - MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); + MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); }; /* ===================== matmul algo ===================== */ @@ -151,6 +159,9 @@ public: } void* type() const override; + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::FLOAT32, AlgoCategory::IM2COL}; + } }; #if MEGDNN_X86_WITH_MKL_DNN @@ -192,6 +203,10 @@ public: return {{kern, {1_z, 1_z, 1_z}}}; } void* type() const override; + + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; + } }; #endif // vim: syntax=cpp.doxygen diff --git a/dnn/src/x86/conv_bias/int8/algo_usable_preferred.cpp b/dnn/src/x86/conv_bias/int8/algo_usable_preferred.cpp index 65dc72e1..ac7f565a 100644 --- a/dnn/src/x86/conv_bias/int8/algo_usable_preferred.cpp +++ b/dnn/src/x86/conv_bias/int8/algo_usable_preferred.cpp @@ -224,8 +224,6 @@ bool mkldnn_matmul_qint8_preferred( const ConvBiasImpl::NCBKernSizeParam& param) { auto is_preferred = true; auto&& fm = param.filter_meta; - megdnn_assert_internal(fm.group == 1 && fm.dilation[0] == 1 && - fm.dilation[1] == 1); // single channel conv should never use matrix mul if (fm.ocpg == 1 || fm.icpg == 1) diff --git a/dnn/src/x86/conv_bias/int8/algos.h b/dnn/src/x86/conv_bias/int8/algos.h index 34717a42..e62dd8ff 100644 --- a/dnn/src/x86/conv_bias/int8/algos.h +++ b/dnn/src/x86/conv_bias/int8/algos.h @@ -34,6 +34,10 @@ public: } void* type() const override; bool is_preferred(const NCBKernSizeParam& param) const override; + + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; + } }; /* ===================== avx2 stride2 chanwise algo ===================== */ @@ -55,6 +59,10 @@ public: } void* type() const override; bool is_preferred(const NCBKernSizeParam& param) const override; + + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; + } }; /* ===================== avx2 stride1 direct algo ===================== */ @@ -76,6 +84,10 @@ public: } void* type() const override; bool is_preferred(const NCBKernSizeParam& param) const override; + + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; + } }; /* ================== avx2 int8 direct conv stride2 algo ================== */ @@ -97,6 +109,10 @@ public: } void* type() const override; bool is_preferred(const NCBKernSizeParam& param) const override; + + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; + } }; #if MEGDNN_X86_WITH_MKL_DNN @@ -134,6 +150,10 @@ public: } void* type() const override; bool is_preferred(const NCBKernSizeParam& param) const override; + + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; + } }; /* ===================== mkldnn qint8 matmul algo ===================== */ class ConvBiasImpl::AlgoMkldnnMatmulQint8 final : public AlgoBase { @@ -160,6 +180,10 @@ public: bool is_preferred(const NCBKernSizeParam& param) const override; void* type() const override; + + ConvAlgoTypePack get_algo_type() const override { + return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; + } }; #endif diff --git a/dnn/src/x86/conv_bias/opr_impl.cpp b/dnn/src/x86/conv_bias/opr_impl.cpp index 0de20e30..b4759ea1 100644 --- a/dnn/src/x86/conv_bias/opr_impl.cpp +++ b/dnn/src/x86/conv_bias/opr_impl.cpp @@ -103,10 +103,10 @@ public: #endif all_algos.emplace_back(&stride1_direct); all_algos.emplace_back(&stride2_direct); - all_algos.emplace_back(&avx2_stride1_direct_int8); - all_algos.emplace_back(&avx2_stride2_direct); all_algos.emplace_back(&avx2_stride1_chanwsie_qint8); all_algos.emplace_back(&avx2_stride2_chanwsie_qint8); + all_algos.emplace_back(&avx2_stride1_direct_int8); + all_algos.emplace_back(&avx2_stride2_direct); all_algos.emplace_back(&matmul); static CpuOprDelegationStorage<> storage; @@ -182,4 +182,41 @@ bool ConvBiasImpl::is_matmul_quantized_prefer( !chanwise_avx2_stride2_qint8_usable_preferred(param)); } +SmallVector +ConvBiasImpl::suggest_algo_category_order(const NCBKernSizeParam& param) const { + auto IC = param.filter_meta.icpg; + auto OC = param.filter_meta.ocpg; + auto FH = param.filter_meta.spatial[0]; + auto FW = param.filter_meta.spatial[1]; + //! TODO: now winograd only support fast-run + if (param.filter_meta.format == param::ConvBias::Format::NCHW_WINOGRAD || + param.filter_meta.format == param::ConvBias::Format::NCHW44_WINOGRAD || + param.filter_meta.format == param::ConvBias::Format::NCHW88_WINOGRAD) { + return {AlgoCategory::WINOGRAD}; + } + //! nchw88 use mkl-dnn which algo is direct + if (param.filter_meta.format == param::ConvBias::Format::NCHW88) { + return {AlgoCategory::DIRECT, AlgoCategory::IM2COL}; + } + //! im2col + matmul + bool im2col_prefer = (IC >= 32 || OC >= 32); + //! quantized algo use matmul when direct algo is unusable + if (param.src_type.category() == DTypeCategory::QUANTIZED) { + im2col_prefer = is_matmul_quantized_prefer(param); + } + //! conv1x1 + im2col_prefer |= (FH == 1 && FW == 1); + //! x86 8x8x16 not optmized, so it will use fallback im2col+matmul + if (param.deduce_algo_data_type() == AlgoDataType::INT8X8X16) { + im2col_prefer = true; + } + if (im2col_prefer) { + return {AlgoCategory::IM2COL, AlgoCategory::DIRECT, + AlgoCategory::NAIVE}; + } else { + return {AlgoCategory::DIRECT, AlgoCategory::IM2COL, + AlgoCategory::NAIVE}; + } +} + // vim: syntax=cpp.doxygen diff --git a/dnn/src/x86/conv_bias/opr_impl.h b/dnn/src/x86/conv_bias/opr_impl.h index 204cf38a..38dca3d3 100644 --- a/dnn/src/x86/conv_bias/opr_impl.h +++ b/dnn/src/x86/conv_bias/opr_impl.h @@ -24,6 +24,8 @@ public: bool is_thread_safe() const override { return true; } SmallVector algo_pack() override; + SmallVector suggest_algo_category_order( + const NCBKernSizeParam& param) const override; class AlgoDirect; class AlgoDirectStride2; diff --git a/dnn/src/x86/matrix_mul/algos.cpp b/dnn/src/x86/matrix_mul/algos.cpp index e07c1ee2..2562f641 100644 --- a/dnn/src/x86/matrix_mul/algos.cpp +++ b/dnn/src/x86/matrix_mul/algos.cpp @@ -184,11 +184,10 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Vnni::get_kern( return int8x8x32_kern_vnni; } -MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32Vnni, - megdnn_x86_matmul_kern, - "AlgoInt8x8x32Vnni"_hash, - x86::matmul::gemm_int8_vnni_12x32x4, - dt_int8, dt_int32, dt_uint8); +MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( + AlgoInt8x8x32Vnni, megdnn_x86_matmul_kern, "AlgoInt8x8x32Vnni"_hash, + x86::matmul::gemm_int8_vnni_12x32x4, dt_int8, dt_int32, + dt_uint8AlgoDataType::QINT8X8X32, DEFAULT); #endif /* ===================== Int8 mkldnn algo ===================== */ @@ -397,7 +396,8 @@ size_t MatrixMulImpl::AlgoInt8x8x16AVX2::get_workspace( } MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( AlgoInt8x8x16AVX2, megdnn_x86_matmul_kern, "AlgoInt8x8x16AVX2"_hash, - x86::matmul::gemm_avx2_s8s8s16_4x16x2, dt_int8, dt_int16, dt_int16); + x86::matmul::gemm_avx2_s8s8s16_4x16x2, dt_int8, dt_int16, dt_int16, + AlgoDataType::INT8X8X16, DEFAULT); /*************************AlgoInt8x8x16SSE********************/ void MatrixMulImpl::AlgoInt8x8x16SSE::gemm_s8s8s16_sse_4x8x2( @@ -474,7 +474,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16SSE, megdnn_x86_matmul_kern, "AlgoInt8x8x16SSE"_hash, x86::matmul::gemm_sse_s8s8s16_4x8x2, - dt_int8, dt_int16, dt_int16); + dt_int8, dt_int16, dt_int16, + AlgoDataType::INT8X8X16, DEFAULT); /*************************AlgoInt8x8x32AVX2M4N16K2********************/ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern( @@ -516,7 +517,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_workspace( MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( AlgoInt8x8x32AVX2M4N16K2, megdnn_x86_matmul_kern, "AlgoInt8x8x32AVX2M4N16K2"_hash, x86::matmul::gemm_avx2_s8s8s32_4x16x2, - dt_int8, dt_int32, dt_int16); + dt_int8, dt_int32, dt_int16, AlgoDataType::QINT8X8X32, DEFAULT); MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_kern( const KernSizeParam&) const { @@ -556,7 +557,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32AVX2M2N4K16, megdnn_x86_matmul_kern, "AlgoInt8x8x32AVX2M2N4K16"_hash, x86::matmul::gemm_avx2_s8s8s32_2x4x16, - dt_int8, dt_int32); + dt_int8, dt_int32, + AlgoDataType::QINT8X8X32, DEFAULT); /*************************AlgoInt8x8x32SSEM4N8K2********************/ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_kern( @@ -596,7 +598,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32SSEM4N8K2, megdnn_x86_matmul_kern, "AlgoInt8x8x32SSEM4N8K2"_hash, x86::matmul::gemm_sse_s8s8s32_4x8x2, - dt_int8, dt_int32, dt_int16); + dt_int8, dt_int32, dt_int16, + AlgoDataType::QINT8X8X32, DEFAULT); /*************************AlgoF32MK8_8x8********************/ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK8_8x8::get_kern( diff --git a/dnn/src/x86/matrix_mul/algos.h b/dnn/src/x86/matrix_mul/algos.h index 6fd5d7e9..d93c9b5e 100644 --- a/dnn/src/x86/matrix_mul/algos.h +++ b/dnn/src/x86/matrix_mul/algos.h @@ -27,7 +27,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; void* type() const override { return sm_x86_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) + MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) }; #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM @@ -49,7 +49,7 @@ public: WorkspaceBundle get_bundle(const KernSizeParam& param) const override; InnerBlockSize get_inner_block_size() const override{ return {8, 16, 1}; }; - MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) + MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) }; #endif @@ -127,7 +127,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; void* type() const override { return sm_x86_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4) + MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4, AlgoDataType::FLOAT32, MK8) }; #if MEGDNN_X86_WITH_VNNI @@ -153,7 +153,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; void* type() const override { return sm_x86_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } - MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) + MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) }; #endif } // namespace x86 diff --git a/src/opr/impl/dnn/convolution.cpp b/src/opr/impl/dnn/convolution.cpp index e3469707..c04d09ba 100644 --- a/src/opr/impl/dnn/convolution.cpp +++ b/src/opr/impl/dnn/convolution.cpp @@ -495,8 +495,9 @@ class AlgoChooser { } } mgb_assert(found, - "algo got by heuristic not found in " - "candidate list"); + "algo %s got by heuristic not found in " + "candidate list", + heu->name()); return std::move(ret); } @@ -628,7 +629,7 @@ public: auto algo = get_algo(ctx); size_t workspace = ctx.get_workspace_size_bytes(algo); mgb_log_debug( - "%s: input shapes (%s %s, %s %s) -> (%s %s): algo=%s " + "%s:tensor layouts (%s %s, %s %s)->(%s %s) :algo=%s " "workspace=%.2fMiB reproducible=%d", mgb_opr->dyn_typeinfo()->name, layouts[0].to_string().c_str(), @@ -636,8 +637,7 @@ public: layouts[1].to_string().c_str(), layouts[1].dtype.name(), layouts[layouts.size() - 1].to_string().c_str(), - layouts[layouts.size() - 1].dtype.name(), - algo->name(), + layouts[layouts.size() - 1].dtype.name(), algo->name(), workspace / (1024 * 1024.0), algo->is_reproducible()); megdnn_opr->execution_policy() = {algo}; return workspace;