GitOrigin-RevId: 60d2646bb3
release-1.1
@@ -76,6 +76,18 @@ enum class AlgoSelectionStrategy { | |||||
FULL_RUN = 2, | FULL_RUN = 2, | ||||
}; | }; | ||||
/** | |||||
* \brief separate algo by datatype for Matmul and conv | |||||
*/ | |||||
enum class AlgoDataType : uint32_t { | |||||
FLOAT32 = 1 << 0, | |||||
FLOAT16 = 1 << 1, | |||||
QINT8X8X32 = 1 << 2, | |||||
QUINT8X8X32 = 1 << 3, | |||||
INT8X8X16 = 1 << 4, | |||||
INT16X16X32 = 1 << 5, | |||||
}; | |||||
/*! | /*! | ||||
* \brief Abstract representation of an algorithm for implementing | * \brief Abstract representation of an algorithm for implementing | ||||
* the operator | * the operator | ||||
@@ -27,6 +27,10 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
} // namespace aarch64 | } // namespace aarch64 | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -32,6 +32,10 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
} // namespace aarch64 | } // namespace aarch64 | ||||
@@ -45,6 +45,9 @@ public: | |||||
return static_cast<ConvBiasImpl*>(conv_bias_opr) | return static_cast<ConvBiasImpl*>(conv_bias_opr) | ||||
->is_matmul_quantized_prefer(param); | ->is_matmul_quantized_prefer(param); | ||||
} | } | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; | |||||
} | |||||
}; | }; | ||||
} // namespace aarch64 | } // namespace aarch64 | ||||
@@ -50,10 +50,9 @@ SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||||
auto&& algos = arm_common::ConvBiasImpl::algo_pack(); | auto&& algos = arm_common::ConvBiasImpl::algo_pack(); | ||||
algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), | algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), | ||||
sl_algo_pack.direct_algos.end()); | sl_algo_pack.direct_algos.end()); | ||||
//! We put matmul algos at the end. Because matmul will get privilege when | |||||
//! We put matmul algos at the begin. Because matmul will get privilege when | |||||
//! prefer return true. See | //! prefer return true. See | ||||
//! fallback::ConvolutionImpl::ncb_1g_get_all_algorithms for more details. | |||||
algos.insert(algos.end(), sl_algo_pack.matmul_algos.begin(), | |||||
algos.insert(algos.begin(), sl_algo_pack.matmul_algos.begin(), | |||||
sl_algo_pack.matmul_algos.end()); | sl_algo_pack.matmul_algos.end()); | ||||
return std::move(algos); | return std::move(algos); | ||||
} | } | ||||
@@ -45,6 +45,9 @@ public: | |||||
return static_cast<ConvBiasImpl*>(conv_bias_opr) | return static_cast<ConvBiasImpl*>(conv_bias_opr) | ||||
->is_matmul_quantized_prefer(param); | ->is_matmul_quantized_prefer(param); | ||||
} | } | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL}; | |||||
} | |||||
}; | }; | ||||
} // namespace aarch64 | } // namespace aarch64 | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -89,7 +89,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern( | |||||
} | } | ||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern, | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern, | ||||
"AlgoF32K8x12x1Impl"_hash, | "AlgoF32K8x12x1Impl"_hash, | ||||
aarch64::matmul::sgemm_8x12, float, float); | |||||
aarch64::matmul::sgemm_8x12, float, float, | |||||
AlgoDataType::FLOAT32, DEFAULT); | |||||
/* ===================== F32_MK4_8X12X1 algo ===================== */ | /* ===================== F32_MK4_8X12X1 algo ===================== */ | ||||
bool MatrixMulImpl::AlgoF32MK4_8x12x1::usable( | bool MatrixMulImpl::AlgoF32MK4_8x12x1::usable( | ||||
@@ -151,7 +152,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4_8x12x1, | |||||
megdnn_aarch64_matmul_kern, | megdnn_aarch64_matmul_kern, | ||||
"AlgoF32MK4_8x12x1Impl"_hash, | "AlgoF32MK4_8x12x1Impl"_hash, | ||||
aarch64::matmul::sgemm_mk4_8x12, float, | aarch64::matmul::sgemm_mk4_8x12, float, | ||||
float); | |||||
float, AlgoDataType::FLOAT32, MK4); | |||||
/* ===================== F32K4X16X1 algo ===================== */ | /* ===================== F32K4X16X1 algo ===================== */ | ||||
@@ -210,7 +211,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K4x16x1::get_kern( | |||||
} | } | ||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K4x16x1, megdnn_aarch64_matmul_kern, | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K4x16x1, megdnn_aarch64_matmul_kern, | ||||
"AlgoF32K4x16x1Impl"_hash, | "AlgoF32K4x16x1Impl"_hash, | ||||
aarch64::matmul::sgemm_4x16, float, float); | |||||
aarch64::matmul::sgemm_4x16, float, float, | |||||
AlgoDataType::FLOAT32, MK4); | |||||
/* ===================== F32MK4_4x16 algo ===================== */ | /* ===================== F32MK4_4x16 algo ===================== */ | ||||
@@ -328,7 +330,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K8x24x1::get_kern( | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K8x24x1, megdnn_aarch64_matmul_kern, | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K8x24x1, megdnn_aarch64_matmul_kern, | ||||
"AlogF16K8x24x1Impl"_hash, | "AlogF16K8x24x1Impl"_hash, | ||||
aarch64::matmul::hgemm_8x24, dt_float16, | aarch64::matmul::hgemm_8x24, dt_float16, | ||||
dt_float16); | |||||
dt_float16, AlgoDataType::FLOAT16, | |||||
DEFAULT); | |||||
/* ===================== F16_MK8_8x8 algo ===================== */ | /* ===================== F16_MK8_8x8 algo ===================== */ | ||||
bool MatrixMulImpl::AlgoF16MK8_8x8::usable( | bool MatrixMulImpl::AlgoF16MK8_8x8::usable( | ||||
@@ -449,7 +452,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x12x4DotProd, | |||||
megdnn_aarch64_matmul_kern, | megdnn_aarch64_matmul_kern, | ||||
"AlgoInt8x8x32K8x12x4DotProdImpl"_hash, | "AlgoInt8x8x32K8x12x4DotProdImpl"_hash, | ||||
aarch64::matmul::gemm_s8_8x12, int8_t, | aarch64::matmul::gemm_s8_8x12, int8_t, | ||||
int32_t); | |||||
int32_t, AlgoDataType::QINT8X8X32, | |||||
DEFAULT); | |||||
/* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */ | /* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */ | ||||
namespace { | namespace { | ||||
@@ -520,7 +524,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x12x4DotProd, | |||||
megdnn_aarch64_matmul_kern, | megdnn_aarch64_matmul_kern, | ||||
"AlgoInt8x8x32MK4_8x12x4DotProdImpl"_hash, | "AlgoInt8x8x32MK4_8x12x4DotProdImpl"_hash, | ||||
aarch64::matmul::gemm_mk4_s8_8x12, int8_t, | aarch64::matmul::gemm_mk4_s8_8x12, int8_t, | ||||
int32_t); | |||||
int32_t, AlgoDataType::QINT8X8X32, | |||||
MK4_DOT); | |||||
#else | #else | ||||
/* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */ | /* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */ | ||||
@@ -593,7 +598,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_4x4x16, | |||||
megdnn_aarch64_matmul_kern, | megdnn_aarch64_matmul_kern, | ||||
"AlgoInt8x8x32MK4_4x4x16Impl"_hash, | "AlgoInt8x8x32MK4_4x4x16Impl"_hash, | ||||
aarch64::matmul::gemm_mk4_s8_4x4, int8_t, | aarch64::matmul::gemm_mk4_s8_4x4, int8_t, | ||||
int32_t); | |||||
int32_t, AlgoDataType::QINT8X8X32, | |||||
MK4); | |||||
/* ===================== Int8x8x32 K4x4x16 algo ===================== */ | /* ===================== Int8x8x32 K4x4x16 algo ===================== */ | ||||
namespace { | namespace { | ||||
@@ -656,7 +662,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x4x16, | |||||
megdnn_aarch64_matmul_kern, | megdnn_aarch64_matmul_kern, | ||||
"AlgoInt8x8x32K4x4x16Impl"_hash, | "AlgoInt8x8x32K4x4x16Impl"_hash, | ||||
aarch64::matmul::gemm_s8_4x4, int8_t, | aarch64::matmul::gemm_s8_4x4, int8_t, | ||||
int32_t); | |||||
int32_t, AlgoDataType::QINT8X8X32, | |||||
DEFAULT); | |||||
/* ===================== Int8x8x32 K8x8x8 algo ===================== */ | /* ===================== Int8x8x32 K8x8x8 algo ===================== */ | ||||
namespace { | namespace { | ||||
void int8x8x32_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) { | void int8x8x32_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) { | ||||
@@ -717,7 +724,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x8x8, | |||||
megdnn_aarch64_matmul_kern, | megdnn_aarch64_matmul_kern, | ||||
"AlgoInt8x8x32K8x8x8Impl"_hash, | "AlgoInt8x8x32K8x8x8Impl"_hash, | ||||
aarch64::matmul::gemm_s8_8x8, int8_t, | aarch64::matmul::gemm_s8_8x8, int8_t, | ||||
int32_t); | |||||
int32_t, AlgoDataType::QINT8X8X32, | |||||
DEFAULT); | |||||
#endif | #endif | ||||
/* ===================== Int8x8x16 K8x8x8 algo ===================== */ | /* ===================== Int8x8x16 K8x8x8 algo ===================== */ | ||||
@@ -785,7 +793,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K8x8x8, | |||||
megdnn_aarch64_matmul_kern, | megdnn_aarch64_matmul_kern, | ||||
"AlgoInt8x8x16K8x8x8Impl"_hash, | "AlgoInt8x8x16K8x8x8Impl"_hash, | ||||
aarch64::matmul::gemm_s8x8x16_8x8, int8_t, | aarch64::matmul::gemm_s8x8x16_8x8, int8_t, | ||||
int16_t); | |||||
int16_t, AlgoDataType::INT8X8X16, DEFAULT); | |||||
/* ===================== Int8x8x16 K4x4x16 algo ===================== */ | /* ===================== Int8x8x16 K4x4x16 algo ===================== */ | ||||
namespace { | namespace { | ||||
void int8x8x16_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) { | void int8x8x16_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) { | ||||
@@ -852,7 +860,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x4x16, | |||||
megdnn_aarch64_matmul_kern, | megdnn_aarch64_matmul_kern, | ||||
"AlgoInt8x8x16K4x4x16Impl"_hash, | "AlgoInt8x8x16K4x4x16Impl"_hash, | ||||
aarch64::matmul::gemm_s8x8x16_4x4, int8_t, | aarch64::matmul::gemm_s8x8x16_4x4, int8_t, | ||||
int16_t); | |||||
int16_t, AlgoDataType::INT8X8X16, DEFAULT); | |||||
/* ===================== Int8x8x16 K16x12x4 algo ===================== */ | /* ===================== Int8x8x16 K16x12x4 algo ===================== */ | ||||
namespace { | namespace { | ||||
@@ -929,7 +937,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_kern( | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( | ||||
AlgoInt8x8x16MK4_16x12x4, megdnn_aarch64_matmul_kern, | AlgoInt8x8x16MK4_16x12x4, megdnn_aarch64_matmul_kern, | ||||
"AlgoInt8x8x16MK4_16x12x4Impl"_hash, | "AlgoInt8x8x16MK4_16x12x4Impl"_hash, | ||||
aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53, int8_t, int16_t, int16_t); | |||||
aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53, int8_t, int16_t, int16_t, | |||||
AlgoDataType::INT8X8X16, MK4); | |||||
/* ===================== Int8x8x16 MK4 4x4x8 algo ===================== */ | /* ===================== Int8x8x16 MK4 4x4x8 algo ===================== */ | ||||
namespace { | namespace { | ||||
@@ -1007,7 +1016,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_4x4x8, | |||||
megdnn_aarch64_matmul_kern, | megdnn_aarch64_matmul_kern, | ||||
"AlgoInt8x8x16MK4_4x4x8_Impl"_hash, | "AlgoInt8x8x16MK4_4x4x8_Impl"_hash, | ||||
aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72, | aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72, | ||||
int8_t, int16_t); | |||||
int8_t, int16_t, AlgoDataType::INT8X8X16, | |||||
MK4); | |||||
/* ===================== Int16x16x32 K12x8x1 algo ===================== */ | /* ===================== Int16x16x32 K12x8x1 algo ===================== */ | ||||
namespace { | namespace { | ||||
@@ -1078,7 +1088,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x8x1, | |||||
megdnn_aarch64_matmul_kern, | megdnn_aarch64_matmul_kern, | ||||
"AlgoInt16x16x32K12x8x1Impl"_hash, | "AlgoInt16x16x32K12x8x1Impl"_hash, | ||||
aarch64::matmul::gemm_s16_12x8x1, int16_t, | aarch64::matmul::gemm_s16_12x8x1, int16_t, | ||||
int32_t); | |||||
int32_t, AlgoDataType::INT16X16X32, | |||||
DEFAULT); | |||||
/* ===================== Int16x16x32MK8_8x8 algo ===================== */ | /* ===================== Int16x16x32MK8_8x8 algo ===================== */ | ||||
@@ -1201,7 +1212,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x4DotProd, | |||||
megdnn_aarch64_matmul_kern, | megdnn_aarch64_matmul_kern, | ||||
"AlgoQuint8K8x8x4DotProdImpl"_hash, | "AlgoQuint8K8x8x4DotProdImpl"_hash, | ||||
aarch64::matmul::gemm_u8_8x8, uint8_t, | aarch64::matmul::gemm_u8_8x8, uint8_t, | ||||
int32_t); | |||||
int32_t, AlgoDataType::QUINT8X8X32, | |||||
DEFAULT); | |||||
/* ===================== Quint8 Gemv DotProd algo ===================== */ | /* ===================== Quint8 Gemv DotProd algo ===================== */ | ||||
namespace { | namespace { | ||||
void quint8_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { | void quint8_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { | ||||
@@ -1307,7 +1319,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8, | |||||
megdnn_aarch64_matmul_kern, | megdnn_aarch64_matmul_kern, | ||||
"AlgoQuint8K8x8x8Impl"_hash, | "AlgoQuint8K8x8x8Impl"_hash, | ||||
aarch64::matmul::gemm_u8_8x8, uint8_t, | aarch64::matmul::gemm_u8_8x8, uint8_t, | ||||
int32_t); | |||||
int32_t, AlgoDataType::QUINT8X8X32, | |||||
DEFAULT); | |||||
#endif | #endif | ||||
/* ===================== Int8x8x16 K8x8x8 algo ===================== */ | /* ===================== Int8x8x16 K8x8x8 algo ===================== */ | ||||
@@ -1378,6 +1391,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_kern( | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_K8x8x8, | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_K8x8x8, | ||||
megdnn_aarch64_matmul_kern, | megdnn_aarch64_matmul_kern, | ||||
"AlgoInt8x8x16MK4_K8x8x8Impl"_hash, | "AlgoInt8x8x16MK4_K8x8x8Impl"_hash, | ||||
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8, int8_t, | |||||
int16_t); | |||||
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8, | |||||
int8_t, int16_t, AlgoDataType::INT8X8X16, | |||||
MK4); | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -61,7 +61,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4, AlgoDataType::FLOAT32, MK4) | |||||
}; | }; | ||||
class MatrixMulImpl::AlgoF32Gemv final | class MatrixMulImpl::AlgoF32Gemv final | ||||
@@ -88,7 +88,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::FLOAT16, MK8) | |||||
}; | }; | ||||
#endif | #endif | ||||
@@ -253,7 +253,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) | |||||
}; | }; | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
@@ -281,7 +281,7 @@ public: | |||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT) | |||||
}; | }; | ||||
#else | #else | ||||
@@ -29,7 +29,7 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase { | class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase { | ||||
@@ -44,7 +44,7 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase { | class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase { | ||||
@@ -60,7 +60,7 @@ public: | |||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase { | class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase { | ||||
public: | public: | ||||
@@ -74,7 +74,7 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { | class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { | ||||
@@ -90,6 +90,10 @@ public: | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override{ | |||||
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { | ||||
@@ -103,6 +107,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
} // namespace arm_common | } // namespace arm_common | ||||
@@ -29,7 +29,7 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase { | ||||
@@ -44,7 +44,7 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { | ||||
@@ -59,7 +59,7 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { | ||||
@@ -74,7 +74,7 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase { | ||||
@@ -89,7 +89,7 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||||
}; | }; | ||||
//===================== NCHW44 Winograd Support =====================// | //===================== NCHW44 Winograd Support =====================// | ||||
@@ -106,7 +106,7 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase { | ||||
@@ -122,7 +122,7 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase { | ||||
@@ -138,7 +138,7 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||||
}; | }; | ||||
// ================================================================= // | // ================================================================= // | ||||
@@ -154,6 +154,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { | ||||
@@ -168,6 +171,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | ||||
@@ -182,6 +188,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase { | ||||
@@ -197,6 +206,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase { | ||||
@@ -212,6 +224,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase { | ||||
@@ -226,6 +241,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
} // namespace arm_common | } // namespace arm_common | ||||
@@ -29,6 +29,10 @@ public: | |||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { | ||||
@@ -42,6 +46,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase { | ||||
@@ -55,6 +62,9 @@ public: | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase { | ||||
@@ -68,6 +78,9 @@ public: | |||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase { | ||||
@@ -79,6 +92,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase { | ||||
@@ -90,6 +106,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
@@ -104,6 +123,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam&) const override; | size_t get_workspace(const NCBKernSizeParam&) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { | ||||
@@ -117,6 +139,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam&) const override; | size_t get_workspace(const NCBKernSizeParam&) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase { | ||||
@@ -131,6 +156,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam&) const override; | size_t get_workspace(const NCBKernSizeParam&) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase { | ||||
@@ -148,6 +176,10 @@ public: | |||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
#endif | #endif | ||||
@@ -163,7 +195,7 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | |||||
}; | }; | ||||
//=======================input int8 compute fp32 output int8============ | //=======================input int8 compute fp32 output int8============ | ||||
@@ -180,7 +212,7 @@ public: | |||||
} | } | ||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | |||||
}; | }; | ||||
//=======================input int8 compute int16 output int8============ | //=======================input int8 compute int16 output int8============ | ||||
@@ -198,7 +230,7 @@ public: | |||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | |||||
}; | }; | ||||
} // namespace arm_common | } // namespace arm_common | ||||
@@ -36,6 +36,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase { | ||||
@@ -48,6 +51,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { | class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { | ||||
@@ -71,6 +77,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase { | class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase { | ||||
@@ -84,6 +93,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final : public AlgoBase { | ||||
@@ -96,6 +108,9 @@ public: | |||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase { | class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase { | ||||
@@ -111,6 +126,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
} // namespace arm_common | } // namespace arm_common | ||||
@@ -10,6 +10,7 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#include "megdnn/opr_param_defs.h" | |||||
#include "src/arm_common/conv_bias/int8/algos.h" | #include "src/arm_common/conv_bias/int8/algos.h" | ||||
#include "src/arm_common/conv_bias/int8x8x16/algos.h" | #include "src/arm_common/conv_bias/int8x8x16/algos.h" | ||||
#include "src/arm_common/conv_bias/quint8/algos.h" | #include "src/arm_common/conv_bias/quint8/algos.h" | ||||
@@ -122,9 +123,11 @@ public: | |||||
static CpuOprDelegationStorage<2> storage; | static CpuOprDelegationStorage<2> storage; | ||||
auto matmul_opr = storage.get<MatrixMul, 0>(); | auto matmul_opr = storage.get<MatrixMul, 0>(); | ||||
using MatmulFormat = param::MatrixMul::Format; | |||||
auto&& matmul_algos = | auto&& matmul_algos = | ||||
static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | ||||
->algo_pack(); | |||||
->select_algo_type( | |||||
{AlgoDataType::FLOAT32, MatmulFormat::MK4}); | |||||
for (auto&& algo : matmul_algos) { | for (auto&& algo : matmul_algos) { | ||||
if (algo->type() == nullptr) | if (algo->type() == nullptr) | ||||
continue; | continue; | ||||
@@ -133,38 +136,62 @@ public: | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
tile_size)); | tile_size)); | ||||
winograd_algos.emplace_back(refhold.back().get()); | winograd_algos.emplace_back(refhold.back().get()); | ||||
refhold.emplace_back(new AlgoFP32WinogradF63( | |||||
refhold.emplace_back(new AlgoFP32WinogradF63_4x4( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
tile_size)); | tile_size)); | ||||
winograd_algos.emplace_back(refhold.back().get()); | winograd_algos.emplace_back(refhold.back().get()); | ||||
refhold.emplace_back(new AlgoFP32WinogradF63_4x4( | |||||
refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
tile_size)); | tile_size)); | ||||
winograd_algos.emplace_back(refhold.back().get()); | winograd_algos.emplace_back(refhold.back().get()); | ||||
refhold.emplace_back(new AlgoFP32WinogradF54( | |||||
refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
tile_size)); | tile_size)); | ||||
winograd_algos.emplace_back(refhold.back().get()); | winograd_algos.emplace_back(refhold.back().get()); | ||||
refhold.emplace_back(new AlgoFP32WinogradF45( | |||||
//! uncomment this when low precision mode is done | |||||
#if 0 | |||||
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
tile_size)); | tile_size)); | ||||
winograd_algos.emplace_back(refhold.back().get()); | winograd_algos.emplace_back(refhold.back().get()); | ||||
refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( | |||||
#endif | |||||
//! Qint8x8x32 winograd compute with fp32 | |||||
refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
tile_size)); | tile_size)); | ||||
winograd_algos.emplace_back(refhold.back().get()); | winograd_algos.emplace_back(refhold.back().get()); | ||||
refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( | |||||
} | |||||
} | |||||
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | |||||
->select_algo_type({AlgoDataType::FLOAT32, | |||||
MatmulFormat::DEFAULT}); | |||||
for (auto&& algo : matmul_algos) { | |||||
if (algo->type() == nullptr) | |||||
continue; | |||||
for (uint32_t tile_size : {16, 8, 24, 32}) { | |||||
refhold.emplace_back(new AlgoFP32WinogradF63( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
tile_size)); | tile_size)); | ||||
winograd_algos.emplace_back(refhold.back().get()); | winograd_algos.emplace_back(refhold.back().get()); | ||||
//! uncomment this when low precision mode is done | |||||
#if 0 | |||||
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( | |||||
refhold.emplace_back(new AlgoFP32WinogradF54( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
tile_size)); | tile_size)); | ||||
winograd_algos.emplace_back(refhold.back().get()); | winograd_algos.emplace_back(refhold.back().get()); | ||||
#endif | |||||
refhold.emplace_back(new AlgoFP32WinogradF45( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||||
tile_size)); | |||||
winograd_algos.emplace_back(refhold.back().get()); | |||||
} | |||||
} | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | |||||
->select_algo_type({AlgoDataType::FLOAT16, | |||||
MatmulFormat::DEFAULT}); | |||||
for (auto&& algo : matmul_algos) { | |||||
if (algo->type() == nullptr) | |||||
continue; | |||||
for (uint32_t tile_size : {16, 8, 24, 32}) { | |||||
refhold.emplace_back(new AlgoFP16WinogradF23( | refhold.emplace_back(new AlgoFP16WinogradF23( | ||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
tile_size)); | tile_size)); | ||||
@@ -177,19 +204,33 @@ public: | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
tile_size)); | tile_size)); | ||||
winograd_algos.emplace_back(refhold.back().get()); | winograd_algos.emplace_back(refhold.back().get()); | ||||
} | |||||
} | |||||
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | |||||
->select_algo_type({AlgoDataType::FLOAT16, | |||||
MatmulFormat::MK8}); | |||||
for (auto&& algo : matmul_algos) { | |||||
if (algo->type() == nullptr) | |||||
continue; | |||||
for (uint32_t tile_size : {16, 8, 24, 32}) { | |||||
refhold.emplace_back(new AlgoFP16WinogradF23_8x8( | refhold.emplace_back(new AlgoFP16WinogradF23_8x8( | ||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
tile_size)); | tile_size)); | ||||
winograd_algos.emplace_back(refhold.back().get()); | winograd_algos.emplace_back(refhold.back().get()); | ||||
} | |||||
} | |||||
#endif | #endif | ||||
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | |||||
->select_algo_type({AlgoDataType::INT16X16X32, | |||||
MatmulFormat::MK8}); | |||||
for (auto&& algo : matmul_algos) { | |||||
if (algo->type() == nullptr) | |||||
continue; | |||||
for (uint32_t tile_size : {16, 8, 24, 32}) { | |||||
refhold.emplace_back(new AlgoS8WinogradF23_8x8( | refhold.emplace_back(new AlgoS8WinogradF23_8x8( | ||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
tile_size)); | tile_size)); | ||||
winograd_algos.emplace_back(refhold.back().get()); | winograd_algos.emplace_back(refhold.back().get()); | ||||
refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44( | |||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||||
tile_size)); | |||||
winograd_algos.emplace_back(refhold.back().get()); | |||||
refhold.emplace_back(new AlgoS8WinogradF23_8x8_NCHW44( | refhold.emplace_back(new AlgoS8WinogradF23_8x8_NCHW44( | ||||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | ||||
tile_size)); | tile_size)); | ||||
@@ -240,6 +281,42 @@ bool ConvBiasImpl::is_matmul_quantized_prefer( | |||||
return conv_direct_unusable; | return conv_direct_unusable; | ||||
} | } | ||||
SmallVector<AlgoCategory> ConvBiasImpl::suggest_algo_category_order( | |||||
const NCBKernSizeParam& param) const { | |||||
auto IC = param.filter_meta.icpg; | |||||
auto OC = param.filter_meta.ocpg; | |||||
auto FH = param.filter_meta.spatial[0]; | |||||
auto FW = param.filter_meta.spatial[1]; | |||||
//! TODO: now winograd only support fast-run | |||||
if (param.filter_meta.format == param::ConvBias::Format::NCHW_WINOGRAD || | |||||
param.filter_meta.format == param::ConvBias::Format::NCHW44_WINOGRAD || | |||||
param.filter_meta.format == param::ConvBias::Format::NCHW88_WINOGRAD) { | |||||
return {AlgoCategory::WINOGRAD}; | |||||
} | |||||
//! im2col | |||||
bool im2col_prefer = (IC >= 32 || OC >= 32); | |||||
//! quantized algo use matmul when direct algo is unusable | |||||
if (param.src_type.category() == DTypeCategory::QUANTIZED) { | |||||
im2col_prefer = is_matmul_quantized_prefer(param); | |||||
} | |||||
//! conv1x1 | |||||
im2col_prefer |= (FH == 1 && FW == 1); | |||||
//! nchw44 and nchw44-dot hybird mode is direct | |||||
if (param.filter_meta.format == param::ConvBias::Format::NCHW44 || | |||||
param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT) { | |||||
if (IC < 4) { | |||||
im2col_prefer = false; | |||||
} | |||||
} | |||||
if (im2col_prefer) { | |||||
return {AlgoCategory::IM2COL, AlgoCategory::DIRECT, | |||||
AlgoCategory::NAIVE}; | |||||
} else { | |||||
return {AlgoCategory::DIRECT, AlgoCategory::IM2COL, | |||||
AlgoCategory::NAIVE}; | |||||
} | |||||
} | |||||
const char* ConvBiasImpl::get_algorithm_set_name() const { | const char* ConvBiasImpl::get_algorithm_set_name() const { | ||||
// arm common version 0 | // arm common version 0 | ||||
return "AC0"; | return "AC0"; | ||||
@@ -28,6 +28,9 @@ public: | |||||
bool is_matmul_quantized_prefer( | bool is_matmul_quantized_prefer( | ||||
const ConvBiasImpl::NCBKernSizeParam& ncb_param) const override; | const ConvBiasImpl::NCBKernSizeParam& ncb_param) const override; | ||||
SmallVector<AlgoCategory> suggest_algo_category_order( | |||||
const NCBKernSizeParam& param) const override; | |||||
class AlgoPack; | class AlgoPack; | ||||
protected: | protected: | ||||
@@ -90,7 +93,7 @@ private: | |||||
class AlgoF16Direct; | class AlgoF16Direct; | ||||
class AlgoF16DirectStride1; | class AlgoF16DirectStride1; | ||||
#endif | #endif | ||||
}; | |||||
}; | |||||
} // namespace arm_common | } // namespace arm_common | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -29,6 +29,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase { | ||||
@@ -42,6 +45,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { | class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { | ||||
@@ -56,6 +62,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase { | class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase { | ||||
@@ -69,6 +78,9 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
virtual SmallVector<NCBKern> dispatch_kerns( | virtual SmallVector<NCBKern> dispatch_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
#endif | #endif | ||||
} // namespace arm_common | } // namespace arm_common | ||||
@@ -26,7 +26,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::INT8X8X16, DEFAULT) | |||||
}; | }; | ||||
class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | ||||
@@ -40,7 +40,7 @@ public: | |||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) | |||||
}; | }; | ||||
class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { | ||||
@@ -54,7 +54,7 @@ public: | |||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4) | |||||
}; | }; | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
@@ -69,7 +69,7 @@ public: | |||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4_DOT) | |||||
}; | }; | ||||
#endif | #endif | ||||
@@ -87,7 +87,7 @@ public: | |||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) | |||||
}; | }; | ||||
class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { | class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { | ||||
@@ -101,7 +101,7 @@ public: | |||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) | |||||
}; | }; | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
@@ -116,7 +116,7 @@ public: | |||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::FLOAT16, DEFAULT) | |||||
}; | }; | ||||
#endif | #endif | ||||
@@ -131,7 +131,13 @@ public: | |||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(1, 1, 1, 4) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC( | |||||
1, 1, 1, 4, | |||||
static_cast<AlgoDataType>( | |||||
static_cast<uint32_t>(AlgoDataType::FLOAT16) | | |||||
static_cast<uint32_t>(AlgoDataType::FLOAT32) | | |||||
static_cast<uint32_t>(AlgoDataType::QINT8X8X32)), | |||||
DEFAULT) | |||||
}; | }; | ||||
} // namespace arm_common | } // namespace arm_common | ||||
@@ -25,7 +25,7 @@ void* const MatrixMulImpl::sm_arm_common_algo_type = | |||||
class MatrixMulImpl::AlgoPack : NonCopyableObj { | class MatrixMulImpl::AlgoPack : NonCopyableObj { | ||||
AlgoInt8x8x16 int8x8x16; | AlgoInt8x8x16 int8x8x16; | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
AlgoF16Gemv f16gemv; | |||||
AlgoF16Gemv f16gemv; | |||||
#endif | #endif | ||||
AlgoInt8x8x32Gemv int8x8x32_gemv; | AlgoInt8x8x32Gemv int8x8x32_gemv; | ||||
AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4; | AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4; | ||||
@@ -34,10 +34,11 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||||
#endif | #endif | ||||
AlgoGevm gevm; | AlgoGevm gevm; | ||||
AlgoF32GemvMK4 f32_gemv_mk4; | AlgoF32GemvMK4 f32_gemv_mk4; | ||||
public: | public: | ||||
AlgoPack() { | AlgoPack() { | ||||
all_algos.emplace_back(&int8x8x16); | all_algos.emplace_back(&int8x8x16); | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||||
all_algos.emplace_back(&f16gemv); | all_algos.emplace_back(&f16gemv); | ||||
#endif | #endif | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
@@ -47,7 +48,7 @@ public: | |||||
all_algos.emplace_back(&int8x8x32_gemv_mk4); | all_algos.emplace_back(&int8x8x32_gemv_mk4); | ||||
all_algos.emplace_back(&f32_gemv_mk4); | all_algos.emplace_back(&f32_gemv_mk4); | ||||
all_algos.emplace_back(&gevm); | all_algos.emplace_back(&gevm); | ||||
} | |||||
} | |||||
SmallVector<AlgoBase*> all_algos; | SmallVector<AlgoBase*> all_algos; | ||||
}; | }; | ||||
@@ -37,6 +37,9 @@ public: | |||||
size_t group = param.filter_meta.group; | size_t group = param.filter_meta.group; | ||||
return {{kimpl, {group, 1_z, 1_z}}}; | return {{kimpl, {group, 1_z, 1_z}}}; | ||||
} | } | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; | |||||
} | |||||
}; | }; | ||||
} // namespace armv7 | } // namespace armv7 | ||||
@@ -38,6 +38,10 @@ public: | |||||
size_t group = param.filter_meta.group; | size_t group = param.filter_meta.group; | ||||
return {{kimpl, {group, 1_z, 1_z}}}; | return {{kimpl, {group, 1_z, 1_z}}}; | ||||
} | } | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL}; | |||||
} | |||||
}; | }; | ||||
} // namespace armv7 | } // namespace armv7 | ||||
@@ -85,7 +85,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32::get_kern( | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32, megdnn_armv7_matmul_kern, | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32, megdnn_armv7_matmul_kern, | ||||
"AlgoF32Impl"_hash, | "AlgoF32Impl"_hash, | ||||
armv7::matmul::sgemm_4x12, float, float); | |||||
armv7::matmul::sgemm_4x12, float, float, | |||||
AlgoDataType::FLOAT32, DEFAULT); | |||||
/* ===================== F32 algo mk4 K4x12 ===================== */ | /* ===================== F32 algo mk4 K4x12 ===================== */ | ||||
@@ -154,7 +155,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4Pack4x12, | |||||
megdnn_armv7_matmul_kern, | megdnn_armv7_matmul_kern, | ||||
"AlgoF32MK4Pack4x12"_hash, | "AlgoF32MK4Pack4x12"_hash, | ||||
armv7::matmul::sgemm_mk4_pack_4x12, float, | armv7::matmul::sgemm_mk4_pack_4x12, float, | ||||
float); | |||||
float, AlgoDataType::FLOAT32, MK4); | |||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
/* ===================== F16 K4x16x1 algo ===================== */ | /* ===================== F16 K4x16x1 algo ===================== */ | ||||
@@ -215,7 +216,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K4x16x1::get_kern( | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K4x16x1, megdnn_armv7_matmul_kern, | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K4x16x1, megdnn_armv7_matmul_kern, | ||||
"AlgoF16K4x16x1"_hash, | "AlgoF16K4x16x1"_hash, | ||||
armv7::matmul::hgemm_4x16, dt_float16, | armv7::matmul::hgemm_4x16, dt_float16, | ||||
dt_float16); | |||||
dt_float16, AlgoDataType::FLOAT16, | |||||
DEFAULT); | |||||
#endif | #endif | ||||
@@ -280,7 +282,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x2x16, | |||||
megdnn_armv7_matmul_kern, | megdnn_armv7_matmul_kern, | ||||
"AlgoInt8x8x32K4x2x16"_hash, | "AlgoInt8x8x32K4x2x16"_hash, | ||||
armv7::matmul::gemm_s8_4x2, int8_t, | armv7::matmul::gemm_s8_4x2, int8_t, | ||||
int32_t); | |||||
int32_t, AlgoDataType::QINT8X8X32, | |||||
DEFAULT); | |||||
/* ===================== Int8x8x32 Kernel 4x8x8 algo ===================== */ | /* ===================== Int8x8x32 Kernel 4x8x8 algo ===================== */ | ||||
namespace { | namespace { | ||||
@@ -342,7 +345,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x8x8, | |||||
megdnn_armv7_matmul_kern, | megdnn_armv7_matmul_kern, | ||||
"AlgoInt8x8x32K4x8x8"_hash, | "AlgoInt8x8x32K4x8x8"_hash, | ||||
armv7::matmul::gemm_s8_4x8, int8_t, | armv7::matmul::gemm_s8_4x8, int8_t, | ||||
int32_t); | |||||
int32_t, AlgoDataType::QINT8X8X32, | |||||
DEFAULT); | |||||
/* ===================== Quint8 Kernel 4x8x8 algo ===================== */ | /* ===================== Quint8 Kernel 4x8x8 algo ===================== */ | ||||
namespace { | namespace { | ||||
@@ -402,7 +406,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K4x8x8::get_kern( | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K4x8x8, megdnn_armv7_matmul_kern, | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K4x8x8, megdnn_armv7_matmul_kern, | ||||
"AlgoQuint8K4x8x8"_hash, | "AlgoQuint8K4x8x8"_hash, | ||||
armv7::matmul::gemm_u8_4x8, uint8_t, | armv7::matmul::gemm_u8_4x8, uint8_t, | ||||
int32_t); | |||||
int32_t, AlgoDataType::QUINT8X8X32, | |||||
DEFAULT); | |||||
/* ===================== Int8x8x16 Kernel 2x4x16 algo ===================== */ | /* ===================== Int8x8x16 Kernel 2x4x16 algo ===================== */ | ||||
namespace { | namespace { | ||||
@@ -468,7 +473,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x2x16, | |||||
megdnn_armv7_matmul_kern, | megdnn_armv7_matmul_kern, | ||||
"AlgoInt8x8x16K4x2x16"_hash, | "AlgoInt8x8x16K4x2x16"_hash, | ||||
armv7::matmul::gemm_s8x8x16_4x2, int8_t, | armv7::matmul::gemm_s8x8x16_4x2, int8_t, | ||||
int16_t); | |||||
int16_t, AlgoDataType::INT8X8X16, DEFAULT); | |||||
/* ===================== Int8x8x16 Kernel 4x8x8 algo ===================== */ | /* ===================== Int8x8x16 Kernel 4x8x8 algo ===================== */ | ||||
namespace { | namespace { | ||||
@@ -534,7 +539,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x8x8, | |||||
megdnn_armv7_matmul_kern, | megdnn_armv7_matmul_kern, | ||||
"AlgoInt8x8x16K4x8x8"_hash, | "AlgoInt8x8x16K4x8x8"_hash, | ||||
armv7::matmul::gemm_s8x8x16_4x8, int8_t, | armv7::matmul::gemm_s8x8x16_4x8, int8_t, | ||||
int16_t); | |||||
int16_t, AlgoDataType::INT8X8X16, DEFAULT); | |||||
/* =================== Int8x8x16 Kernel MK4 8x8x4 algo ===================*/ | /* =================== Int8x8x16 Kernel MK4 8x8x4 algo ===================*/ | ||||
@@ -602,7 +607,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16MK4_8x8x4, | |||||
megdnn_armv7_matmul_kern, | megdnn_armv7_matmul_kern, | ||||
"AlgoInt8x8x16MK4_8x8x4"_hash, | "AlgoInt8x8x16MK4_8x8x4"_hash, | ||||
armv7::matmul::gemm_s8x8x16_mk4_8x8, | armv7::matmul::gemm_s8x8x16_mk4_8x8, | ||||
int8_t, int16_t, int16_t); | |||||
int8_t, int16_t, int16_t, | |||||
AlgoDataType::INT8X8X16, MK4); | |||||
/* ===================== Int16x16x32 Kernel 12x4x1 algo ===================== */ | /* ===================== Int16x16x32 Kernel 12x4x1 algo ===================== */ | ||||
@@ -668,7 +674,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x4x1, | |||||
megdnn_armv7_matmul_kern, | megdnn_armv7_matmul_kern, | ||||
"AlgoInt16x16x32K12x4x1"_hash, | "AlgoInt16x16x32K12x4x1"_hash, | ||||
armv7::matmul::gemm_s16x16x32_12x4, | armv7::matmul::gemm_s16x16x32_12x4, | ||||
int16_t, int32_t); | |||||
int16_t, int32_t, | |||||
AlgoDataType::INT16X16X32, DEFAULT); | |||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
/* ===================== Int8 K6x8x4 algo ===================== */ | /* ===================== Int8 K6x8x4 algo ===================== */ | ||||
namespace { | namespace { | ||||
@@ -724,7 +731,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K6x8x4, | |||||
megdnn_armv7_matmul_kern, | megdnn_armv7_matmul_kern, | ||||
"AlgoInt8x8x32K6x8x4"_hash, | "AlgoInt8x8x32K6x8x4"_hash, | ||||
armv7::matmul::gemm_dots8_6x8, int8_t, | armv7::matmul::gemm_dots8_6x8, int8_t, | ||||
int32_t); | |||||
int32_t, AlgoDataType::QINT8X8X32, | |||||
DEFAULT); | |||||
/* ===================== Quint8 K4x8x4 algo ===================== */ | /* ===================== Quint8 K4x8x4 algo ===================== */ | ||||
namespace { | namespace { | ||||
void quint8_dot_k4x8x4_kern(const MatrixMulImpl::KernParam& kern_param) { | void quint8_dot_k4x8x4_kern(const MatrixMulImpl::KernParam& kern_param) { | ||||
@@ -786,7 +794,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8DotK4x8x4, | |||||
megdnn_armv7_matmul_kern, | megdnn_armv7_matmul_kern, | ||||
"AlgoQuint8DotK4x8x4"_hash, | "AlgoQuint8DotK4x8x4"_hash, | ||||
armv7::matmul::gemm_dot_quint8_4x8, | armv7::matmul::gemm_dot_quint8_4x8, | ||||
uint8_t, int32_t); | |||||
uint8_t, int32_t, | |||||
AlgoDataType::QUINT8X8X32, DEFAULT); | |||||
/* ======================== Int8 MK4 8x4x4 dot algo ======================== */ | /* ======================== Int8 MK4 8x4x4 dot algo ======================== */ | ||||
namespace { | namespace { | ||||
@@ -854,7 +863,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x4x4DotProd, | |||||
megdnn_armv7_matmul_kern, | megdnn_armv7_matmul_kern, | ||||
"AlgoInt8x8x32MK4_8x4x4DotProd"_hash, | "AlgoInt8x8x32MK4_8x4x4DotProd"_hash, | ||||
armv7::matmul::gemm_mk4_dots8_8x4, int8_t, | armv7::matmul::gemm_mk4_dots8_8x4, int8_t, | ||||
int32_t); | |||||
int32_t, AlgoDataType::QINT8X8X32, MK4_DOT); | |||||
#endif | #endif | ||||
/* ===================== F32 algo K4x8 ===================== */ | /* ===================== F32 algo K4x8 ===================== */ | ||||
@@ -1099,6 +1108,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_4x2x16, | |||||
megdnn_armv7_matmul_kern, | megdnn_armv7_matmul_kern, | ||||
"AlgoInt8x8x32MK4_4x2x16"_hash, | "AlgoInt8x8x32MK4_4x2x16"_hash, | ||||
armv7::matmul::gemm_mk4_s8_4x2, int8_t, | armv7::matmul::gemm_mk4_s8_4x2, int8_t, | ||||
int32_t); | |||||
int32_t, AlgoDataType::QINT8X8X32, MK4); | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -50,7 +50,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4) | |||||
}; | }; | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
@@ -73,7 +73,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::FLOAT16, MK8) | |||||
}; | }; | ||||
#endif | #endif | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
@@ -205,7 +205,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) | |||||
}; | }; | ||||
class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | ||||
@@ -18,7 +18,6 @@ namespace armv7 { | |||||
class MatrixMulImpl : public arm_common::MatrixMulImpl { | class MatrixMulImpl : public arm_common::MatrixMulImpl { | ||||
public: | public: | ||||
using arm_common::MatrixMulImpl::MatrixMulImpl; | using arm_common::MatrixMulImpl::MatrixMulImpl; | ||||
SmallVector<AlgoBase*> algo_pack() override; | SmallVector<AlgoBase*> algo_pack() override; | ||||
private: | private: | ||||
@@ -110,6 +110,11 @@ void __log__(LogLevel level, const char* file, const char* func, int line, | |||||
} while (0) | } while (0) | ||||
#endif // megdnn_ENABLE_LOGGING | #endif // megdnn_ENABLE_LOGGING | ||||
template <typename T> | |||||
constexpr int32_t cast_int(T data) { | |||||
return static_cast<int32_t>(data); | |||||
} | |||||
/* helper functions */ | /* helper functions */ | ||||
/** | /** | ||||
* \brief Get the next `stride' index lexicographically. | * \brief Get the next `stride' index lexicographically. | ||||
@@ -187,6 +192,29 @@ std::unique_ptr<T> make_unique(Args&&... args) { | |||||
return std::unique_ptr<T>(new T(std::forward<Args>(args)...)); | return std::unique_ptr<T>(new T(std::forward<Args>(args)...)); | ||||
} | } | ||||
/*! | |||||
* \brief check whether the source enum contain the target data type enum | |||||
*/ | |||||
bool inline contain_data_type(detail::AlgoDataType source, | |||||
detail::AlgoDataType target) { | |||||
return static_cast<bool>(static_cast<uint32_t>(source) & | |||||
static_cast<uint32_t>(target)); | |||||
} | |||||
/*! | |||||
* \brief get the source enum contain the data type number | |||||
*/ | |||||
template<typename T> | |||||
size_t nr_type_contain(T index) { | |||||
uint32_t sr_index = static_cast<uint32_t>(index); | |||||
size_t nr_type = 0; | |||||
while (sr_index != 0) { | |||||
nr_type++; | |||||
sr_index &= (sr_index - 1); | |||||
} | |||||
return nr_type; | |||||
} | |||||
/** | /** | ||||
* \brief Aligned workspace bundle. | * \brief Aligned workspace bundle. | ||||
* | * | ||||
@@ -26,6 +26,16 @@ public: | |||||
AlgoSelectionStrategy algo_selection_strategy) const override; | AlgoSelectionStrategy algo_selection_strategy) const override; | ||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
auto support_data_type = static_cast<AlgoDataType>( | |||||
static_cast<uint32_t>(AlgoDataType::FLOAT16) | | |||||
static_cast<uint32_t>(AlgoDataType::FLOAT32) | | |||||
static_cast<uint32_t>(AlgoDataType::INT8X8X16) | | |||||
static_cast<uint32_t>(AlgoDataType::QINT8X8X32) | | |||||
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)); | |||||
return {support_data_type, AlgoCategory::NAIVE}; | |||||
} | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoWinogradF32 final : public AlgoBase { | class ConvBiasImpl::AlgoWinogradF32 final : public AlgoBase { | ||||
@@ -46,6 +56,10 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; | |||||
} | |||||
private: | private: | ||||
MatrixMulImpl::AlgoBase* m_matmul_algo; | MatrixMulImpl::AlgoBase* m_matmul_algo; | ||||
mutable std::string m_name; | mutable std::string m_name; | ||||
@@ -70,6 +84,10 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; | |||||
} | |||||
private: | private: | ||||
MatrixMulImpl::AlgoBase* m_matmul_algo; | MatrixMulImpl::AlgoBase* m_matmul_algo; | ||||
mutable std::string m_name; | mutable std::string m_name; | ||||
@@ -94,6 +112,10 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; | |||||
} | |||||
private: | private: | ||||
MatrixMulImpl::AlgoBase* m_matmul_algo; | MatrixMulImpl::AlgoBase* m_matmul_algo; | ||||
mutable std::string m_name; | mutable std::string m_name; | ||||
@@ -118,6 +140,10 @@ public: | |||||
size_t get_workspace(const NCBKernSizeParam& param) const override; | size_t get_workspace(const NCBKernSizeParam& param) const override; | ||||
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; | |||||
} | |||||
private: | private: | ||||
MatrixMulImpl::AlgoBase* m_matmul_algo; | MatrixMulImpl::AlgoBase* m_matmul_algo; | ||||
mutable std::string m_name; | mutable std::string m_name; | ||||
@@ -140,7 +140,7 @@ using BiasMode = ConvBiasForward::BiasMode; | |||||
break; \ | break; \ | ||||
} | } | ||||
#define MEGDNN_WINOGRAD_ALGO_FUN_DECLARE() \ | |||||
#define MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(_algo_data_type) \ | |||||
bool is_reproducible() const override { return true; } \ | bool is_reproducible() const override { return true; } \ | ||||
bool usable(const NCBKernSizeParam& param, \ | bool usable(const NCBKernSizeParam& param, \ | ||||
AlgoSelectionStrategy algo_selection_strategy) const override; \ | AlgoSelectionStrategy algo_selection_strategy) const override; \ | ||||
@@ -153,6 +153,9 @@ using BiasMode = ConvBiasForward::BiasMode; | |||||
const override; \ | const override; \ | ||||
virtual SmallVector<NCBKern> dispatch_preprocess_kerns( \ | virtual SmallVector<NCBKern> dispatch_preprocess_kerns( \ | ||||
const NCBKernSizeParam& param) const override; \ | const NCBKernSizeParam& param) const override; \ | ||||
ConvAlgoTypePack get_algo_type() const override { \ | |||||
return {_algo_data_type, AlgoCategory::WINOGRAD}; \ | |||||
} \ | |||||
\ | \ | ||||
private: \ | private: \ | ||||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; \ | fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; \ | ||||
@@ -288,7 +288,8 @@ bool ConvBiasImpl::AlgoConv1x1::is_preferred( | |||||
size_t OH = param.osz[0]; | size_t OH = param.osz[0]; | ||||
size_t OW = param.osz[1]; | size_t OW = param.osz[1]; | ||||
if (OH * OW != 1) { | if (OH * OW != 1) { | ||||
return true; | |||||
return m_matmul_algo->algoset() != | |||||
MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV; | |||||
} else { | } else { | ||||
#if (MEGDNN_ARMV7 || MEGDNN_AARCH64) | #if (MEGDNN_ARMV7 || MEGDNN_AARCH64) | ||||
if (param.src_type.enumv() == DTypeEnum::Int8 && | if (param.src_type.enumv() == DTypeEnum::Int8 && | ||||
@@ -56,6 +56,11 @@ public: | |||||
SmallVector<NCBKern> dispatch_preprocess_kerns( | SmallVector<NCBKern> dispatch_preprocess_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override{ | |||||
return {m_matmul_algo->matmul_description().algo_type.data_type, | |||||
AlgoCategory::IM2COL}; | |||||
} | |||||
protected: | protected: | ||||
size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; | size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; | ||||
@@ -34,6 +34,16 @@ public: | |||||
bool is_preferred(const NCBKernSizeParam&) const override; | bool is_preferred(const NCBKernSizeParam&) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
auto support_data_type = static_cast<AlgoDataType>( | |||||
static_cast<uint32_t>(AlgoDataType::FLOAT16) | | |||||
static_cast<uint32_t>(AlgoDataType::FLOAT32) | | |||||
static_cast<uint32_t>(AlgoDataType::INT8X8X16) | | |||||
static_cast<uint32_t>(AlgoDataType::QINT8X8X32) | | |||||
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)); | |||||
return {support_data_type, AlgoCategory::IM2COL}; | |||||
} | |||||
protected: | protected: | ||||
size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; | size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; | ||||
}; | }; | ||||
@@ -48,15 +48,25 @@ public: | |||||
SmallVector<NCBKern> dispatch_preprocess_kerns( | SmallVector<NCBKern> dispatch_preprocess_kerns( | ||||
const NCBKernSizeParam& param) const override; | const NCBKernSizeParam& param) const override; | ||||
bool is_preferred(const NCBKernSizeParam& param) const override { | bool is_preferred(const NCBKernSizeParam& param) const override { | ||||
if (param.src_type.category() == DTypeCategory::QUANTIZED) { | |||||
static CpuOprDelegationStorage<1> storage; | |||||
auto conv_bias_opr = storage.get<ConvBias, 0>(); | |||||
return static_cast<ConvBiasImpl*>(conv_bias_opr) | |||||
->is_matmul_quantized_prefer(param); | |||||
size_t OH = param.osz[0]; | |||||
size_t OW = param.osz[1]; | |||||
//! gemm and oh * ow > 1 is prefer | |||||
//! gemv and oh * ow == 1 is prefer | |||||
if ((m_matmul_algo->algoset() != | |||||
MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV && | |||||
OH * OW > 1) || | |||||
(m_matmul_algo->algoset() == | |||||
MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV && | |||||
OH * OW == 1)) { | |||||
return true; | |||||
} else { | |||||
return false; | |||||
} | } | ||||
auto&& fm = param.filter_meta; | |||||
auto OC = fm.ocpg, IC = fm.icpg; | |||||
return OC >= 32 || IC >= 32; | |||||
} | |||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {m_matmul_algo->matmul_description().algo_type.data_type, | |||||
AlgoCategory::IM2COL}; | |||||
} | } | ||||
private: | private: | ||||
@@ -48,11 +48,26 @@ void incr_ptr(T*& dst, ptrdiff_t delta) { | |||||
} // namespace | } // namespace | ||||
#if MEGDNN_X86 | |||||
#define SKIP_GEMV() | |||||
//! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may | |||||
//! fallback to naive implementation, which may cause performance very low, so | |||||
//! here we just enable im2col for gemv in x86 backend. | |||||
//! FIXME: remove it when we add direct conv support for int8x8x16 | |||||
#else | |||||
#define SKIP_GEMV() \ | |||||
if (algo->algoset() == MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { \ | |||||
continue; \ | |||||
} | |||||
#endif | |||||
class ConvBiasImpl::AlgoPack : NonCopyableObj { | class ConvBiasImpl::AlgoPack : NonCopyableObj { | ||||
AlgoNaive algo_naive; | AlgoNaive algo_naive; | ||||
SmallVector<std::unique_ptr<AlgoBase>> refhold; | SmallVector<std::unique_ptr<AlgoBase>> refhold; | ||||
public: | public: | ||||
AlgoPack() { | AlgoPack() { | ||||
refhold.emplace_back(new AlgoConv1x1Gemv()); | refhold.emplace_back(new AlgoConv1x1Gemv()); | ||||
all_algos.emplace_back(refhold.back().get()); | all_algos.emplace_back(refhold.back().get()); | ||||
@@ -110,8 +125,6 @@ public: | |||||
all_algos.emplace_back(refhold.back().get()); | all_algos.emplace_back(refhold.back().get()); | ||||
#endif | #endif | ||||
} | } | ||||
//! reverse matmul algo, when the algo is_prefer can be selected first | |||||
std::reverse(all_algos.begin(), all_algos.end()); | |||||
all_algos.emplace_back(&algo_naive); | all_algos.emplace_back(&algo_naive); | ||||
} | } | ||||
SmallVector<AlgoBase*> all_algos; | SmallVector<AlgoBase*> all_algos; | ||||
@@ -121,6 +134,22 @@ SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||||
static AlgoPack sl_algo_pack; | static AlgoPack sl_algo_pack; | ||||
return sl_algo_pack.all_algos; | return sl_algo_pack.all_algos; | ||||
} | } | ||||
SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::select_algo_type( | |||||
ConvAlgoTypePack target_type) { | |||||
megdnn_assert(nr_type_contain(target_type.data_type), | |||||
"ConvBias algo selection only support one type"); | |||||
SmallVector<ConvBiasImpl::AlgoBase*> algos; | |||||
for (auto&& algo : algo_pack()) { | |||||
auto algo_type = algo->get_algo_type(); | |||||
if (contain_data_type(algo_type.data_type, target_type.data_type) && | |||||
algo_type.algo_category == target_type.algo_category) { | |||||
algos.push_back(algo); | |||||
} | |||||
} | |||||
return algos; | |||||
} | |||||
bool ConvBiasImpl::is_naive_algo(ConvBiasImpl::Algorithm* algo) { | bool ConvBiasImpl::is_naive_algo(ConvBiasImpl::Algorithm* algo) { | ||||
return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0; | return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0; | ||||
} | } | ||||
@@ -248,12 +277,32 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic( | |||||
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb( | ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb( | ||||
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | bool reproducible) { | ||||
for (auto i : get_all_algorithms_with_ncb(param)) { | |||||
if (static_cast<AlgoBase*>(i)->usable_reproducible( | |||||
param, AlgoSelectionStrategy::HEURISTIC, reproducible) && | |||||
NCB_ALGO_FUNC(get_workspace, i, param) <= | |||||
workspace_limit_in_bytes) { | |||||
return i; | |||||
auto algo_data_type = param.deduce_algo_data_type(); | |||||
auto suggest_category_order = suggest_algo_category_order(param); | |||||
for (auto category : suggest_category_order) { | |||||
auto&& origin_algos = select_algo_type({algo_data_type, category}); | |||||
ConvBiasImpl::Algorithm* heuristic_algo = nullptr; | |||||
for (auto i : origin_algos) { | |||||
bool usable_reproducible = | |||||
static_cast<AlgoBase*>(i)->usable_reproducible( | |||||
param, AlgoSelectionStrategy::HEURISTIC, | |||||
reproducible); | |||||
if (usable_reproducible && | |||||
static_cast<AlgoBase*>(i)->get_workspace(param) <= | |||||
workspace_limit_in_bytes) { | |||||
//! store the first usable algo if no prefer algo, choose it as | |||||
//! the target algo | |||||
if (!heuristic_algo) { | |||||
heuristic_algo = i; | |||||
} | |||||
//! choose the first prefer algo | |||||
if (i->is_preferred(param)) { | |||||
return i; | |||||
} | |||||
} | |||||
} | |||||
if (heuristic_algo) { | |||||
return heuristic_algo; | |||||
} | } | ||||
} | } | ||||
return nullptr; | return nullptr; | ||||
@@ -300,9 +349,8 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param( | |||||
sizeof(ConvolutionImpl::CanonizedFilterMeta), | sizeof(ConvolutionImpl::CanonizedFilterMeta), | ||||
"sizeof CanonizedFilterMeta in convolution and conv_bias " | "sizeof CanonizedFilterMeta in convolution and conv_bias " | ||||
"should be equal"); | "should be equal"); | ||||
CanonizedFilterMeta fm = check_layout_fwd(src, filter, dst); | |||||
ConvolutionImpl::CanonizedFilterMeta conv_fm; | |||||
conv_fm.copy_from(fm); | |||||
auto&& fm = check_layout_fwd(src, filter, dst); | |||||
auto& conv_fm = reinterpret_cast<ConvolutionImpl::CanonizedFilterMeta&>(fm); | |||||
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT; | param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT; | ||||
if (param().format == Param::Format::NCHW_WINOGRAD || | if (param().format == Param::Format::NCHW_WINOGRAD || | ||||
@@ -367,7 +415,7 @@ ConvBiasImpl::NCBKernParam ConvBiasImpl::make_ncb_kern_param( | |||||
void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param, | void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param, | ||||
ConvBiasImpl::Algorithm* algo) { | ConvBiasImpl::Algorithm* algo) { | ||||
auto ncb_kerns = NCB_ALGO_FUNC(dispatch_kerns, algo, param); | |||||
auto&& ncb_kerns = NCB_ALGO_FUNC(dispatch_kerns, algo, param); | |||||
for (auto&& kernel : ncb_kerns) { | for (auto&& kernel : ncb_kerns) { | ||||
auto run = [kernel, param](size_t index, size_t thread_id) { | auto run = [kernel, param](size_t index, size_t thread_id) { | ||||
CpuNDRange ndrange_id(kernel.global_size, index); | CpuNDRange ndrange_id(kernel.global_size, index); | ||||
@@ -380,7 +428,7 @@ void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param, | |||||
void ConvBiasImpl::exec_preprocess_with_ncb_kern( | void ConvBiasImpl::exec_preprocess_with_ncb_kern( | ||||
const NCBKernParam& param, ConvBiasImpl::Algorithm* algo) { | const NCBKernParam& param, ConvBiasImpl::Algorithm* algo) { | ||||
auto ncb_kerns = NCB_ALGO_FUNC(dispatch_preprocess_kerns, algo, param); | |||||
auto&& ncb_kerns = NCB_ALGO_FUNC(dispatch_preprocess_kerns, algo, param); | |||||
for (auto&& kernel : ncb_kerns) { | for (auto&& kernel : ncb_kerns) { | ||||
auto run = [kernel, param](size_t index, size_t thread_id) { | auto run = [kernel, param](size_t index, size_t thread_id) { | ||||
CpuNDRange ndrange_id(kernel.global_size, index); | CpuNDRange ndrange_id(kernel.global_size, index); | ||||
@@ -405,7 +453,6 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb( | |||||
} | } | ||||
} | } | ||||
} | } | ||||
std::reverse(prefer_algos.begin(), prefer_algos.end()); | |||||
//! Prefer algo inserted from begin | //! Prefer algo inserted from begin | ||||
algos.insert(algos.begin(), prefer_algos.begin(), prefer_algos.end()); | algos.insert(algos.begin(), prefer_algos.begin(), prefer_algos.end()); | ||||
return algos; | return algos; | ||||
@@ -425,6 +472,35 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm( | |||||
return m_prev_selected_algo; | return m_prev_selected_algo; | ||||
} | } | ||||
SmallVector<AlgoCategory> ConvBiasImpl::suggest_algo_category_order( | |||||
const NCBKernSizeParam& param) const { | |||||
auto IC = param.filter_meta.icpg; | |||||
auto OC = param.filter_meta.ocpg; | |||||
auto FH = param.filter_meta.spatial[0]; | |||||
auto FW = param.filter_meta.spatial[1]; | |||||
//! TODO: now winograd only support in fast-run | |||||
if (param.filter_meta.format == param::ConvBias::Format::NCHW_WINOGRAD || | |||||
param.filter_meta.format == param::ConvBias::Format::NCHW44_WINOGRAD || | |||||
param.filter_meta.format == param::ConvBias::Format::NCHW88_WINOGRAD) { | |||||
return {AlgoCategory::WINOGRAD}; | |||||
} | |||||
//! im2col + matmul | |||||
bool im2col_prefer = (IC >= 32 || OC >= 32); | |||||
//! quantized algo use matmul when direct algo is unusable | |||||
if (param.src_type.category() == DTypeCategory::QUANTIZED) { | |||||
im2col_prefer = is_matmul_quantized_prefer(param); | |||||
} | |||||
//! conv1x1 | |||||
im2col_prefer |= (FH == 1 && FW == 1); | |||||
if (im2col_prefer) { | |||||
return {AlgoCategory::IM2COL, AlgoCategory::DIRECT, | |||||
AlgoCategory::NAIVE}; | |||||
} else { | |||||
return {AlgoCategory::DIRECT, AlgoCategory::IM2COL, | |||||
AlgoCategory::NAIVE}; | |||||
} | |||||
} | |||||
const char* ConvBiasImpl::get_algorithm_set_name() const { | const char* ConvBiasImpl::get_algorithm_set_name() const { | ||||
// fallback version 0 | // fallback version 0 | ||||
return "F0"; | return "F0"; | ||||
@@ -18,6 +18,8 @@ | |||||
#include "src/fallback/matrix_mul/opr_impl.h" | #include "src/fallback/matrix_mul/opr_impl.h" | ||||
#include "src/naive/conv_bias/opr_impl.h" | #include "src/naive/conv_bias/opr_impl.h" | ||||
#include <unordered_map> | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace fallback { | namespace fallback { | ||||
@@ -44,6 +46,7 @@ class ConvBiasImpl : public naive::ConvBiasForwardImpl { | |||||
public: | public: | ||||
using naive::ConvBiasForwardImpl::ConvBiasForwardImpl; | using naive::ConvBiasForwardImpl::ConvBiasForwardImpl; | ||||
using AlgoSelectionStrategy = detail::AlgoSelectionStrategy; | using AlgoSelectionStrategy = detail::AlgoSelectionStrategy; | ||||
using AlgoDataType = detail::AlgoDataType; | |||||
//! implemented by exec_with_ncb_kern() | //! implemented by exec_with_ncb_kern() | ||||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | ||||
@@ -94,6 +97,8 @@ public: | |||||
size_t workspace_limit_in_bytes, | size_t workspace_limit_in_bytes, | ||||
bool reproducible) override; | bool reproducible) override; | ||||
//! size param for kernels with non-contiguous batch | //! size param for kernels with non-contiguous batch | ||||
struct NCBKernSizeParam : ConvolutionImpl::NCBKernSizeParam { | struct NCBKernSizeParam : ConvolutionImpl::NCBKernSizeParam { | ||||
NCBKernSizeParam() = default; | NCBKernSizeParam() = default; | ||||
@@ -244,6 +249,9 @@ public: | |||||
return (!reproducible || is_reproducible()) && | return (!reproducible || is_reproducible()) && | ||||
usable(param, algo_selection_strategy); | usable(param, algo_selection_strategy); | ||||
} | } | ||||
//! get the type of the algo | |||||
virtual ConvAlgoTypePack get_algo_type() const = 0; | |||||
}; | }; | ||||
/** | /** | ||||
@@ -251,6 +259,17 @@ public: | |||||
*/ | */ | ||||
virtual SmallVector<AlgoBase*> algo_pack(); | virtual SmallVector<AlgoBase*> algo_pack(); | ||||
/** | |||||
* \brief select algo according to input algo type | |||||
*/ | |||||
SmallVector<AlgoBase*> select_algo_type(ConvAlgoTypePack algo_type); | |||||
/** | |||||
* \brief suggest algo category according to the param | |||||
*/ | |||||
virtual SmallVector<AlgoCategory> suggest_algo_category_order( | |||||
const NCBKernSizeParam& param) const; | |||||
protected: | protected: | ||||
virtual void exec_with_ncb_kern(const NCBKernParam& param, | virtual void exec_with_ncb_kern(const NCBKernParam& param, | ||||
ConvBiasImpl::Algorithm* algo); | ConvBiasImpl::Algorithm* algo); | ||||
@@ -83,6 +83,10 @@ public: | |||||
SmallVector<NCBKern> dispatch_kern( | SmallVector<NCBKern> dispatch_kern( | ||||
const NCBKernSizeParam& /*param*/) const override; | const NCBKernSizeParam& /*param*/) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::FLOAT32, AlgoCategory::NAIVE}; | |||||
} | |||||
}; | }; | ||||
class ConvolutionImpl::AlgoNaive final : public AlgoBase { | class ConvolutionImpl::AlgoNaive final : public AlgoBase { | ||||
@@ -96,11 +100,17 @@ public: | |||||
SmallVector<NCBKern> dispatch_kern( | SmallVector<NCBKern> dispatch_kern( | ||||
const NCBKernSizeParam& /*param*/) const override; | const NCBKernSizeParam& /*param*/) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
auto support_data_type = static_cast<AlgoDataType>( | |||||
static_cast<uint32_t>(AlgoDataType::INT8X8X16) | | |||||
static_cast<uint32_t>(AlgoDataType::QINT8X8X32) | | |||||
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)); | |||||
return {support_data_type, AlgoCategory::NAIVE}; | |||||
} | |||||
}; | }; | ||||
class ConvolutionImpl::AlgoDefault final : public AlgoBase { | class ConvolutionImpl::AlgoDefault final : public AlgoBase { | ||||
static ConvBiasImpl::NCBKernSizeParam init_conv_bias_param( | |||||
const NCBKernSizeParam& param); | |||||
WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; | WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; | ||||
static SmallVector<NCBKern> get_kimpl(ConvBiasImpl::AlgoBase* algo, | static SmallVector<NCBKern> get_kimpl(ConvBiasImpl::AlgoBase* algo, | ||||
const NCBKernSizeParam& param); | const NCBKernSizeParam& param); | ||||
@@ -136,6 +146,13 @@ public: | |||||
//! select matmul to the highest preference | //! select matmul to the highest preference | ||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
static ConvBiasImpl::NCBKernSizeParam init_conv_bias_param( | |||||
const NCBKernSizeParam& param); | |||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return m_algorithm->get_algo_type(); | |||||
} | |||||
private: | private: | ||||
std::string m_name; | std::string m_name; | ||||
ConvBiasImpl::AlgoBase* m_algorithm; | ConvBiasImpl::AlgoBase* m_algorithm; | ||||
@@ -23,6 +23,7 @@ | |||||
#include "midout.h" | #include "midout.h" | ||||
#include <cstring> | #include <cstring> | ||||
#include <unordered_map> | |||||
MIDOUT_DECL(megdnn_fb_convbwd_float) | MIDOUT_DECL(megdnn_fb_convbwd_float) | ||||
@@ -75,6 +76,22 @@ SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::algo_pack() { | |||||
static AlgoPack sl_algo_pack; | static AlgoPack sl_algo_pack; | ||||
return sl_algo_pack.all_algos; | return sl_algo_pack.all_algos; | ||||
} | } | ||||
SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::select_algo_type( | |||||
ConvAlgoTypePack target_type) { | |||||
megdnn_assert(nr_type_contain(target_type.data_type), | |||||
"ConvBias algo selection only support one type"); | |||||
SmallVector<ConvolutionImpl::AlgoBase*> algos; | |||||
for (auto&& algo : algo_pack()) { | |||||
auto algo_type = algo->get_algo_type(); | |||||
if (contain_data_type(algo_type.data_type, target_type.data_type) && | |||||
algo_type.algo_category == target_type.algo_category) { | |||||
algos.push_back(algo); | |||||
} | |||||
} | |||||
return algos; | |||||
} | |||||
bool ConvolutionImpl::is_naive_algo(ConvolutionImpl::Algorithm* algo) { | bool ConvolutionImpl::is_naive_algo(ConvolutionImpl::Algorithm* algo) { | ||||
return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0; | return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0; | ||||
} | } | ||||
@@ -249,9 +266,9 @@ ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param( | |||||
void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param, | void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param, | ||||
Algorithm* algo) { | Algorithm* algo) { | ||||
auto kerns = NCB_ALGO_FUNC(dispatch_preprocess_kern, algo, param); | |||||
auto fallback_handle = handle(); | |||||
for (auto kernel : kerns) { | |||||
auto&& kerns = NCB_ALGO_FUNC(dispatch_preprocess_kern, algo, param); | |||||
auto&& fallback_handle = handle(); | |||||
for (auto&& kernel : kerns) { | |||||
megdnn_assert( | megdnn_assert( | ||||
param.filter_meta.format == Param::Format::NCHW || | param.filter_meta.format == Param::Format::NCHW || | ||||
param.filter_meta.format == Param::Format::NHWC || | param.filter_meta.format == Param::Format::NHWC || | ||||
@@ -270,9 +287,9 @@ void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param, | |||||
void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, | void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, | ||||
Algorithm* algo) { | Algorithm* algo) { | ||||
auto kerns = NCB_ALGO_FUNC(dispatch_kern, algo, param); | |||||
auto fallback_handle = handle(); | |||||
for (auto kernel : kerns) { | |||||
auto&& kerns = NCB_ALGO_FUNC(dispatch_kern, algo, param); | |||||
auto&& fallback_handle = handle(); | |||||
for (auto&& kernel : kerns) { | |||||
megdnn_assert( | megdnn_assert( | ||||
param.filter_meta.format == Param::Format::NCHW || | param.filter_meta.format == Param::Format::NCHW || | ||||
param.filter_meta.format == Param::Format::NHWC || | param.filter_meta.format == Param::Format::NHWC || | ||||
@@ -292,13 +309,32 @@ void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, | |||||
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb( | ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb( | ||||
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | ||||
bool reproducible) { | bool reproducible) { | ||||
for (auto i : get_all_algorithms_with_ncb(param)) { | |||||
bool usable_reproducible = | |||||
static_cast<AlgoBase*>(i)->usable_reproducible( | |||||
param, AlgoSelectionStrategy::HEURISTIC, reproducible); | |||||
if (usable_reproducible && NCB_ALGO_FUNC(get_workspace, i, param) <= | |||||
workspace_limit_in_bytes) { | |||||
return i; | |||||
auto algo_data_type = param.deduce_algo_data_type(); | |||||
auto suggest_category_order = suggest_algo_category_order(param); | |||||
for (auto category : suggest_category_order) { | |||||
auto&& origin_algos = select_algo_type({algo_data_type, category}); | |||||
ConvolutionImpl::Algorithm* heuristic_algo = nullptr; | |||||
for (auto i : origin_algos) { | |||||
bool usable_reproducible = | |||||
static_cast<AlgoBase*>(i)->usable_reproducible( | |||||
param, AlgoSelectionStrategy::HEURISTIC, | |||||
reproducible); | |||||
if (usable_reproducible && | |||||
static_cast<AlgoBase*>(i)->get_workspace(param) <= | |||||
workspace_limit_in_bytes) { | |||||
//! store the first usable algo if no prefer algo, choose it as | |||||
//! the target algo | |||||
if (!heuristic_algo) { | |||||
heuristic_algo = i; | |||||
} | |||||
//! choose the first prefer algo | |||||
if (i->is_preferred(param)) { | |||||
return i; | |||||
} | |||||
} | |||||
} | |||||
if (heuristic_algo) { | |||||
return heuristic_algo; | |||||
} | } | ||||
} | } | ||||
return nullptr; | return nullptr; | ||||
@@ -317,8 +353,6 @@ ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) { | |||||
} | } | ||||
} | } | ||||
} | } | ||||
std::reverse(prefer_algos.begin(), prefer_algos.end()); | |||||
//! Prefer algo inserted from begin | |||||
ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end()); | ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end()); | ||||
return ret; | return ret; | ||||
} | } | ||||
@@ -337,11 +371,45 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm( | |||||
return m_prev_selected_algo; | return m_prev_selected_algo; | ||||
} | } | ||||
SmallVector<AlgoCategory> ConvolutionImpl::suggest_algo_category_order( | |||||
const NCBKernSizeParam& param) const { | |||||
static CpuOprDelegationStorage<1> storage; | |||||
auto conv_bias_opr = storage.get<ConvBias, 0>(); | |||||
auto conv_bias_param = | |||||
ConvolutionImpl::AlgoDefault::init_conv_bias_param(param); | |||||
return static_cast<ConvBiasImpl*>(conv_bias_opr) | |||||
->suggest_algo_category_order(conv_bias_param); | |||||
} | |||||
const char* ConvolutionImpl::get_algorithm_set_name() const { | const char* ConvolutionImpl::get_algorithm_set_name() const { | ||||
// fallback version 0 | // fallback version 0 | ||||
return "F0"; | return "F0"; | ||||
} | } | ||||
ConvolutionImpl::AlgoDataType | |||||
ConvolutionImpl::NCBKernSizeParam::deduce_algo_data_type() const { | |||||
if (src_type.enumv() == DTypeEnum::Float32) { | |||||
return ConvolutionImpl::AlgoDataType::FLOAT32; | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
} else if (src_type.enumv() == DTypeEnum::Float16) { | |||||
return ConvolutionImpl::AlgoDataType::FLOAT16; | |||||
#endif | |||||
} else if (src_type.enumv() == DTypeEnum::Int8 || | |||||
src_type.enumv() == DTypeEnum::QuantizedS8) { | |||||
if (dst_type.enumv() == DTypeEnum::Int16) { | |||||
return ConvolutionImpl::AlgoDataType::INT8X8X16; | |||||
} else { | |||||
return ConvolutionImpl::AlgoDataType::QINT8X8X32; | |||||
} | |||||
} else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||||
return ConvolutionImpl::AlgoDataType::QUINT8X8X32; | |||||
} else { | |||||
megdnn_throw(ssprintf("megdnn not support data type of %s * %s -> %s\n", | |||||
src_type.name(), filter_type.name(), | |||||
dst_type.name())); | |||||
} | |||||
} | |||||
/* ===================== ConvolutionBackwardData ===================== */ | /* ===================== ConvolutionBackwardData ===================== */ | ||||
void* const ConvolutionBackwardDataImpl::sm_fallback_deconv_algo_type = | void* const ConvolutionBackwardDataImpl::sm_fallback_deconv_algo_type = | ||||
@@ -10,11 +10,28 @@ | |||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megdnn/oprs/base.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/fallback/handle.h" | #include "src/fallback/handle.h" | ||||
#include "src/naive/convolution/opr_impl.h" | #include "src/naive/convolution/opr_impl.h" | ||||
namespace megdnn { | namespace megdnn { | ||||
/** | |||||
* \brief Convolutino algo category | |||||
*/ | |||||
enum class AlgoCategory : int32_t { | |||||
DIRECT = 0, | |||||
IM2COL = 1, | |||||
WINOGRAD = 2, | |||||
NAIVE = 3, | |||||
}; | |||||
struct ConvAlgoTypePack { | |||||
detail::AlgoDataType data_type : 32; | |||||
AlgoCategory algo_category : 32; | |||||
}; | |||||
namespace fallback { | namespace fallback { | ||||
/*! | /*! | ||||
@@ -33,6 +50,7 @@ class ConvolutionImpl : public naive::ConvolutionForwardImpl { | |||||
public: | public: | ||||
using naive::ConvolutionForwardImpl::ConvolutionForwardImpl; | using naive::ConvolutionForwardImpl::ConvolutionForwardImpl; | ||||
using AlgoSelectionStrategy = detail::AlgoSelectionStrategy; | using AlgoSelectionStrategy = detail::AlgoSelectionStrategy; | ||||
using AlgoDataType = detail::AlgoDataType; | |||||
//! implemented by exec_with_ncb_kern() | //! implemented by exec_with_ncb_kern() | ||||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | ||||
@@ -86,6 +104,8 @@ public: | |||||
size_t nr_threads; | size_t nr_threads; | ||||
//! weight_preprocess info | //! weight_preprocess info | ||||
const PreprocessedFilter* preprocessed_filter; | const PreprocessedFilter* preprocessed_filter; | ||||
//! get the data type category of the param for select the algo | |||||
AlgoDataType deduce_algo_data_type() const; | |||||
}; | }; | ||||
//! memory param for kernels with non-contiguous batch | //! memory param for kernels with non-contiguous batch | ||||
@@ -211,6 +231,9 @@ public: | |||||
return (!reproducible || is_reproducible()) && | return (!reproducible || is_reproducible()) && | ||||
usable(param, algo_selection_strategy); | usable(param, algo_selection_strategy); | ||||
} | } | ||||
//! get the type of the algo | |||||
virtual ConvAlgoTypePack get_algo_type() const = 0; | |||||
}; | }; | ||||
/** | /** | ||||
@@ -218,6 +241,11 @@ public: | |||||
*/ | */ | ||||
virtual SmallVector<AlgoBase*> algo_pack(); | virtual SmallVector<AlgoBase*> algo_pack(); | ||||
/** | |||||
* \brief select algo according to input algo type | |||||
*/ | |||||
SmallVector<AlgoBase*> select_algo_type(ConvAlgoTypePack algo_type); | |||||
protected: | protected: | ||||
virtual void exec_with_ncb_kern(const NCBKernParam& param, Algorithm* algo); | virtual void exec_with_ncb_kern(const NCBKernParam& param, Algorithm* algo); | ||||
@@ -258,6 +286,9 @@ private: | |||||
_megdnn_tensor_out dst, | _megdnn_tensor_out dst, | ||||
const PreprocessedFilter* preprocessed_filter, | const PreprocessedFilter* preprocessed_filter, | ||||
_megdnn_workspace workspace); | _megdnn_workspace workspace); | ||||
SmallVector<AlgoCategory> suggest_algo_category_order( | |||||
const NCBKernSizeParam& param) const; | |||||
}; | }; | ||||
class ConvolutionBackwardDataImpl : public naive::ConvolutionBackwardDataImpl { | class ConvolutionBackwardDataImpl : public naive::ConvolutionBackwardDataImpl { | ||||
@@ -76,7 +76,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern( | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_fb_matmul_f32_kern, | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_fb_matmul_f32_kern, | ||||
5, matmul::fallback::sgemm_8x12, float, | 5, matmul::fallback::sgemm_8x12, float, | ||||
float); | |||||
float, AlgoDataType::FLOAT32, DEFAULT); | |||||
/* ===================== gemv algo ===================== */ | /* ===================== gemv algo ===================== */ | ||||
bool MatrixMulImpl::AlgoGemv::usable( | bool MatrixMulImpl::AlgoGemv::usable( | ||||
@@ -37,7 +37,15 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC( | |||||
8, 16, 1, 4, | |||||
static_cast<AlgoDataType>( | |||||
static_cast<uint32_t>(AlgoDataType::FLOAT16) | | |||||
static_cast<uint32_t>(AlgoDataType::FLOAT32) | | |||||
static_cast<uint32_t>(AlgoDataType::INT8X8X16) | | |||||
static_cast<uint32_t>(AlgoDataType::QINT8X8X32) | | |||||
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)), | |||||
DEFAULT) | |||||
}; | }; | ||||
} // namespace fallback | } // namespace fallback | ||||
@@ -352,13 +352,15 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K, | |||||
DType dtype_c) \ | DType dtype_c) \ | ||||
: A_dtype(dtype_a), B_dtype(dtype_b), C_dtype(dtype_c) {} | : A_dtype(dtype_a), B_dtype(dtype_b), C_dtype(dtype_c) {} | ||||
#define MEGDNN_OVERRIDE_MATMUL_DESC(_m, _n, _k, _packa_type_size) \ | |||||
MatmulDescription matmul_description() const override { \ | |||||
MatmulDescription mdesc; \ | |||||
mdesc.packmode = packmode(); \ | |||||
mdesc.innerblocksize = {_m, _n, _k}; \ | |||||
mdesc.packa_type_size = _packa_type_size; \ | |||||
return mdesc; \ | |||||
#define MEGDNN_OVERRIDE_MATMUL_DESC(_m, _n, _k, _packa_type_size, _data_type, \ | |||||
_format) \ | |||||
MatmulDescription matmul_description() const override { \ | |||||
MatmulDescription mdesc; \ | |||||
mdesc.packmode = packmode(); \ | |||||
mdesc.innerblocksize = {_m, _n, _k}; \ | |||||
mdesc.packa_type_size = _packa_type_size; \ | |||||
mdesc.algo_type = {_data_type, Param::Format::_format}; \ | |||||
return mdesc; \ | |||||
} | } | ||||
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL() \ | #define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL() \ | ||||
@@ -373,7 +375,7 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K, | |||||
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \ | #define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \ | ||||
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \ | _algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \ | ||||
_packa_type) \ | |||||
_packa_type, _support_data_type, _format) \ | |||||
\ | \ | ||||
MatrixMulImpl::kern_naked_t MatrixMulImpl::_algo_name::get_kern_naked( \ | MatrixMulImpl::kern_naked_t MatrixMulImpl::_algo_name::get_kern_naked( \ | ||||
const KernSizeParam&) const { \ | const KernSizeParam&) const { \ | ||||
@@ -474,14 +476,16 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K, | |||||
mdesc.innerblocksize = {_strategy::KERNEL_H, _strategy::KERNEL_W, \ | mdesc.innerblocksize = {_strategy::KERNEL_H, _strategy::KERNEL_W, \ | ||||
_strategy::UNROLL_K}; \ | _strategy::UNROLL_K}; \ | ||||
mdesc.packa_type_size = sizeof(_packa_type); \ | mdesc.packa_type_size = sizeof(_packa_type); \ | ||||
mdesc.algo_type = {_support_data_type, Param::Format::_format}; \ | |||||
return mdesc; \ | return mdesc; \ | ||||
} | } | ||||
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \ | |||||
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type) \ | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(_algo_name, _midout_name, \ | |||||
_mid_index, _strategy, \ | |||||
_i_type, _c_type, _i_type) | |||||
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \ | |||||
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \ | |||||
_support_data_type, _format) \ | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \ | |||||
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \ | |||||
_i_type, _support_data_type, _format) | |||||
} // namespace matmul | } // namespace matmul | ||||
} // namespace megdnn | } // namespace megdnn | ||||
@@ -38,6 +38,22 @@ SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||||
return s_algo_pack.all_algos; | return s_algo_pack.all_algos; | ||||
} | } | ||||
SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::select_algo_type( | |||||
AlgoTypePack index) { | |||||
megdnn_assert(nr_type_contain(index.data_type), | |||||
"Matmul algo selection only support one type"); | |||||
SmallVector<MatrixMulImpl::AlgoBase*> algos; | |||||
for (auto&& algo : algo_pack()) { | |||||
auto algo_desc = algo->matmul_description(); | |||||
if (contain_data_type(algo_desc.algo_type.data_type, | |||||
index.data_type) && | |||||
algo_desc.algo_type.format == index.format) { | |||||
algos.push_back(algo); | |||||
} | |||||
} | |||||
return algos; | |||||
} | |||||
std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms( | std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms( | ||||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { | const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { | ||||
std::vector<Algorithm*> gemm_algos, gemv_algos; | std::vector<Algorithm*> gemm_algos, gemv_algos; | ||||
@@ -71,17 +87,25 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( | |||||
"require reproducible algorithm, but given algorithm is not " | "require reproducible algorithm, but given algorithm is not " | ||||
"reproducible"); | "reproducible"); | ||||
} | } | ||||
auto algos = get_all_algorithms(A, B, C); | |||||
AlgoTypePack algo_type; | |||||
algo_type.data_type = kern_size_param.deduce_algo_data_type(); | |||||
algo_type.format = kern_size_param.format; | |||||
auto algos = select_algo_type(algo_type); | |||||
Algorithm *heuristic_algo = nullptr; | |||||
for (auto&& algo : algos) { | for (auto&& algo : algos) { | ||||
if (static_cast<AlgoBase*>(algo)->preferred_reproducible( | |||||
if (static_cast<AlgoBase*>(algo)->usable(kern_size_param) && | |||||
static_cast<AlgoBase*>(algo)->preferred_reproducible( | |||||
kern_size_param, reproducible) && | kern_size_param, reproducible) && | ||||
static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param) <= | static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param) <= | ||||
workspace_limit_in_bytes) { | workspace_limit_in_bytes) { | ||||
return algo; | |||||
if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { | |||||
return algo; | |||||
} else if (!heuristic_algo) { | |||||
heuristic_algo = algo; | |||||
} | |||||
} | } | ||||
} | } | ||||
return nullptr; | |||||
return heuristic_algo; | |||||
} | } | ||||
MatrixMulImpl::KernSizeParam MatrixMulImpl::make_kern_size_param( | MatrixMulImpl::KernSizeParam MatrixMulImpl::make_kern_size_param( | ||||
@@ -150,4 +174,34 @@ void MatrixMulImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||||
naive::MatrixMulForwardImpl::exec(A, B, C, workspace); | naive::MatrixMulForwardImpl::exec(A, B, C, workspace); | ||||
} | } | ||||
MatrixMulImpl::AlgoDataType | |||||
MatrixMulImpl::KernSizeParam::deduce_algo_data_type() const { | |||||
megdnn_assert(A_type.enumv() == B_type.enumv(), | |||||
"Matmul A type and B type of different ctype\n"); | |||||
if (A_type.enumv() == DTypeEnum::Float32) { | |||||
return MatrixMulImpl::AlgoDataType::FLOAT32; | |||||
#if !MEGDNN_DISABLE_FLOAT16 | |||||
} else if (A_type.enumv() == DTypeEnum::Float16) { | |||||
return MatrixMulImpl::AlgoDataType::FLOAT16; | |||||
#endif | |||||
} else if (A_type.enumv() == DTypeEnum::Int8 || | |||||
A_type.enumv() == DTypeEnum::QuantizedS8) { | |||||
if (C_type.enumv() == DTypeEnum::Int16) { | |||||
return MatrixMulImpl::AlgoDataType::INT8X8X16; | |||||
} else { | |||||
megdnn_assert(C_type.enumv() == DTypeEnum::Int32 || | |||||
C_type.enumv() == DTypeEnum::QuantizedS32); | |||||
return MatrixMulImpl::AlgoDataType::QINT8X8X32; | |||||
} | |||||
} else if (A_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||||
return MatrixMulImpl::AlgoDataType::QUINT8X8X32; | |||||
} else if (A_type.enumv() == DTypeEnum::Int16) { | |||||
return MatrixMulImpl::AlgoDataType::INT16X16X32; | |||||
} else { | |||||
megdnn_throw(ssprintf( | |||||
"megdnn matmul not support data type of %s * %s -> %s\n", | |||||
A_type.name(), B_type.name(), C_type.name())); | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -10,14 +10,23 @@ | |||||
* implied. | * implied. | ||||
*/ | */ | ||||
#pragma once | #pragma once | ||||
#include "megdnn/opr_param_defs.h" | |||||
#include "src/common/utils.h" | #include "src/common/utils.h" | ||||
#include "src/naive/matrix_mul/opr_impl.h" | #include "src/naive/matrix_mul/opr_impl.h" | ||||
#include <unordered_map> | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace fallback { | |||||
struct AlgoTypePack { | |||||
detail::AlgoDataType data_type : 32; | |||||
param::MatrixMul::Format format : 32; | |||||
}; | |||||
namespace fallback { | |||||
class MatrixMulImpl : public naive::MatrixMulForwardImpl { | class MatrixMulImpl : public naive::MatrixMulForwardImpl { | ||||
public: | public: | ||||
using naive::MatrixMulForwardImpl::MatrixMulForwardImpl; | using naive::MatrixMulForwardImpl::MatrixMulForwardImpl; | ||||
using AlgoDataType = detail::AlgoDataType; | |||||
bool is_thread_safe() const override { return true; } | bool is_thread_safe() const override { return true; } | ||||
@@ -34,6 +43,8 @@ public: | |||||
bool trA, trB; | bool trA, trB; | ||||
Param::ComputeMode compute_mode; | Param::ComputeMode compute_mode; | ||||
Param::Format format; | Param::Format format; | ||||
//! get the data type category of the param for select the algo | |||||
AlgoDataType deduce_algo_data_type() const; | |||||
}; | }; | ||||
struct KernParam : public KernSizeParam { | struct KernParam : public KernSizeParam { | ||||
@@ -110,6 +121,7 @@ public: | |||||
struct MatmulDescription { | struct MatmulDescription { | ||||
PackMode packmode; | PackMode packmode; | ||||
InnerBlockSize innerblocksize; | InnerBlockSize innerblocksize; | ||||
AlgoTypePack algo_type; | |||||
size_t packa_type_size; | size_t packa_type_size; | ||||
}; | }; | ||||
@@ -146,6 +158,11 @@ public: | |||||
*/ | */ | ||||
virtual SmallVector<AlgoBase*> algo_pack(); | virtual SmallVector<AlgoBase*> algo_pack(); | ||||
/** | |||||
* \brief select algo according to input algo type | |||||
*/ | |||||
SmallVector<AlgoBase*> select_algo_type(AlgoTypePack algo_type); | |||||
protected: | protected: | ||||
KernSizeParam make_kern_size_param(const TensorLayout& A, | KernSizeParam make_kern_size_param(const TensorLayout& A, | ||||
const TensorLayout& B, | const TensorLayout& B, | ||||
@@ -48,6 +48,10 @@ public: | |||||
} | } | ||||
void* type() const override; | void* type() const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
/* ===================== direct-stride2 algo ===================== */ | /* ===================== direct-stride2 algo ===================== */ | ||||
@@ -81,6 +85,10 @@ public: | |||||
} | } | ||||
void* type() const override; | void* type() const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
/* =========================== winograd ======================== */ | /* =========================== winograd ======================== */ | ||||
class ConvBiasImpl::AlgoFP32WinogradF63_8x8 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF63_8x8 final : public AlgoBase { | ||||
@@ -96,7 +104,7 @@ public: | |||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
void* type() const override; | void* type() const override; | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||||
}; | }; | ||||
class ConvBiasImpl::AlgoFP32WinogradF23_8x8 final : public AlgoBase { | class ConvBiasImpl::AlgoFP32WinogradF23_8x8 final : public AlgoBase { | ||||
@@ -112,7 +120,7 @@ public: | |||||
return m_name.c_str(); | return m_name.c_str(); | ||||
} | } | ||||
void* type() const override; | void* type() const override; | ||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||||
}; | }; | ||||
/* ===================== matmul algo ===================== */ | /* ===================== matmul algo ===================== */ | ||||
@@ -151,6 +159,9 @@ public: | |||||
} | } | ||||
void* type() const override; | void* type() const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::FLOAT32, AlgoCategory::IM2COL}; | |||||
} | |||||
}; | }; | ||||
#if MEGDNN_X86_WITH_MKL_DNN | #if MEGDNN_X86_WITH_MKL_DNN | ||||
@@ -192,6 +203,10 @@ public: | |||||
return {{kern, {1_z, 1_z, 1_z}}}; | return {{kern, {1_z, 1_z, 1_z}}}; | ||||
} | } | ||||
void* type() const override; | void* type() const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
#endif | #endif | ||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -224,8 +224,6 @@ bool mkldnn_matmul_qint8_preferred( | |||||
const ConvBiasImpl::NCBKernSizeParam& param) { | const ConvBiasImpl::NCBKernSizeParam& param) { | ||||
auto is_preferred = true; | auto is_preferred = true; | ||||
auto&& fm = param.filter_meta; | auto&& fm = param.filter_meta; | ||||
megdnn_assert_internal(fm.group == 1 && fm.dilation[0] == 1 && | |||||
fm.dilation[1] == 1); | |||||
// single channel conv should never use matrix mul | // single channel conv should never use matrix mul | ||||
if (fm.ocpg == 1 || fm.icpg == 1) | if (fm.ocpg == 1 || fm.icpg == 1) | ||||
@@ -34,6 +34,10 @@ public: | |||||
} | } | ||||
void* type() const override; | void* type() const override; | ||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
/* ===================== avx2 stride2 chanwise algo ===================== */ | /* ===================== avx2 stride2 chanwise algo ===================== */ | ||||
@@ -55,6 +59,10 @@ public: | |||||
} | } | ||||
void* type() const override; | void* type() const override; | ||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
/* ===================== avx2 stride1 direct algo ===================== */ | /* ===================== avx2 stride1 direct algo ===================== */ | ||||
@@ -76,6 +84,10 @@ public: | |||||
} | } | ||||
void* type() const override; | void* type() const override; | ||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
/* ================== avx2 int8 direct conv stride2 algo ================== */ | /* ================== avx2 int8 direct conv stride2 algo ================== */ | ||||
@@ -97,6 +109,10 @@ public: | |||||
} | } | ||||
void* type() const override; | void* type() const override; | ||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
#if MEGDNN_X86_WITH_MKL_DNN | #if MEGDNN_X86_WITH_MKL_DNN | ||||
@@ -134,6 +150,10 @@ public: | |||||
} | } | ||||
void* type() const override; | void* type() const override; | ||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||||
} | |||||
}; | }; | ||||
/* ===================== mkldnn qint8 matmul algo ===================== */ | /* ===================== mkldnn qint8 matmul algo ===================== */ | ||||
class ConvBiasImpl::AlgoMkldnnMatmulQint8 final : public AlgoBase { | class ConvBiasImpl::AlgoMkldnnMatmulQint8 final : public AlgoBase { | ||||
@@ -160,6 +180,10 @@ public: | |||||
bool is_preferred(const NCBKernSizeParam& param) const override; | bool is_preferred(const NCBKernSizeParam& param) const override; | ||||
void* type() const override; | void* type() const override; | ||||
ConvAlgoTypePack get_algo_type() const override { | |||||
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; | |||||
} | |||||
}; | }; | ||||
#endif | #endif | ||||
@@ -103,10 +103,10 @@ public: | |||||
#endif | #endif | ||||
all_algos.emplace_back(&stride1_direct); | all_algos.emplace_back(&stride1_direct); | ||||
all_algos.emplace_back(&stride2_direct); | all_algos.emplace_back(&stride2_direct); | ||||
all_algos.emplace_back(&avx2_stride1_direct_int8); | |||||
all_algos.emplace_back(&avx2_stride2_direct); | |||||
all_algos.emplace_back(&avx2_stride1_chanwsie_qint8); | all_algos.emplace_back(&avx2_stride1_chanwsie_qint8); | ||||
all_algos.emplace_back(&avx2_stride2_chanwsie_qint8); | all_algos.emplace_back(&avx2_stride2_chanwsie_qint8); | ||||
all_algos.emplace_back(&avx2_stride1_direct_int8); | |||||
all_algos.emplace_back(&avx2_stride2_direct); | |||||
all_algos.emplace_back(&matmul); | all_algos.emplace_back(&matmul); | ||||
static CpuOprDelegationStorage<> storage; | static CpuOprDelegationStorage<> storage; | ||||
@@ -182,4 +182,41 @@ bool ConvBiasImpl::is_matmul_quantized_prefer( | |||||
!chanwise_avx2_stride2_qint8_usable_preferred(param)); | !chanwise_avx2_stride2_qint8_usable_preferred(param)); | ||||
} | } | ||||
SmallVector<AlgoCategory> | |||||
ConvBiasImpl::suggest_algo_category_order(const NCBKernSizeParam& param) const { | |||||
auto IC = param.filter_meta.icpg; | |||||
auto OC = param.filter_meta.ocpg; | |||||
auto FH = param.filter_meta.spatial[0]; | |||||
auto FW = param.filter_meta.spatial[1]; | |||||
//! TODO: now winograd only support fast-run | |||||
if (param.filter_meta.format == param::ConvBias::Format::NCHW_WINOGRAD || | |||||
param.filter_meta.format == param::ConvBias::Format::NCHW44_WINOGRAD || | |||||
param.filter_meta.format == param::ConvBias::Format::NCHW88_WINOGRAD) { | |||||
return {AlgoCategory::WINOGRAD}; | |||||
} | |||||
//! nchw88 use mkl-dnn which algo is direct | |||||
if (param.filter_meta.format == param::ConvBias::Format::NCHW88) { | |||||
return {AlgoCategory::DIRECT, AlgoCategory::IM2COL}; | |||||
} | |||||
//! im2col + matmul | |||||
bool im2col_prefer = (IC >= 32 || OC >= 32); | |||||
//! quantized algo use matmul when direct algo is unusable | |||||
if (param.src_type.category() == DTypeCategory::QUANTIZED) { | |||||
im2col_prefer = is_matmul_quantized_prefer(param); | |||||
} | |||||
//! conv1x1 | |||||
im2col_prefer |= (FH == 1 && FW == 1); | |||||
//! x86 8x8x16 not optmized, so it will use fallback im2col+matmul | |||||
if (param.deduce_algo_data_type() == AlgoDataType::INT8X8X16) { | |||||
im2col_prefer = true; | |||||
} | |||||
if (im2col_prefer) { | |||||
return {AlgoCategory::IM2COL, AlgoCategory::DIRECT, | |||||
AlgoCategory::NAIVE}; | |||||
} else { | |||||
return {AlgoCategory::DIRECT, AlgoCategory::IM2COL, | |||||
AlgoCategory::NAIVE}; | |||||
} | |||||
} | |||||
// vim: syntax=cpp.doxygen | // vim: syntax=cpp.doxygen |
@@ -24,6 +24,8 @@ public: | |||||
bool is_thread_safe() const override { return true; } | bool is_thread_safe() const override { return true; } | ||||
SmallVector<AlgoBase*> algo_pack() override; | SmallVector<AlgoBase*> algo_pack() override; | ||||
SmallVector<AlgoCategory> suggest_algo_category_order( | |||||
const NCBKernSizeParam& param) const override; | |||||
class AlgoDirect; | class AlgoDirect; | ||||
class AlgoDirectStride2; | class AlgoDirectStride2; | ||||
@@ -184,11 +184,10 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Vnni::get_kern( | |||||
return int8x8x32_kern_vnni; | return int8x8x32_kern_vnni; | ||||
} | } | ||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32Vnni, | |||||
megdnn_x86_matmul_kern, | |||||
"AlgoInt8x8x32Vnni"_hash, | |||||
x86::matmul::gemm_int8_vnni_12x32x4, | |||||
dt_int8, dt_int32, dt_uint8); | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( | |||||
AlgoInt8x8x32Vnni, megdnn_x86_matmul_kern, "AlgoInt8x8x32Vnni"_hash, | |||||
x86::matmul::gemm_int8_vnni_12x32x4, dt_int8, dt_int32, | |||||
dt_uint8AlgoDataType::QINT8X8X32, DEFAULT); | |||||
#endif | #endif | ||||
/* ===================== Int8 mkldnn algo ===================== */ | /* ===================== Int8 mkldnn algo ===================== */ | ||||
@@ -397,7 +396,8 @@ size_t MatrixMulImpl::AlgoInt8x8x16AVX2::get_workspace( | |||||
} | } | ||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( | ||||
AlgoInt8x8x16AVX2, megdnn_x86_matmul_kern, "AlgoInt8x8x16AVX2"_hash, | AlgoInt8x8x16AVX2, megdnn_x86_matmul_kern, "AlgoInt8x8x16AVX2"_hash, | ||||
x86::matmul::gemm_avx2_s8s8s16_4x16x2, dt_int8, dt_int16, dt_int16); | |||||
x86::matmul::gemm_avx2_s8s8s16_4x16x2, dt_int8, dt_int16, dt_int16, | |||||
AlgoDataType::INT8X8X16, DEFAULT); | |||||
/*************************AlgoInt8x8x16SSE********************/ | /*************************AlgoInt8x8x16SSE********************/ | ||||
void MatrixMulImpl::AlgoInt8x8x16SSE::gemm_s8s8s16_sse_4x8x2( | void MatrixMulImpl::AlgoInt8x8x16SSE::gemm_s8s8s16_sse_4x8x2( | ||||
@@ -474,7 +474,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16SSE, | |||||
megdnn_x86_matmul_kern, | megdnn_x86_matmul_kern, | ||||
"AlgoInt8x8x16SSE"_hash, | "AlgoInt8x8x16SSE"_hash, | ||||
x86::matmul::gemm_sse_s8s8s16_4x8x2, | x86::matmul::gemm_sse_s8s8s16_4x8x2, | ||||
dt_int8, dt_int16, dt_int16); | |||||
dt_int8, dt_int16, dt_int16, | |||||
AlgoDataType::INT8X8X16, DEFAULT); | |||||
/*************************AlgoInt8x8x32AVX2M4N16K2********************/ | /*************************AlgoInt8x8x32AVX2M4N16K2********************/ | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern( | ||||
@@ -516,7 +517,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_workspace( | |||||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( | MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( | ||||
AlgoInt8x8x32AVX2M4N16K2, megdnn_x86_matmul_kern, | AlgoInt8x8x32AVX2M4N16K2, megdnn_x86_matmul_kern, | ||||
"AlgoInt8x8x32AVX2M4N16K2"_hash, x86::matmul::gemm_avx2_s8s8s32_4x16x2, | "AlgoInt8x8x32AVX2M4N16K2"_hash, x86::matmul::gemm_avx2_s8s8s32_4x16x2, | ||||
dt_int8, dt_int32, dt_int16); | |||||
dt_int8, dt_int32, dt_int16, AlgoDataType::QINT8X8X32, DEFAULT); | |||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_kern( | ||||
const KernSizeParam&) const { | const KernSizeParam&) const { | ||||
@@ -556,7 +557,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32AVX2M2N4K16, | |||||
megdnn_x86_matmul_kern, | megdnn_x86_matmul_kern, | ||||
"AlgoInt8x8x32AVX2M2N4K16"_hash, | "AlgoInt8x8x32AVX2M2N4K16"_hash, | ||||
x86::matmul::gemm_avx2_s8s8s32_2x4x16, | x86::matmul::gemm_avx2_s8s8s32_2x4x16, | ||||
dt_int8, dt_int32); | |||||
dt_int8, dt_int32, | |||||
AlgoDataType::QINT8X8X32, DEFAULT); | |||||
/*************************AlgoInt8x8x32SSEM4N8K2********************/ | /*************************AlgoInt8x8x32SSEM4N8K2********************/ | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_kern( | ||||
@@ -596,7 +598,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32SSEM4N8K2, | |||||
megdnn_x86_matmul_kern, | megdnn_x86_matmul_kern, | ||||
"AlgoInt8x8x32SSEM4N8K2"_hash, | "AlgoInt8x8x32SSEM4N8K2"_hash, | ||||
x86::matmul::gemm_sse_s8s8s32_4x8x2, | x86::matmul::gemm_sse_s8s8s32_4x8x2, | ||||
dt_int8, dt_int32, dt_int16); | |||||
dt_int8, dt_int32, dt_int16, | |||||
AlgoDataType::QINT8X8X32, DEFAULT); | |||||
/*************************AlgoF32MK8_8x8********************/ | /*************************AlgoF32MK8_8x8********************/ | ||||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK8_8x8::get_kern( | MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK8_8x8::get_kern( | ||||
@@ -27,7 +27,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_x86_algo_type; } | void* type() const override { return sm_x86_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) | |||||
}; | }; | ||||
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM | #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM | ||||
@@ -49,7 +49,7 @@ public: | |||||
WorkspaceBundle get_bundle(const KernSizeParam& param) const override; | WorkspaceBundle get_bundle(const KernSizeParam& param) const override; | ||||
InnerBlockSize get_inner_block_size() const override{ return {8, 16, 1}; }; | InnerBlockSize get_inner_block_size() const override{ return {8, 16, 1}; }; | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) | |||||
}; | }; | ||||
#endif | #endif | ||||
@@ -127,7 +127,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_x86_algo_type; } | void* type() const override { return sm_x86_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4, AlgoDataType::FLOAT32, MK8) | |||||
}; | }; | ||||
#if MEGDNN_X86_WITH_VNNI | #if MEGDNN_X86_WITH_VNNI | ||||
@@ -153,7 +153,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_x86_algo_type; } | void* type() const override { return sm_x86_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) | |||||
}; | }; | ||||
#endif | #endif | ||||
} // namespace x86 | } // namespace x86 | ||||
@@ -495,8 +495,9 @@ class AlgoChooser { | |||||
} | } | ||||
} | } | ||||
mgb_assert(found, | mgb_assert(found, | ||||
"algo got by heuristic not found in " | |||||
"candidate list"); | |||||
"algo %s got by heuristic not found in " | |||||
"candidate list", | |||||
heu->name()); | |||||
return std::move(ret); | return std::move(ret); | ||||
} | } | ||||
@@ -628,7 +629,7 @@ public: | |||||
auto algo = get_algo(ctx); | auto algo = get_algo(ctx); | ||||
size_t workspace = ctx.get_workspace_size_bytes(algo); | size_t workspace = ctx.get_workspace_size_bytes(algo); | ||||
mgb_log_debug( | mgb_log_debug( | ||||
"%s: input shapes (%s %s, %s %s) -> (%s %s): algo=%s " | |||||
"%s:tensor layouts (%s %s, %s %s)->(%s %s) :algo=%s " | |||||
"workspace=%.2fMiB reproducible=%d", | "workspace=%.2fMiB reproducible=%d", | ||||
mgb_opr->dyn_typeinfo()->name, | mgb_opr->dyn_typeinfo()->name, | ||||
layouts[0].to_string().c_str(), | layouts[0].to_string().c_str(), | ||||
@@ -636,8 +637,7 @@ public: | |||||
layouts[1].to_string().c_str(), | layouts[1].to_string().c_str(), | ||||
layouts[1].dtype.name(), | layouts[1].dtype.name(), | ||||
layouts[layouts.size() - 1].to_string().c_str(), | layouts[layouts.size() - 1].to_string().c_str(), | ||||
layouts[layouts.size() - 1].dtype.name(), | |||||
algo->name(), | |||||
layouts[layouts.size() - 1].dtype.name(), algo->name(), | |||||
workspace / (1024 * 1024.0), algo->is_reproducible()); | workspace / (1024 * 1024.0), algo->is_reproducible()); | ||||
megdnn_opr->execution_policy() = {algo}; | megdnn_opr->execution_policy() = {algo}; | ||||
return workspace; | return workspace; | ||||