GitOrigin-RevId: 60d2646bb3
release-1.1
@@ -76,6 +76,18 @@ enum class AlgoSelectionStrategy { | |||
FULL_RUN = 2, | |||
}; | |||
/** | |||
* \brief separate algo by datatype for Matmul and conv | |||
*/ | |||
enum class AlgoDataType : uint32_t { | |||
FLOAT32 = 1 << 0, | |||
FLOAT16 = 1 << 1, | |||
QINT8X8X32 = 1 << 2, | |||
QUINT8X8X32 = 1 << 3, | |||
INT8X8X16 = 1 << 4, | |||
INT16X16X32 = 1 << 5, | |||
}; | |||
/*! | |||
* \brief Abstract representation of an algorithm for implementing | |||
* the operator | |||
@@ -27,6 +27,10 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
@@ -32,6 +32,10 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
} // namespace aarch64 | |||
@@ -45,6 +45,9 @@ public: | |||
return static_cast<ConvBiasImpl*>(conv_bias_opr) | |||
->is_matmul_quantized_prefer(param); | |||
} | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; | |||
} | |||
}; | |||
} // namespace aarch64 | |||
@@ -50,10 +50,9 @@ SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||
auto&& algos = arm_common::ConvBiasImpl::algo_pack(); | |||
algos.insert(algos.begin(), sl_algo_pack.direct_algos.begin(), | |||
sl_algo_pack.direct_algos.end()); | |||
//! We put matmul algos at the end. Because matmul will get privilege when | |||
//! We put matmul algos at the begin. Because matmul will get privilege when | |||
//! prefer return true. See | |||
//! fallback::ConvolutionImpl::ncb_1g_get_all_algorithms for more details. | |||
algos.insert(algos.end(), sl_algo_pack.matmul_algos.begin(), | |||
algos.insert(algos.begin(), sl_algo_pack.matmul_algos.begin(), | |||
sl_algo_pack.matmul_algos.end()); | |||
return std::move(algos); | |||
} | |||
@@ -45,6 +45,9 @@ public: | |||
return static_cast<ConvBiasImpl*>(conv_bias_opr) | |||
->is_matmul_quantized_prefer(param); | |||
} | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL}; | |||
} | |||
}; | |||
} // namespace aarch64 | |||
} // namespace megdnn | |||
@@ -89,7 +89,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern( | |||
} | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_aarch64_matmul_kern, | |||
"AlgoF32K8x12x1Impl"_hash, | |||
aarch64::matmul::sgemm_8x12, float, float); | |||
aarch64::matmul::sgemm_8x12, float, float, | |||
AlgoDataType::FLOAT32, DEFAULT); | |||
/* ===================== F32_MK4_8X12X1 algo ===================== */ | |||
bool MatrixMulImpl::AlgoF32MK4_8x12x1::usable( | |||
@@ -151,7 +152,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4_8x12x1, | |||
megdnn_aarch64_matmul_kern, | |||
"AlgoF32MK4_8x12x1Impl"_hash, | |||
aarch64::matmul::sgemm_mk4_8x12, float, | |||
float); | |||
float, AlgoDataType::FLOAT32, MK4); | |||
/* ===================== F32K4X16X1 algo ===================== */ | |||
@@ -210,7 +211,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K4x16x1::get_kern( | |||
} | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K4x16x1, megdnn_aarch64_matmul_kern, | |||
"AlgoF32K4x16x1Impl"_hash, | |||
aarch64::matmul::sgemm_4x16, float, float); | |||
aarch64::matmul::sgemm_4x16, float, float, | |||
AlgoDataType::FLOAT32, MK4); | |||
/* ===================== F32MK4_4x16 algo ===================== */ | |||
@@ -328,7 +330,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K8x24x1::get_kern( | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K8x24x1, megdnn_aarch64_matmul_kern, | |||
"AlogF16K8x24x1Impl"_hash, | |||
aarch64::matmul::hgemm_8x24, dt_float16, | |||
dt_float16); | |||
dt_float16, AlgoDataType::FLOAT16, | |||
DEFAULT); | |||
/* ===================== F16_MK8_8x8 algo ===================== */ | |||
bool MatrixMulImpl::AlgoF16MK8_8x8::usable( | |||
@@ -449,7 +452,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x12x4DotProd, | |||
megdnn_aarch64_matmul_kern, | |||
"AlgoInt8x8x32K8x12x4DotProdImpl"_hash, | |||
aarch64::matmul::gemm_s8_8x12, int8_t, | |||
int32_t); | |||
int32_t, AlgoDataType::QINT8X8X32, | |||
DEFAULT); | |||
/* =================== Int8x8x32 MK4 8X12X4 Dotprod algo =================== */ | |||
namespace { | |||
@@ -520,7 +524,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x12x4DotProd, | |||
megdnn_aarch64_matmul_kern, | |||
"AlgoInt8x8x32MK4_8x12x4DotProdImpl"_hash, | |||
aarch64::matmul::gemm_mk4_s8_8x12, int8_t, | |||
int32_t); | |||
int32_t, AlgoDataType::QINT8X8X32, | |||
MK4_DOT); | |||
#else | |||
/* ===================== Int8x8x32 MK4 4x4x16 algo ===================== */ | |||
@@ -593,7 +598,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_4x4x16, | |||
megdnn_aarch64_matmul_kern, | |||
"AlgoInt8x8x32MK4_4x4x16Impl"_hash, | |||
aarch64::matmul::gemm_mk4_s8_4x4, int8_t, | |||
int32_t); | |||
int32_t, AlgoDataType::QINT8X8X32, | |||
MK4); | |||
/* ===================== Int8x8x32 K4x4x16 algo ===================== */ | |||
namespace { | |||
@@ -656,7 +662,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x4x16, | |||
megdnn_aarch64_matmul_kern, | |||
"AlgoInt8x8x32K4x4x16Impl"_hash, | |||
aarch64::matmul::gemm_s8_4x4, int8_t, | |||
int32_t); | |||
int32_t, AlgoDataType::QINT8X8X32, | |||
DEFAULT); | |||
/* ===================== Int8x8x32 K8x8x8 algo ===================== */ | |||
namespace { | |||
void int8x8x32_k8x8x8_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
@@ -717,7 +724,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K8x8x8, | |||
megdnn_aarch64_matmul_kern, | |||
"AlgoInt8x8x32K8x8x8Impl"_hash, | |||
aarch64::matmul::gemm_s8_8x8, int8_t, | |||
int32_t); | |||
int32_t, AlgoDataType::QINT8X8X32, | |||
DEFAULT); | |||
#endif | |||
/* ===================== Int8x8x16 K8x8x8 algo ===================== */ | |||
@@ -785,7 +793,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K8x8x8, | |||
megdnn_aarch64_matmul_kern, | |||
"AlgoInt8x8x16K8x8x8Impl"_hash, | |||
aarch64::matmul::gemm_s8x8x16_8x8, int8_t, | |||
int16_t); | |||
int16_t, AlgoDataType::INT8X8X16, DEFAULT); | |||
/* ===================== Int8x8x16 K4x4x16 algo ===================== */ | |||
namespace { | |||
void int8x8x16_k4x4x16_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
@@ -852,7 +860,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x4x16, | |||
megdnn_aarch64_matmul_kern, | |||
"AlgoInt8x8x16K4x4x16Impl"_hash, | |||
aarch64::matmul::gemm_s8x8x16_4x4, int8_t, | |||
int16_t); | |||
int16_t, AlgoDataType::INT8X8X16, DEFAULT); | |||
/* ===================== Int8x8x16 K16x12x4 algo ===================== */ | |||
namespace { | |||
@@ -929,7 +937,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4::get_kern( | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( | |||
AlgoInt8x8x16MK4_16x12x4, megdnn_aarch64_matmul_kern, | |||
"AlgoInt8x8x16MK4_16x12x4Impl"_hash, | |||
aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53, int8_t, int16_t, int16_t); | |||
aarch64::matmul::gemm_s8x8x16_mk4_16x12_a53, int8_t, int16_t, int16_t, | |||
AlgoDataType::INT8X8X16, MK4); | |||
/* ===================== Int8x8x16 MK4 4x4x8 algo ===================== */ | |||
namespace { | |||
@@ -1007,7 +1016,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_4x4x8, | |||
megdnn_aarch64_matmul_kern, | |||
"AlgoInt8x8x16MK4_4x4x8_Impl"_hash, | |||
aarch64::matmul::gemm_s8x8x16_mk4_4x4_a72, | |||
int8_t, int16_t); | |||
int8_t, int16_t, AlgoDataType::INT8X8X16, | |||
MK4); | |||
/* ===================== Int16x16x32 K12x8x1 algo ===================== */ | |||
namespace { | |||
@@ -1078,7 +1088,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x8x1, | |||
megdnn_aarch64_matmul_kern, | |||
"AlgoInt16x16x32K12x8x1Impl"_hash, | |||
aarch64::matmul::gemm_s16_12x8x1, int16_t, | |||
int32_t); | |||
int32_t, AlgoDataType::INT16X16X32, | |||
DEFAULT); | |||
/* ===================== Int16x16x32MK8_8x8 algo ===================== */ | |||
@@ -1201,7 +1212,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x4DotProd, | |||
megdnn_aarch64_matmul_kern, | |||
"AlgoQuint8K8x8x4DotProdImpl"_hash, | |||
aarch64::matmul::gemm_u8_8x8, uint8_t, | |||
int32_t); | |||
int32_t, AlgoDataType::QUINT8X8X32, | |||
DEFAULT); | |||
/* ===================== Quint8 Gemv DotProd algo ===================== */ | |||
namespace { | |||
void quint8_gemv_dotprod_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
@@ -1307,7 +1319,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K8x8x8, | |||
megdnn_aarch64_matmul_kern, | |||
"AlgoQuint8K8x8x8Impl"_hash, | |||
aarch64::matmul::gemm_u8_8x8, uint8_t, | |||
int32_t); | |||
int32_t, AlgoDataType::QUINT8X8X32, | |||
DEFAULT); | |||
#endif | |||
/* ===================== Int8x8x16 K8x8x8 algo ===================== */ | |||
@@ -1378,6 +1391,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8::get_kern( | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16MK4_K8x8x8, | |||
megdnn_aarch64_matmul_kern, | |||
"AlgoInt8x8x16MK4_K8x8x8Impl"_hash, | |||
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8, int8_t, | |||
int16_t); | |||
aarch64::matmul::gemm_s8x8x16_mk4_8x8x8, | |||
int8_t, int16_t, AlgoDataType::INT8X8X16, | |||
MK4); | |||
// vim: syntax=cpp.doxygen |
@@ -61,7 +61,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4) | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4, AlgoDataType::FLOAT32, MK4) | |||
}; | |||
class MatrixMulImpl::AlgoF32Gemv final | |||
@@ -88,7 +88,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2) | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::FLOAT16, MK8) | |||
}; | |||
#endif | |||
@@ -253,7 +253,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2) | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) | |||
}; | |||
#if __ARM_FEATURE_DOTPROD | |||
@@ -281,7 +281,7 @@ public: | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QUINT8X8X32, DEFAULT) | |||
}; | |||
#else | |||
@@ -29,7 +29,7 @@ public: | |||
} | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | |||
}; | |||
class ConvBiasImpl::AlgoFP16WinogradF45 final : public AlgoBase { | |||
@@ -44,7 +44,7 @@ public: | |||
} | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | |||
}; | |||
class ConvBiasImpl::AlgoFP16WinogradF63 final : public AlgoBase { | |||
@@ -60,7 +60,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | |||
}; | |||
class ConvBiasImpl::AlgoFP16WinogradF23_8x8 final : public AlgoBase { | |||
public: | |||
@@ -74,7 +74,7 @@ public: | |||
} | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT16); | |||
}; | |||
class ConvBiasImpl::AlgoF16Direct final : public AlgoBase { | |||
@@ -90,6 +90,10 @@ public: | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override{ | |||
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoF16DirectStride1 final : public AlgoBase { | |||
@@ -103,6 +107,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT16, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
} // namespace arm_common | |||
@@ -29,7 +29,7 @@ public: | |||
} | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF63 final : public AlgoBase { | |||
@@ -44,7 +44,7 @@ public: | |||
} | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF63_4x4 final : public AlgoBase { | |||
@@ -59,7 +59,7 @@ public: | |||
} | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF54 final : public AlgoBase { | |||
@@ -74,7 +74,7 @@ public: | |||
} | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF45 final : public AlgoBase { | |||
@@ -89,7 +89,7 @@ public: | |||
} | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
}; | |||
//===================== NCHW44 Winograd Support =====================// | |||
@@ -106,7 +106,7 @@ public: | |||
} | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF63_4x4_NCHW44 final : public AlgoBase { | |||
@@ -122,7 +122,7 @@ public: | |||
} | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF73_4x4_NCHW44 final : public AlgoBase { | |||
@@ -138,7 +138,7 @@ public: | |||
} | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
}; | |||
// ================================================================= // | |||
@@ -154,6 +154,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoF32DirectStride1 final : public AlgoBase { | |||
@@ -168,6 +171,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoF32DirectStride2 final : public AlgoBase { | |||
@@ -182,6 +188,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoF32DirectNCHW44 final : public AlgoBase { | |||
@@ -197,6 +206,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoF32DirectNCHWNCHW44 final : public AlgoBase { | |||
@@ -212,6 +224,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoF32ChannelWiseNCHW44 final : public AlgoBase { | |||
@@ -226,6 +241,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
} // namespace arm_common | |||
@@ -29,6 +29,10 @@ public: | |||
const NCBKernSizeParam& param) const override; | |||
bool is_preferred(const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoS8DirectStride2 final : public AlgoBase { | |||
@@ -42,6 +46,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoS8DirectNCHW44 final : public AlgoBase { | |||
@@ -55,6 +62,9 @@ public: | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
bool is_preferred(const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoS8DirectNCHWNCHW44 final : public AlgoBase { | |||
@@ -68,6 +78,9 @@ public: | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
bool is_preferred(const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoS8ChanWiseStride1NCHW44 final : public AlgoBase { | |||
@@ -79,6 +92,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoS8ChanWiseStride2NCHW44 final : public AlgoBase { | |||
@@ -90,6 +106,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
#if __ARM_FEATURE_DOTPROD | |||
@@ -104,6 +123,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam&) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoDotS8DirectStride1 final : public AlgoBase { | |||
@@ -117,6 +139,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam&) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoDotS8DirectStride2 final : public AlgoBase { | |||
@@ -131,6 +156,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam&) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoDotS8Direct_NCHW44 final : public AlgoBase { | |||
@@ -148,6 +176,10 @@ public: | |||
const NCBKernSizeParam& param) const override; | |||
bool is_preferred(const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
#endif | |||
@@ -163,7 +195,7 @@ public: | |||
} | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | |||
}; | |||
//=======================input int8 compute fp32 output int8============ | |||
@@ -180,7 +212,7 @@ public: | |||
} | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | |||
}; | |||
//=======================input int8 compute int16 output int8============ | |||
@@ -198,7 +230,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::QINT8X8X32); | |||
}; | |||
} // namespace arm_common | |||
@@ -36,6 +36,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoS8x8x16DirectNCHW44 final : public AlgoBase { | |||
@@ -48,6 +51,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoI8x8x16Stride2 final : public AlgoBase { | |||
@@ -71,6 +77,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoI8x8x16Stride2Filter2 final : public AlgoBase { | |||
@@ -84,6 +93,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoS8x8x16ChanWiseStride1Stride2NCHW44 final : public AlgoBase { | |||
@@ -96,6 +108,9 @@ public: | |||
const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoI8x8x16DirectNCHWNCHW44 final : public AlgoBase { | |||
@@ -111,6 +126,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::INT8X8X16, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
} // namespace arm_common | |||
@@ -10,6 +10,7 @@ | |||
* implied. | |||
*/ | |||
#include "megdnn/opr_param_defs.h" | |||
#include "src/arm_common/conv_bias/int8/algos.h" | |||
#include "src/arm_common/conv_bias/int8x8x16/algos.h" | |||
#include "src/arm_common/conv_bias/quint8/algos.h" | |||
@@ -122,9 +123,11 @@ public: | |||
static CpuOprDelegationStorage<2> storage; | |||
auto matmul_opr = storage.get<MatrixMul, 0>(); | |||
using MatmulFormat = param::MatrixMul::Format; | |||
auto&& matmul_algos = | |||
static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | |||
->algo_pack(); | |||
->select_algo_type( | |||
{AlgoDataType::FLOAT32, MatmulFormat::MK4}); | |||
for (auto&& algo : matmul_algos) { | |||
if (algo->type() == nullptr) | |||
continue; | |||
@@ -133,38 +136,62 @@ public: | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF63( | |||
refhold.emplace_back(new AlgoFP32WinogradF63_4x4( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF63_4x4( | |||
refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF54( | |||
refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF45( | |||
//! uncomment this when low precision mode is done | |||
#if 0 | |||
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF23_4x4_NCHW44( | |||
#endif | |||
//! Qint8x8x32 winograd compute with fp32 | |||
refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoFP32WinogradF63_4x4_NCHW44( | |||
} | |||
} | |||
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | |||
->select_algo_type({AlgoDataType::FLOAT32, | |||
MatmulFormat::DEFAULT}); | |||
for (auto&& algo : matmul_algos) { | |||
if (algo->type() == nullptr) | |||
continue; | |||
for (uint32_t tile_size : {16, 8, 24, 32}) { | |||
refhold.emplace_back(new AlgoFP32WinogradF63( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
//! uncomment this when low precision mode is done | |||
#if 0 | |||
refhold.emplace_back(new AlgoFP32WinogradF73_4x4_NCHW44( | |||
refhold.emplace_back(new AlgoFP32WinogradF54( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
#endif | |||
refhold.emplace_back(new AlgoFP32WinogradF45( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
} | |||
} | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | |||
->select_algo_type({AlgoDataType::FLOAT16, | |||
MatmulFormat::DEFAULT}); | |||
for (auto&& algo : matmul_algos) { | |||
if (algo->type() == nullptr) | |||
continue; | |||
for (uint32_t tile_size : {16, 8, 24, 32}) { | |||
refhold.emplace_back(new AlgoFP16WinogradF23( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
@@ -177,19 +204,33 @@ public: | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
} | |||
} | |||
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | |||
->select_algo_type({AlgoDataType::FLOAT16, | |||
MatmulFormat::MK8}); | |||
for (auto&& algo : matmul_algos) { | |||
if (algo->type() == nullptr) | |||
continue; | |||
for (uint32_t tile_size : {16, 8, 24, 32}) { | |||
refhold.emplace_back(new AlgoFP16WinogradF23_8x8( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
} | |||
} | |||
#endif | |||
matmul_algos = static_cast<arm_common::MatrixMulImpl*>(matmul_opr) | |||
->select_algo_type({AlgoDataType::INT16X16X32, | |||
MatmulFormat::MK8}); | |||
for (auto&& algo : matmul_algos) { | |||
if (algo->type() == nullptr) | |||
continue; | |||
for (uint32_t tile_size : {16, 8, 24, 32}) { | |||
refhold.emplace_back(new AlgoS8WinogradF23_8x8( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoS8CF32WinogradF23_4x4_NCHW44( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
winograd_algos.emplace_back(refhold.back().get()); | |||
refhold.emplace_back(new AlgoS8WinogradF23_8x8_NCHW44( | |||
static_cast<fallback::MatrixMulImpl::AlgoBase*>(algo), | |||
tile_size)); | |||
@@ -240,6 +281,42 @@ bool ConvBiasImpl::is_matmul_quantized_prefer( | |||
return conv_direct_unusable; | |||
} | |||
SmallVector<AlgoCategory> ConvBiasImpl::suggest_algo_category_order( | |||
const NCBKernSizeParam& param) const { | |||
auto IC = param.filter_meta.icpg; | |||
auto OC = param.filter_meta.ocpg; | |||
auto FH = param.filter_meta.spatial[0]; | |||
auto FW = param.filter_meta.spatial[1]; | |||
//! TODO: now winograd only support fast-run | |||
if (param.filter_meta.format == param::ConvBias::Format::NCHW_WINOGRAD || | |||
param.filter_meta.format == param::ConvBias::Format::NCHW44_WINOGRAD || | |||
param.filter_meta.format == param::ConvBias::Format::NCHW88_WINOGRAD) { | |||
return {AlgoCategory::WINOGRAD}; | |||
} | |||
//! im2col | |||
bool im2col_prefer = (IC >= 32 || OC >= 32); | |||
//! quantized algo use matmul when direct algo is unusable | |||
if (param.src_type.category() == DTypeCategory::QUANTIZED) { | |||
im2col_prefer = is_matmul_quantized_prefer(param); | |||
} | |||
//! conv1x1 | |||
im2col_prefer |= (FH == 1 && FW == 1); | |||
//! nchw44 and nchw44-dot hybird mode is direct | |||
if (param.filter_meta.format == param::ConvBias::Format::NCHW44 || | |||
param.filter_meta.format == param::ConvBias::Format::NCHW44_DOT) { | |||
if (IC < 4) { | |||
im2col_prefer = false; | |||
} | |||
} | |||
if (im2col_prefer) { | |||
return {AlgoCategory::IM2COL, AlgoCategory::DIRECT, | |||
AlgoCategory::NAIVE}; | |||
} else { | |||
return {AlgoCategory::DIRECT, AlgoCategory::IM2COL, | |||
AlgoCategory::NAIVE}; | |||
} | |||
} | |||
const char* ConvBiasImpl::get_algorithm_set_name() const { | |||
// arm common version 0 | |||
return "AC0"; | |||
@@ -28,6 +28,9 @@ public: | |||
bool is_matmul_quantized_prefer( | |||
const ConvBiasImpl::NCBKernSizeParam& ncb_param) const override; | |||
SmallVector<AlgoCategory> suggest_algo_category_order( | |||
const NCBKernSizeParam& param) const override; | |||
class AlgoPack; | |||
protected: | |||
@@ -90,7 +93,7 @@ private: | |||
class AlgoF16Direct; | |||
class AlgoF16DirectStride1; | |||
#endif | |||
}; | |||
}; | |||
} // namespace arm_common | |||
} // namespace megdnn | |||
@@ -29,6 +29,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoQU8DirectStride2 final : public AlgoBase { | |||
@@ -42,6 +45,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
#if __ARM_FEATURE_DOTPROD | |||
class ConvBiasImpl::AlgoDotU8DirectStride1 final : public AlgoBase { | |||
@@ -56,6 +62,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoDotU8DirectStride2 final : public AlgoBase { | |||
@@ -69,6 +78,9 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
virtual SmallVector<NCBKern> dispatch_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QUINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
#endif | |||
} // namespace arm_common | |||
@@ -26,7 +26,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::INT8X8X16, DEFAULT) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | |||
@@ -40,7 +40,7 @@ public: | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { | |||
@@ -54,7 +54,7 @@ public: | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4) | |||
}; | |||
#if __ARM_FEATURE_DOTPROD | |||
@@ -69,7 +69,7 @@ public: | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, MK4_DOT) | |||
}; | |||
#endif | |||
@@ -87,7 +87,7 @@ public: | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) | |||
}; | |||
class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { | |||
@@ -101,7 +101,7 @@ public: | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4) | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 1, 1, 4, AlgoDataType::FLOAT32, MK4) | |||
}; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
@@ -116,7 +116,7 @@ public: | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::FLOAT16, DEFAULT) | |||
}; | |||
#endif | |||
@@ -131,7 +131,13 @@ public: | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(1, 1, 1, 4) | |||
MEGDNN_OVERRIDE_MATMUL_DESC( | |||
1, 1, 1, 4, | |||
static_cast<AlgoDataType>( | |||
static_cast<uint32_t>(AlgoDataType::FLOAT16) | | |||
static_cast<uint32_t>(AlgoDataType::FLOAT32) | | |||
static_cast<uint32_t>(AlgoDataType::QINT8X8X32)), | |||
DEFAULT) | |||
}; | |||
} // namespace arm_common | |||
@@ -25,7 +25,7 @@ void* const MatrixMulImpl::sm_arm_common_algo_type = | |||
class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
AlgoInt8x8x16 int8x8x16; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
AlgoF16Gemv f16gemv; | |||
AlgoF16Gemv f16gemv; | |||
#endif | |||
AlgoInt8x8x32Gemv int8x8x32_gemv; | |||
AlgoInt8x8x32GemvMK4 int8x8x32_gemv_mk4; | |||
@@ -34,10 +34,11 @@ class MatrixMulImpl::AlgoPack : NonCopyableObj { | |||
#endif | |||
AlgoGevm gevm; | |||
AlgoF32GemvMK4 f32_gemv_mk4; | |||
public: | |||
AlgoPack() { | |||
all_algos.emplace_back(&int8x8x16); | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
all_algos.emplace_back(&f16gemv); | |||
#endif | |||
#if __ARM_FEATURE_DOTPROD | |||
@@ -47,7 +48,7 @@ public: | |||
all_algos.emplace_back(&int8x8x32_gemv_mk4); | |||
all_algos.emplace_back(&f32_gemv_mk4); | |||
all_algos.emplace_back(&gevm); | |||
} | |||
} | |||
SmallVector<AlgoBase*> all_algos; | |||
}; | |||
@@ -37,6 +37,9 @@ public: | |||
size_t group = param.filter_meta.group; | |||
return {{kimpl, {group, 1_z, 1_z}}}; | |||
} | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; | |||
} | |||
}; | |||
} // namespace armv7 | |||
@@ -38,6 +38,10 @@ public: | |||
size_t group = param.filter_meta.group; | |||
return {{kimpl, {group, 1_z, 1_z}}}; | |||
} | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QUINT8X8X32, AlgoCategory::IM2COL}; | |||
} | |||
}; | |||
} // namespace armv7 | |||
@@ -85,7 +85,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32::get_kern( | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32, megdnn_armv7_matmul_kern, | |||
"AlgoF32Impl"_hash, | |||
armv7::matmul::sgemm_4x12, float, float); | |||
armv7::matmul::sgemm_4x12, float, float, | |||
AlgoDataType::FLOAT32, DEFAULT); | |||
/* ===================== F32 algo mk4 K4x12 ===================== */ | |||
@@ -154,7 +155,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32MK4Pack4x12, | |||
megdnn_armv7_matmul_kern, | |||
"AlgoF32MK4Pack4x12"_hash, | |||
armv7::matmul::sgemm_mk4_pack_4x12, float, | |||
float); | |||
float, AlgoDataType::FLOAT32, MK4); | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
/* ===================== F16 K4x16x1 algo ===================== */ | |||
@@ -215,7 +216,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF16K4x16x1::get_kern( | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF16K4x16x1, megdnn_armv7_matmul_kern, | |||
"AlgoF16K4x16x1"_hash, | |||
armv7::matmul::hgemm_4x16, dt_float16, | |||
dt_float16); | |||
dt_float16, AlgoDataType::FLOAT16, | |||
DEFAULT); | |||
#endif | |||
@@ -280,7 +282,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x2x16, | |||
megdnn_armv7_matmul_kern, | |||
"AlgoInt8x8x32K4x2x16"_hash, | |||
armv7::matmul::gemm_s8_4x2, int8_t, | |||
int32_t); | |||
int32_t, AlgoDataType::QINT8X8X32, | |||
DEFAULT); | |||
/* ===================== Int8x8x32 Kernel 4x8x8 algo ===================== */ | |||
namespace { | |||
@@ -342,7 +345,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K4x8x8, | |||
megdnn_armv7_matmul_kern, | |||
"AlgoInt8x8x32K4x8x8"_hash, | |||
armv7::matmul::gemm_s8_4x8, int8_t, | |||
int32_t); | |||
int32_t, AlgoDataType::QINT8X8X32, | |||
DEFAULT); | |||
/* ===================== Quint8 Kernel 4x8x8 algo ===================== */ | |||
namespace { | |||
@@ -402,7 +406,8 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoQuint8K4x8x8::get_kern( | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8K4x8x8, megdnn_armv7_matmul_kern, | |||
"AlgoQuint8K4x8x8"_hash, | |||
armv7::matmul::gemm_u8_4x8, uint8_t, | |||
int32_t); | |||
int32_t, AlgoDataType::QUINT8X8X32, | |||
DEFAULT); | |||
/* ===================== Int8x8x16 Kernel 2x4x16 algo ===================== */ | |||
namespace { | |||
@@ -468,7 +473,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x2x16, | |||
megdnn_armv7_matmul_kern, | |||
"AlgoInt8x8x16K4x2x16"_hash, | |||
armv7::matmul::gemm_s8x8x16_4x2, int8_t, | |||
int16_t); | |||
int16_t, AlgoDataType::INT8X8X16, DEFAULT); | |||
/* ===================== Int8x8x16 Kernel 4x8x8 algo ===================== */ | |||
namespace { | |||
@@ -534,7 +539,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x16K4x8x8, | |||
megdnn_armv7_matmul_kern, | |||
"AlgoInt8x8x16K4x8x8"_hash, | |||
armv7::matmul::gemm_s8x8x16_4x8, int8_t, | |||
int16_t); | |||
int16_t, AlgoDataType::INT8X8X16, DEFAULT); | |||
/* =================== Int8x8x16 Kernel MK4 8x8x4 algo ===================*/ | |||
@@ -602,7 +607,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16MK4_8x8x4, | |||
megdnn_armv7_matmul_kern, | |||
"AlgoInt8x8x16MK4_8x8x4"_hash, | |||
armv7::matmul::gemm_s8x8x16_mk4_8x8, | |||
int8_t, int16_t, int16_t); | |||
int8_t, int16_t, int16_t, | |||
AlgoDataType::INT8X8X16, MK4); | |||
/* ===================== Int16x16x32 Kernel 12x4x1 algo ===================== */ | |||
@@ -668,7 +674,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt16x16x32K12x4x1, | |||
megdnn_armv7_matmul_kern, | |||
"AlgoInt16x16x32K12x4x1"_hash, | |||
armv7::matmul::gemm_s16x16x32_12x4, | |||
int16_t, int32_t); | |||
int16_t, int32_t, | |||
AlgoDataType::INT16X16X32, DEFAULT); | |||
#if __ARM_FEATURE_DOTPROD | |||
/* ===================== Int8 K6x8x4 algo ===================== */ | |||
namespace { | |||
@@ -724,7 +731,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32K6x8x4, | |||
megdnn_armv7_matmul_kern, | |||
"AlgoInt8x8x32K6x8x4"_hash, | |||
armv7::matmul::gemm_dots8_6x8, int8_t, | |||
int32_t); | |||
int32_t, AlgoDataType::QINT8X8X32, | |||
DEFAULT); | |||
/* ===================== Quint8 K4x8x4 algo ===================== */ | |||
namespace { | |||
void quint8_dot_k4x8x4_kern(const MatrixMulImpl::KernParam& kern_param) { | |||
@@ -786,7 +794,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoQuint8DotK4x8x4, | |||
megdnn_armv7_matmul_kern, | |||
"AlgoQuint8DotK4x8x4"_hash, | |||
armv7::matmul::gemm_dot_quint8_4x8, | |||
uint8_t, int32_t); | |||
uint8_t, int32_t, | |||
AlgoDataType::QUINT8X8X32, DEFAULT); | |||
/* ======================== Int8 MK4 8x4x4 dot algo ======================== */ | |||
namespace { | |||
@@ -854,7 +863,7 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_8x4x4DotProd, | |||
megdnn_armv7_matmul_kern, | |||
"AlgoInt8x8x32MK4_8x4x4DotProd"_hash, | |||
armv7::matmul::gemm_mk4_dots8_8x4, int8_t, | |||
int32_t); | |||
int32_t, AlgoDataType::QINT8X8X32, MK4_DOT); | |||
#endif | |||
/* ===================== F32 algo K4x8 ===================== */ | |||
@@ -1099,6 +1108,6 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32MK4_4x2x16, | |||
megdnn_armv7_matmul_kern, | |||
"AlgoInt8x8x32MK4_4x2x16"_hash, | |||
armv7::matmul::gemm_mk4_s8_4x2, int8_t, | |||
int32_t); | |||
int32_t, AlgoDataType::QINT8X8X32, MK4); | |||
// vim: syntax=cpp.doxygen |
@@ -50,7 +50,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4) | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4, AlgoDataType::FLOAT32, MK4) | |||
}; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
@@ -73,7 +73,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2) | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::FLOAT16, MK8) | |||
}; | |||
#endif | |||
#if __ARM_FEATURE_DOTPROD | |||
@@ -205,7 +205,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2) | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2, AlgoDataType::INT16X16X32, MK8) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | |||
@@ -18,7 +18,6 @@ namespace armv7 { | |||
class MatrixMulImpl : public arm_common::MatrixMulImpl { | |||
public: | |||
using arm_common::MatrixMulImpl::MatrixMulImpl; | |||
SmallVector<AlgoBase*> algo_pack() override; | |||
private: | |||
@@ -110,6 +110,11 @@ void __log__(LogLevel level, const char* file, const char* func, int line, | |||
} while (0) | |||
#endif // megdnn_ENABLE_LOGGING | |||
template <typename T> | |||
constexpr int32_t cast_int(T data) { | |||
return static_cast<int32_t>(data); | |||
} | |||
/* helper functions */ | |||
/** | |||
* \brief Get the next `stride' index lexicographically. | |||
@@ -187,6 +192,29 @@ std::unique_ptr<T> make_unique(Args&&... args) { | |||
return std::unique_ptr<T>(new T(std::forward<Args>(args)...)); | |||
} | |||
/*! | |||
* \brief check whether the source enum contain the target data type enum | |||
*/ | |||
bool inline contain_data_type(detail::AlgoDataType source, | |||
detail::AlgoDataType target) { | |||
return static_cast<bool>(static_cast<uint32_t>(source) & | |||
static_cast<uint32_t>(target)); | |||
} | |||
/*! | |||
* \brief get the source enum contain the data type number | |||
*/ | |||
template<typename T> | |||
size_t nr_type_contain(T index) { | |||
uint32_t sr_index = static_cast<uint32_t>(index); | |||
size_t nr_type = 0; | |||
while (sr_index != 0) { | |||
nr_type++; | |||
sr_index &= (sr_index - 1); | |||
} | |||
return nr_type; | |||
} | |||
/** | |||
* \brief Aligned workspace bundle. | |||
* | |||
@@ -26,6 +26,16 @@ public: | |||
AlgoSelectionStrategy algo_selection_strategy) const override; | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
auto support_data_type = static_cast<AlgoDataType>( | |||
static_cast<uint32_t>(AlgoDataType::FLOAT16) | | |||
static_cast<uint32_t>(AlgoDataType::FLOAT32) | | |||
static_cast<uint32_t>(AlgoDataType::INT8X8X16) | | |||
static_cast<uint32_t>(AlgoDataType::QINT8X8X32) | | |||
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)); | |||
return {support_data_type, AlgoCategory::NAIVE}; | |||
} | |||
}; | |||
class ConvBiasImpl::AlgoWinogradF32 final : public AlgoBase { | |||
@@ -46,6 +56,10 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; | |||
} | |||
private: | |||
MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
mutable std::string m_name; | |||
@@ -70,6 +84,10 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::WINOGRAD}; | |||
} | |||
private: | |||
MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
mutable std::string m_name; | |||
@@ -94,6 +112,10 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; | |||
} | |||
private: | |||
MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
mutable std::string m_name; | |||
@@ -118,6 +140,10 @@ public: | |||
size_t get_workspace(const NCBKernSizeParam& param) const override; | |||
SmallVector<NCBKern> dispatch_kerns(const NCBKernSizeParam&) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::WINOGRAD}; | |||
} | |||
private: | |||
MatrixMulImpl::AlgoBase* m_matmul_algo; | |||
mutable std::string m_name; | |||
@@ -140,7 +140,7 @@ using BiasMode = ConvBiasForward::BiasMode; | |||
break; \ | |||
} | |||
#define MEGDNN_WINOGRAD_ALGO_FUN_DECLARE() \ | |||
#define MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(_algo_data_type) \ | |||
bool is_reproducible() const override { return true; } \ | |||
bool usable(const NCBKernSizeParam& param, \ | |||
AlgoSelectionStrategy algo_selection_strategy) const override; \ | |||
@@ -153,6 +153,9 @@ using BiasMode = ConvBiasForward::BiasMode; | |||
const override; \ | |||
virtual SmallVector<NCBKern> dispatch_preprocess_kerns( \ | |||
const NCBKernSizeParam& param) const override; \ | |||
ConvAlgoTypePack get_algo_type() const override { \ | |||
return {_algo_data_type, AlgoCategory::WINOGRAD}; \ | |||
} \ | |||
\ | |||
private: \ | |||
fallback::MatrixMulImpl::AlgoBase* m_matmul_algo; \ | |||
@@ -288,7 +288,8 @@ bool ConvBiasImpl::AlgoConv1x1::is_preferred( | |||
size_t OH = param.osz[0]; | |||
size_t OW = param.osz[1]; | |||
if (OH * OW != 1) { | |||
return true; | |||
return m_matmul_algo->algoset() != | |||
MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV; | |||
} else { | |||
#if (MEGDNN_ARMV7 || MEGDNN_AARCH64) | |||
if (param.src_type.enumv() == DTypeEnum::Int8 && | |||
@@ -56,6 +56,11 @@ public: | |||
SmallVector<NCBKern> dispatch_preprocess_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override{ | |||
return {m_matmul_algo->matmul_description().algo_type.data_type, | |||
AlgoCategory::IM2COL}; | |||
} | |||
protected: | |||
size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; | |||
@@ -34,6 +34,16 @@ public: | |||
bool is_preferred(const NCBKernSizeParam&) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
auto support_data_type = static_cast<AlgoDataType>( | |||
static_cast<uint32_t>(AlgoDataType::FLOAT16) | | |||
static_cast<uint32_t>(AlgoDataType::FLOAT32) | | |||
static_cast<uint32_t>(AlgoDataType::INT8X8X16) | | |||
static_cast<uint32_t>(AlgoDataType::QINT8X8X32) | | |||
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)); | |||
return {support_data_type, AlgoCategory::IM2COL}; | |||
} | |||
protected: | |||
size_t get_oc_tile_size_heuristic(const NCBKernSizeParam& param) const; | |||
}; | |||
@@ -48,15 +48,25 @@ public: | |||
SmallVector<NCBKern> dispatch_preprocess_kerns( | |||
const NCBKernSizeParam& param) const override; | |||
bool is_preferred(const NCBKernSizeParam& param) const override { | |||
if (param.src_type.category() == DTypeCategory::QUANTIZED) { | |||
static CpuOprDelegationStorage<1> storage; | |||
auto conv_bias_opr = storage.get<ConvBias, 0>(); | |||
return static_cast<ConvBiasImpl*>(conv_bias_opr) | |||
->is_matmul_quantized_prefer(param); | |||
size_t OH = param.osz[0]; | |||
size_t OW = param.osz[1]; | |||
//! gemm and oh * ow > 1 is prefer | |||
//! gemv and oh * ow == 1 is prefer | |||
if ((m_matmul_algo->algoset() != | |||
MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV && | |||
OH * OW > 1) || | |||
(m_matmul_algo->algoset() == | |||
MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV && | |||
OH * OW == 1)) { | |||
return true; | |||
} else { | |||
return false; | |||
} | |||
auto&& fm = param.filter_meta; | |||
auto OC = fm.ocpg, IC = fm.icpg; | |||
return OC >= 32 || IC >= 32; | |||
} | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {m_matmul_algo->matmul_description().algo_type.data_type, | |||
AlgoCategory::IM2COL}; | |||
} | |||
private: | |||
@@ -48,11 +48,26 @@ void incr_ptr(T*& dst, ptrdiff_t delta) { | |||
} // namespace | |||
#if MEGDNN_X86 | |||
#define SKIP_GEMV() | |||
//! As we haven't direct conv for int8x8x16 yet, if we disable gemv here, it may | |||
//! fallback to naive implementation, which may cause performance very low, so | |||
//! here we just enable im2col for gemv in x86 backend. | |||
//! FIXME: remove it when we add direct conv support for int8x8x16 | |||
#else | |||
#define SKIP_GEMV() \ | |||
if (algo->algoset() == MatrixMulImpl::AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { \ | |||
continue; \ | |||
} | |||
#endif | |||
class ConvBiasImpl::AlgoPack : NonCopyableObj { | |||
AlgoNaive algo_naive; | |||
SmallVector<std::unique_ptr<AlgoBase>> refhold; | |||
public: | |||
AlgoPack() { | |||
refhold.emplace_back(new AlgoConv1x1Gemv()); | |||
all_algos.emplace_back(refhold.back().get()); | |||
@@ -110,8 +125,6 @@ public: | |||
all_algos.emplace_back(refhold.back().get()); | |||
#endif | |||
} | |||
//! reverse matmul algo, when the algo is_prefer can be selected first | |||
std::reverse(all_algos.begin(), all_algos.end()); | |||
all_algos.emplace_back(&algo_naive); | |||
} | |||
SmallVector<AlgoBase*> all_algos; | |||
@@ -121,6 +134,22 @@ SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::algo_pack() { | |||
static AlgoPack sl_algo_pack; | |||
return sl_algo_pack.all_algos; | |||
} | |||
SmallVector<ConvBiasImpl::AlgoBase*> ConvBiasImpl::select_algo_type( | |||
ConvAlgoTypePack target_type) { | |||
megdnn_assert(nr_type_contain(target_type.data_type), | |||
"ConvBias algo selection only support one type"); | |||
SmallVector<ConvBiasImpl::AlgoBase*> algos; | |||
for (auto&& algo : algo_pack()) { | |||
auto algo_type = algo->get_algo_type(); | |||
if (contain_data_type(algo_type.data_type, target_type.data_type) && | |||
algo_type.algo_category == target_type.algo_category) { | |||
algos.push_back(algo); | |||
} | |||
} | |||
return algos; | |||
} | |||
bool ConvBiasImpl::is_naive_algo(ConvBiasImpl::Algorithm* algo) { | |||
return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0; | |||
} | |||
@@ -248,12 +277,32 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic( | |||
ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm_heuristic_with_ncb( | |||
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | |||
bool reproducible) { | |||
for (auto i : get_all_algorithms_with_ncb(param)) { | |||
if (static_cast<AlgoBase*>(i)->usable_reproducible( | |||
param, AlgoSelectionStrategy::HEURISTIC, reproducible) && | |||
NCB_ALGO_FUNC(get_workspace, i, param) <= | |||
workspace_limit_in_bytes) { | |||
return i; | |||
auto algo_data_type = param.deduce_algo_data_type(); | |||
auto suggest_category_order = suggest_algo_category_order(param); | |||
for (auto category : suggest_category_order) { | |||
auto&& origin_algos = select_algo_type({algo_data_type, category}); | |||
ConvBiasImpl::Algorithm* heuristic_algo = nullptr; | |||
for (auto i : origin_algos) { | |||
bool usable_reproducible = | |||
static_cast<AlgoBase*>(i)->usable_reproducible( | |||
param, AlgoSelectionStrategy::HEURISTIC, | |||
reproducible); | |||
if (usable_reproducible && | |||
static_cast<AlgoBase*>(i)->get_workspace(param) <= | |||
workspace_limit_in_bytes) { | |||
//! store the first usable algo if no prefer algo, choose it as | |||
//! the target algo | |||
if (!heuristic_algo) { | |||
heuristic_algo = i; | |||
} | |||
//! choose the first prefer algo | |||
if (i->is_preferred(param)) { | |||
return i; | |||
} | |||
} | |||
} | |||
if (heuristic_algo) { | |||
return heuristic_algo; | |||
} | |||
} | |||
return nullptr; | |||
@@ -300,9 +349,8 @@ ConvBiasImpl::NCBKernSizeParam ConvBiasImpl::make_ncb_kern_size_param( | |||
sizeof(ConvolutionImpl::CanonizedFilterMeta), | |||
"sizeof CanonizedFilterMeta in convolution and conv_bias " | |||
"should be equal"); | |||
CanonizedFilterMeta fm = check_layout_fwd(src, filter, dst); | |||
ConvolutionImpl::CanonizedFilterMeta conv_fm; | |||
conv_fm.copy_from(fm); | |||
auto&& fm = check_layout_fwd(src, filter, dst); | |||
auto& conv_fm = reinterpret_cast<ConvolutionImpl::CanonizedFilterMeta&>(fm); | |||
param::MatrixMul::Format format = param::MatrixMul::Format::DEFAULT; | |||
if (param().format == Param::Format::NCHW_WINOGRAD || | |||
@@ -367,7 +415,7 @@ ConvBiasImpl::NCBKernParam ConvBiasImpl::make_ncb_kern_param( | |||
void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param, | |||
ConvBiasImpl::Algorithm* algo) { | |||
auto ncb_kerns = NCB_ALGO_FUNC(dispatch_kerns, algo, param); | |||
auto&& ncb_kerns = NCB_ALGO_FUNC(dispatch_kerns, algo, param); | |||
for (auto&& kernel : ncb_kerns) { | |||
auto run = [kernel, param](size_t index, size_t thread_id) { | |||
CpuNDRange ndrange_id(kernel.global_size, index); | |||
@@ -380,7 +428,7 @@ void ConvBiasImpl::exec_with_ncb_kern(const NCBKernParam& param, | |||
void ConvBiasImpl::exec_preprocess_with_ncb_kern( | |||
const NCBKernParam& param, ConvBiasImpl::Algorithm* algo) { | |||
auto ncb_kerns = NCB_ALGO_FUNC(dispatch_preprocess_kerns, algo, param); | |||
auto&& ncb_kerns = NCB_ALGO_FUNC(dispatch_preprocess_kerns, algo, param); | |||
for (auto&& kernel : ncb_kerns) { | |||
auto run = [kernel, param](size_t index, size_t thread_id) { | |||
CpuNDRange ndrange_id(kernel.global_size, index); | |||
@@ -405,7 +453,6 @@ std::vector<ConvBiasImpl::Algorithm*> ConvBiasImpl::get_all_algorithms_with_ncb( | |||
} | |||
} | |||
} | |||
std::reverse(prefer_algos.begin(), prefer_algos.end()); | |||
//! Prefer algo inserted from begin | |||
algos.insert(algos.begin(), prefer_algos.begin(), prefer_algos.end()); | |||
return algos; | |||
@@ -425,6 +472,35 @@ ConvBiasImpl::Algorithm* ConvBiasImpl::get_algorithm( | |||
return m_prev_selected_algo; | |||
} | |||
SmallVector<AlgoCategory> ConvBiasImpl::suggest_algo_category_order( | |||
const NCBKernSizeParam& param) const { | |||
auto IC = param.filter_meta.icpg; | |||
auto OC = param.filter_meta.ocpg; | |||
auto FH = param.filter_meta.spatial[0]; | |||
auto FW = param.filter_meta.spatial[1]; | |||
//! TODO: now winograd only support in fast-run | |||
if (param.filter_meta.format == param::ConvBias::Format::NCHW_WINOGRAD || | |||
param.filter_meta.format == param::ConvBias::Format::NCHW44_WINOGRAD || | |||
param.filter_meta.format == param::ConvBias::Format::NCHW88_WINOGRAD) { | |||
return {AlgoCategory::WINOGRAD}; | |||
} | |||
//! im2col + matmul | |||
bool im2col_prefer = (IC >= 32 || OC >= 32); | |||
//! quantized algo use matmul when direct algo is unusable | |||
if (param.src_type.category() == DTypeCategory::QUANTIZED) { | |||
im2col_prefer = is_matmul_quantized_prefer(param); | |||
} | |||
//! conv1x1 | |||
im2col_prefer |= (FH == 1 && FW == 1); | |||
if (im2col_prefer) { | |||
return {AlgoCategory::IM2COL, AlgoCategory::DIRECT, | |||
AlgoCategory::NAIVE}; | |||
} else { | |||
return {AlgoCategory::DIRECT, AlgoCategory::IM2COL, | |||
AlgoCategory::NAIVE}; | |||
} | |||
} | |||
const char* ConvBiasImpl::get_algorithm_set_name() const { | |||
// fallback version 0 | |||
return "F0"; | |||
@@ -18,6 +18,8 @@ | |||
#include "src/fallback/matrix_mul/opr_impl.h" | |||
#include "src/naive/conv_bias/opr_impl.h" | |||
#include <unordered_map> | |||
namespace megdnn { | |||
namespace fallback { | |||
@@ -44,6 +46,7 @@ class ConvBiasImpl : public naive::ConvBiasForwardImpl { | |||
public: | |||
using naive::ConvBiasForwardImpl::ConvBiasForwardImpl; | |||
using AlgoSelectionStrategy = detail::AlgoSelectionStrategy; | |||
using AlgoDataType = detail::AlgoDataType; | |||
//! implemented by exec_with_ncb_kern() | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
@@ -94,6 +97,8 @@ public: | |||
size_t workspace_limit_in_bytes, | |||
bool reproducible) override; | |||
//! size param for kernels with non-contiguous batch | |||
struct NCBKernSizeParam : ConvolutionImpl::NCBKernSizeParam { | |||
NCBKernSizeParam() = default; | |||
@@ -244,6 +249,9 @@ public: | |||
return (!reproducible || is_reproducible()) && | |||
usable(param, algo_selection_strategy); | |||
} | |||
//! get the type of the algo | |||
virtual ConvAlgoTypePack get_algo_type() const = 0; | |||
}; | |||
/** | |||
@@ -251,6 +259,17 @@ public: | |||
*/ | |||
virtual SmallVector<AlgoBase*> algo_pack(); | |||
/** | |||
* \brief select algo according to input algo type | |||
*/ | |||
SmallVector<AlgoBase*> select_algo_type(ConvAlgoTypePack algo_type); | |||
/** | |||
* \brief suggest algo category according to the param | |||
*/ | |||
virtual SmallVector<AlgoCategory> suggest_algo_category_order( | |||
const NCBKernSizeParam& param) const; | |||
protected: | |||
virtual void exec_with_ncb_kern(const NCBKernParam& param, | |||
ConvBiasImpl::Algorithm* algo); | |||
@@ -83,6 +83,10 @@ public: | |||
SmallVector<NCBKern> dispatch_kern( | |||
const NCBKernSizeParam& /*param*/) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::NAIVE}; | |||
} | |||
}; | |||
class ConvolutionImpl::AlgoNaive final : public AlgoBase { | |||
@@ -96,11 +100,17 @@ public: | |||
SmallVector<NCBKern> dispatch_kern( | |||
const NCBKernSizeParam& /*param*/) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
auto support_data_type = static_cast<AlgoDataType>( | |||
static_cast<uint32_t>(AlgoDataType::INT8X8X16) | | |||
static_cast<uint32_t>(AlgoDataType::QINT8X8X32) | | |||
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)); | |||
return {support_data_type, AlgoCategory::NAIVE}; | |||
} | |||
}; | |||
class ConvolutionImpl::AlgoDefault final : public AlgoBase { | |||
static ConvBiasImpl::NCBKernSizeParam init_conv_bias_param( | |||
const NCBKernSizeParam& param); | |||
WorkspaceBundle get_bundle(const NCBKernSizeParam& param) const; | |||
static SmallVector<NCBKern> get_kimpl(ConvBiasImpl::AlgoBase* algo, | |||
const NCBKernSizeParam& param); | |||
@@ -136,6 +146,13 @@ public: | |||
//! select matmul to the highest preference | |||
bool is_preferred(const NCBKernSizeParam& param) const override; | |||
static ConvBiasImpl::NCBKernSizeParam init_conv_bias_param( | |||
const NCBKernSizeParam& param); | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return m_algorithm->get_algo_type(); | |||
} | |||
private: | |||
std::string m_name; | |||
ConvBiasImpl::AlgoBase* m_algorithm; | |||
@@ -23,6 +23,7 @@ | |||
#include "midout.h" | |||
#include <cstring> | |||
#include <unordered_map> | |||
MIDOUT_DECL(megdnn_fb_convbwd_float) | |||
@@ -75,6 +76,22 @@ SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::algo_pack() { | |||
static AlgoPack sl_algo_pack; | |||
return sl_algo_pack.all_algos; | |||
} | |||
SmallVector<ConvolutionImpl::AlgoBase*> ConvolutionImpl::select_algo_type( | |||
ConvAlgoTypePack target_type) { | |||
megdnn_assert(nr_type_contain(target_type.data_type), | |||
"ConvBias algo selection only support one type"); | |||
SmallVector<ConvolutionImpl::AlgoBase*> algos; | |||
for (auto&& algo : algo_pack()) { | |||
auto algo_type = algo->get_algo_type(); | |||
if (contain_data_type(algo_type.data_type, target_type.data_type) && | |||
algo_type.algo_category == target_type.algo_category) { | |||
algos.push_back(algo); | |||
} | |||
} | |||
return algos; | |||
} | |||
bool ConvolutionImpl::is_naive_algo(ConvolutionImpl::Algorithm* algo) { | |||
return algo == nullptr || strcmp(algo->name(), "DEFAULT") == 0; | |||
} | |||
@@ -249,9 +266,9 @@ ConvolutionImpl::NCBKernParam ConvolutionImpl::make_ncb_kern_param( | |||
void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param, | |||
Algorithm* algo) { | |||
auto kerns = NCB_ALGO_FUNC(dispatch_preprocess_kern, algo, param); | |||
auto fallback_handle = handle(); | |||
for (auto kernel : kerns) { | |||
auto&& kerns = NCB_ALGO_FUNC(dispatch_preprocess_kern, algo, param); | |||
auto&& fallback_handle = handle(); | |||
for (auto&& kernel : kerns) { | |||
megdnn_assert( | |||
param.filter_meta.format == Param::Format::NCHW || | |||
param.filter_meta.format == Param::Format::NHWC || | |||
@@ -270,9 +287,9 @@ void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param, | |||
void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, | |||
Algorithm* algo) { | |||
auto kerns = NCB_ALGO_FUNC(dispatch_kern, algo, param); | |||
auto fallback_handle = handle(); | |||
for (auto kernel : kerns) { | |||
auto&& kerns = NCB_ALGO_FUNC(dispatch_kern, algo, param); | |||
auto&& fallback_handle = handle(); | |||
for (auto&& kernel : kerns) { | |||
megdnn_assert( | |||
param.filter_meta.format == Param::Format::NCHW || | |||
param.filter_meta.format == Param::Format::NHWC || | |||
@@ -292,13 +309,32 @@ void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, | |||
ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm_heuristic_with_ncb( | |||
const NCBKernSizeParam& param, size_t workspace_limit_in_bytes, | |||
bool reproducible) { | |||
for (auto i : get_all_algorithms_with_ncb(param)) { | |||
bool usable_reproducible = | |||
static_cast<AlgoBase*>(i)->usable_reproducible( | |||
param, AlgoSelectionStrategy::HEURISTIC, reproducible); | |||
if (usable_reproducible && NCB_ALGO_FUNC(get_workspace, i, param) <= | |||
workspace_limit_in_bytes) { | |||
return i; | |||
auto algo_data_type = param.deduce_algo_data_type(); | |||
auto suggest_category_order = suggest_algo_category_order(param); | |||
for (auto category : suggest_category_order) { | |||
auto&& origin_algos = select_algo_type({algo_data_type, category}); | |||
ConvolutionImpl::Algorithm* heuristic_algo = nullptr; | |||
for (auto i : origin_algos) { | |||
bool usable_reproducible = | |||
static_cast<AlgoBase*>(i)->usable_reproducible( | |||
param, AlgoSelectionStrategy::HEURISTIC, | |||
reproducible); | |||
if (usable_reproducible && | |||
static_cast<AlgoBase*>(i)->get_workspace(param) <= | |||
workspace_limit_in_bytes) { | |||
//! store the first usable algo if no prefer algo, choose it as | |||
//! the target algo | |||
if (!heuristic_algo) { | |||
heuristic_algo = i; | |||
} | |||
//! choose the first prefer algo | |||
if (i->is_preferred(param)) { | |||
return i; | |||
} | |||
} | |||
} | |||
if (heuristic_algo) { | |||
return heuristic_algo; | |||
} | |||
} | |||
return nullptr; | |||
@@ -317,8 +353,6 @@ ConvolutionImpl::get_all_algorithms_with_ncb(const NCBKernSizeParam& param) { | |||
} | |||
} | |||
} | |||
std::reverse(prefer_algos.begin(), prefer_algos.end()); | |||
//! Prefer algo inserted from begin | |||
ret.insert(ret.begin(), prefer_algos.begin(), prefer_algos.end()); | |||
return ret; | |||
} | |||
@@ -337,11 +371,45 @@ ConvolutionImpl::Algorithm* ConvolutionImpl::get_algorithm( | |||
return m_prev_selected_algo; | |||
} | |||
SmallVector<AlgoCategory> ConvolutionImpl::suggest_algo_category_order( | |||
const NCBKernSizeParam& param) const { | |||
static CpuOprDelegationStorage<1> storage; | |||
auto conv_bias_opr = storage.get<ConvBias, 0>(); | |||
auto conv_bias_param = | |||
ConvolutionImpl::AlgoDefault::init_conv_bias_param(param); | |||
return static_cast<ConvBiasImpl*>(conv_bias_opr) | |||
->suggest_algo_category_order(conv_bias_param); | |||
} | |||
const char* ConvolutionImpl::get_algorithm_set_name() const { | |||
// fallback version 0 | |||
return "F0"; | |||
} | |||
ConvolutionImpl::AlgoDataType | |||
ConvolutionImpl::NCBKernSizeParam::deduce_algo_data_type() const { | |||
if (src_type.enumv() == DTypeEnum::Float32) { | |||
return ConvolutionImpl::AlgoDataType::FLOAT32; | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
} else if (src_type.enumv() == DTypeEnum::Float16) { | |||
return ConvolutionImpl::AlgoDataType::FLOAT16; | |||
#endif | |||
} else if (src_type.enumv() == DTypeEnum::Int8 || | |||
src_type.enumv() == DTypeEnum::QuantizedS8) { | |||
if (dst_type.enumv() == DTypeEnum::Int16) { | |||
return ConvolutionImpl::AlgoDataType::INT8X8X16; | |||
} else { | |||
return ConvolutionImpl::AlgoDataType::QINT8X8X32; | |||
} | |||
} else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||
return ConvolutionImpl::AlgoDataType::QUINT8X8X32; | |||
} else { | |||
megdnn_throw(ssprintf("megdnn not support data type of %s * %s -> %s\n", | |||
src_type.name(), filter_type.name(), | |||
dst_type.name())); | |||
} | |||
} | |||
/* ===================== ConvolutionBackwardData ===================== */ | |||
void* const ConvolutionBackwardDataImpl::sm_fallback_deconv_algo_type = | |||
@@ -10,11 +10,28 @@ | |||
*/ | |||
#pragma once | |||
#include "megdnn/oprs/base.h" | |||
#include "src/common/utils.h" | |||
#include "src/fallback/handle.h" | |||
#include "src/naive/convolution/opr_impl.h" | |||
namespace megdnn { | |||
/** | |||
* \brief Convolutino algo category | |||
*/ | |||
enum class AlgoCategory : int32_t { | |||
DIRECT = 0, | |||
IM2COL = 1, | |||
WINOGRAD = 2, | |||
NAIVE = 3, | |||
}; | |||
struct ConvAlgoTypePack { | |||
detail::AlgoDataType data_type : 32; | |||
AlgoCategory algo_category : 32; | |||
}; | |||
namespace fallback { | |||
/*! | |||
@@ -33,6 +50,7 @@ class ConvolutionImpl : public naive::ConvolutionForwardImpl { | |||
public: | |||
using naive::ConvolutionForwardImpl::ConvolutionForwardImpl; | |||
using AlgoSelectionStrategy = detail::AlgoSelectionStrategy; | |||
using AlgoDataType = detail::AlgoDataType; | |||
//! implemented by exec_with_ncb_kern() | |||
void exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, | |||
@@ -86,6 +104,8 @@ public: | |||
size_t nr_threads; | |||
//! weight_preprocess info | |||
const PreprocessedFilter* preprocessed_filter; | |||
//! get the data type category of the param for select the algo | |||
AlgoDataType deduce_algo_data_type() const; | |||
}; | |||
//! memory param for kernels with non-contiguous batch | |||
@@ -211,6 +231,9 @@ public: | |||
return (!reproducible || is_reproducible()) && | |||
usable(param, algo_selection_strategy); | |||
} | |||
//! get the type of the algo | |||
virtual ConvAlgoTypePack get_algo_type() const = 0; | |||
}; | |||
/** | |||
@@ -218,6 +241,11 @@ public: | |||
*/ | |||
virtual SmallVector<AlgoBase*> algo_pack(); | |||
/** | |||
* \brief select algo according to input algo type | |||
*/ | |||
SmallVector<AlgoBase*> select_algo_type(ConvAlgoTypePack algo_type); | |||
protected: | |||
virtual void exec_with_ncb_kern(const NCBKernParam& param, Algorithm* algo); | |||
@@ -258,6 +286,9 @@ private: | |||
_megdnn_tensor_out dst, | |||
const PreprocessedFilter* preprocessed_filter, | |||
_megdnn_workspace workspace); | |||
SmallVector<AlgoCategory> suggest_algo_category_order( | |||
const NCBKernSizeParam& param) const; | |||
}; | |||
class ConvolutionBackwardDataImpl : public naive::ConvolutionBackwardDataImpl { | |||
@@ -76,7 +76,7 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32K8x12x1::get_kern( | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoF32K8x12x1, megdnn_fb_matmul_f32_kern, | |||
5, matmul::fallback::sgemm_8x12, float, | |||
float); | |||
float, AlgoDataType::FLOAT32, DEFAULT); | |||
/* ===================== gemv algo ===================== */ | |||
bool MatrixMulImpl::AlgoGemv::usable( | |||
@@ -37,7 +37,15 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||
MEGDNN_OVERRIDE_MATMUL_DESC( | |||
8, 16, 1, 4, | |||
static_cast<AlgoDataType>( | |||
static_cast<uint32_t>(AlgoDataType::FLOAT16) | | |||
static_cast<uint32_t>(AlgoDataType::FLOAT32) | | |||
static_cast<uint32_t>(AlgoDataType::INT8X8X16) | | |||
static_cast<uint32_t>(AlgoDataType::QINT8X8X32) | | |||
static_cast<uint32_t>(AlgoDataType::QUINT8X8X32)), | |||
DEFAULT) | |||
}; | |||
} // namespace fallback | |||
@@ -352,13 +352,15 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K, | |||
DType dtype_c) \ | |||
: A_dtype(dtype_a), B_dtype(dtype_b), C_dtype(dtype_c) {} | |||
#define MEGDNN_OVERRIDE_MATMUL_DESC(_m, _n, _k, _packa_type_size) \ | |||
MatmulDescription matmul_description() const override { \ | |||
MatmulDescription mdesc; \ | |||
mdesc.packmode = packmode(); \ | |||
mdesc.innerblocksize = {_m, _n, _k}; \ | |||
mdesc.packa_type_size = _packa_type_size; \ | |||
return mdesc; \ | |||
#define MEGDNN_OVERRIDE_MATMUL_DESC(_m, _n, _k, _packa_type_size, _data_type, \ | |||
_format) \ | |||
MatmulDescription matmul_description() const override { \ | |||
MatmulDescription mdesc; \ | |||
mdesc.packmode = packmode(); \ | |||
mdesc.innerblocksize = {_m, _n, _k}; \ | |||
mdesc.packa_type_size = _packa_type_size; \ | |||
mdesc.algo_type = {_data_type, Param::Format::_format}; \ | |||
return mdesc; \ | |||
} | |||
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL() \ | |||
@@ -373,7 +375,7 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K, | |||
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \ | |||
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \ | |||
_packa_type) \ | |||
_packa_type, _support_data_type, _format) \ | |||
\ | |||
MatrixMulImpl::kern_naked_t MatrixMulImpl::_algo_name::get_kern_naked( \ | |||
const KernSizeParam&) const { \ | |||
@@ -474,14 +476,16 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K, | |||
mdesc.innerblocksize = {_strategy::KERNEL_H, _strategy::KERNEL_W, \ | |||
_strategy::UNROLL_K}; \ | |||
mdesc.packa_type_size = sizeof(_packa_type); \ | |||
mdesc.algo_type = {_support_data_type, Param::Format::_format}; \ | |||
return mdesc; \ | |||
} | |||
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \ | |||
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type) \ | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(_algo_name, _midout_name, \ | |||
_mid_index, _strategy, \ | |||
_i_type, _c_type, _i_type) | |||
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \ | |||
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \ | |||
_support_data_type, _format) \ | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \ | |||
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \ | |||
_i_type, _support_data_type, _format) | |||
} // namespace matmul | |||
} // namespace megdnn | |||
@@ -38,6 +38,22 @@ SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::algo_pack() { | |||
return s_algo_pack.all_algos; | |||
} | |||
SmallVector<MatrixMulImpl::AlgoBase*> MatrixMulImpl::select_algo_type( | |||
AlgoTypePack index) { | |||
megdnn_assert(nr_type_contain(index.data_type), | |||
"Matmul algo selection only support one type"); | |||
SmallVector<MatrixMulImpl::AlgoBase*> algos; | |||
for (auto&& algo : algo_pack()) { | |||
auto algo_desc = algo->matmul_description(); | |||
if (contain_data_type(algo_desc.algo_type.data_type, | |||
index.data_type) && | |||
algo_desc.algo_type.format == index.format) { | |||
algos.push_back(algo); | |||
} | |||
} | |||
return algos; | |||
} | |||
std::vector<MatrixMul::Algorithm*> MatrixMulImpl::get_all_algorithms( | |||
const TensorLayout& A, const TensorLayout& B, const TensorLayout& C) { | |||
std::vector<Algorithm*> gemm_algos, gemv_algos; | |||
@@ -71,17 +87,25 @@ MatrixMul::Algorithm* MatrixMulImpl::get_algorithm_heuristic( | |||
"require reproducible algorithm, but given algorithm is not " | |||
"reproducible"); | |||
} | |||
auto algos = get_all_algorithms(A, B, C); | |||
AlgoTypePack algo_type; | |||
algo_type.data_type = kern_size_param.deduce_algo_data_type(); | |||
algo_type.format = kern_size_param.format; | |||
auto algos = select_algo_type(algo_type); | |||
Algorithm *heuristic_algo = nullptr; | |||
for (auto&& algo : algos) { | |||
if (static_cast<AlgoBase*>(algo)->preferred_reproducible( | |||
if (static_cast<AlgoBase*>(algo)->usable(kern_size_param) && | |||
static_cast<AlgoBase*>(algo)->preferred_reproducible( | |||
kern_size_param, reproducible) && | |||
static_cast<AlgoBase*>(algo)->get_workspace(kern_size_param) <= | |||
workspace_limit_in_bytes) { | |||
return algo; | |||
if (algo->algoset() == AlgoBase::AlgoSet::ALGO_TYPE_GEMV) { | |||
return algo; | |||
} else if (!heuristic_algo) { | |||
heuristic_algo = algo; | |||
} | |||
} | |||
} | |||
return nullptr; | |||
return heuristic_algo; | |||
} | |||
MatrixMulImpl::KernSizeParam MatrixMulImpl::make_kern_size_param( | |||
@@ -150,4 +174,34 @@ void MatrixMulImpl::exec(_megdnn_tensor_in A, _megdnn_tensor_in B, | |||
naive::MatrixMulForwardImpl::exec(A, B, C, workspace); | |||
} | |||
MatrixMulImpl::AlgoDataType | |||
MatrixMulImpl::KernSizeParam::deduce_algo_data_type() const { | |||
megdnn_assert(A_type.enumv() == B_type.enumv(), | |||
"Matmul A type and B type of different ctype\n"); | |||
if (A_type.enumv() == DTypeEnum::Float32) { | |||
return MatrixMulImpl::AlgoDataType::FLOAT32; | |||
#if !MEGDNN_DISABLE_FLOAT16 | |||
} else if (A_type.enumv() == DTypeEnum::Float16) { | |||
return MatrixMulImpl::AlgoDataType::FLOAT16; | |||
#endif | |||
} else if (A_type.enumv() == DTypeEnum::Int8 || | |||
A_type.enumv() == DTypeEnum::QuantizedS8) { | |||
if (C_type.enumv() == DTypeEnum::Int16) { | |||
return MatrixMulImpl::AlgoDataType::INT8X8X16; | |||
} else { | |||
megdnn_assert(C_type.enumv() == DTypeEnum::Int32 || | |||
C_type.enumv() == DTypeEnum::QuantizedS32); | |||
return MatrixMulImpl::AlgoDataType::QINT8X8X32; | |||
} | |||
} else if (A_type.enumv() == DTypeEnum::Quantized8Asymm) { | |||
return MatrixMulImpl::AlgoDataType::QUINT8X8X32; | |||
} else if (A_type.enumv() == DTypeEnum::Int16) { | |||
return MatrixMulImpl::AlgoDataType::INT16X16X32; | |||
} else { | |||
megdnn_throw(ssprintf( | |||
"megdnn matmul not support data type of %s * %s -> %s\n", | |||
A_type.name(), B_type.name(), C_type.name())); | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -10,14 +10,23 @@ | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megdnn/opr_param_defs.h" | |||
#include "src/common/utils.h" | |||
#include "src/naive/matrix_mul/opr_impl.h" | |||
#include <unordered_map> | |||
namespace megdnn { | |||
namespace fallback { | |||
struct AlgoTypePack { | |||
detail::AlgoDataType data_type : 32; | |||
param::MatrixMul::Format format : 32; | |||
}; | |||
namespace fallback { | |||
class MatrixMulImpl : public naive::MatrixMulForwardImpl { | |||
public: | |||
using naive::MatrixMulForwardImpl::MatrixMulForwardImpl; | |||
using AlgoDataType = detail::AlgoDataType; | |||
bool is_thread_safe() const override { return true; } | |||
@@ -34,6 +43,8 @@ public: | |||
bool trA, trB; | |||
Param::ComputeMode compute_mode; | |||
Param::Format format; | |||
//! get the data type category of the param for select the algo | |||
AlgoDataType deduce_algo_data_type() const; | |||
}; | |||
struct KernParam : public KernSizeParam { | |||
@@ -110,6 +121,7 @@ public: | |||
struct MatmulDescription { | |||
PackMode packmode; | |||
InnerBlockSize innerblocksize; | |||
AlgoTypePack algo_type; | |||
size_t packa_type_size; | |||
}; | |||
@@ -146,6 +158,11 @@ public: | |||
*/ | |||
virtual SmallVector<AlgoBase*> algo_pack(); | |||
/** | |||
* \brief select algo according to input algo type | |||
*/ | |||
SmallVector<AlgoBase*> select_algo_type(AlgoTypePack algo_type); | |||
protected: | |||
KernSizeParam make_kern_size_param(const TensorLayout& A, | |||
const TensorLayout& B, | |||
@@ -48,6 +48,10 @@ public: | |||
} | |||
void* type() const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
/* ===================== direct-stride2 algo ===================== */ | |||
@@ -81,6 +85,10 @@ public: | |||
} | |||
void* type() const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
/* =========================== winograd ======================== */ | |||
class ConvBiasImpl::AlgoFP32WinogradF63_8x8 final : public AlgoBase { | |||
@@ -96,7 +104,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
void* type() const override; | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
}; | |||
class ConvBiasImpl::AlgoFP32WinogradF23_8x8 final : public AlgoBase { | |||
@@ -112,7 +120,7 @@ public: | |||
return m_name.c_str(); | |||
} | |||
void* type() const override; | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(); | |||
MEGDNN_WINOGRAD_ALGO_FUN_DECLARE(AlgoDataType::FLOAT32); | |||
}; | |||
/* ===================== matmul algo ===================== */ | |||
@@ -151,6 +159,9 @@ public: | |||
} | |||
void* type() const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::IM2COL}; | |||
} | |||
}; | |||
#if MEGDNN_X86_WITH_MKL_DNN | |||
@@ -192,6 +203,10 @@ public: | |||
return {{kern, {1_z, 1_z, 1_z}}}; | |||
} | |||
void* type() const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::FLOAT32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
#endif | |||
// vim: syntax=cpp.doxygen |
@@ -224,8 +224,6 @@ bool mkldnn_matmul_qint8_preferred( | |||
const ConvBiasImpl::NCBKernSizeParam& param) { | |||
auto is_preferred = true; | |||
auto&& fm = param.filter_meta; | |||
megdnn_assert_internal(fm.group == 1 && fm.dilation[0] == 1 && | |||
fm.dilation[1] == 1); | |||
// single channel conv should never use matrix mul | |||
if (fm.ocpg == 1 || fm.icpg == 1) | |||
@@ -34,6 +34,10 @@ public: | |||
} | |||
void* type() const override; | |||
bool is_preferred(const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
/* ===================== avx2 stride2 chanwise algo ===================== */ | |||
@@ -55,6 +59,10 @@ public: | |||
} | |||
void* type() const override; | |||
bool is_preferred(const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
/* ===================== avx2 stride1 direct algo ===================== */ | |||
@@ -76,6 +84,10 @@ public: | |||
} | |||
void* type() const override; | |||
bool is_preferred(const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
/* ================== avx2 int8 direct conv stride2 algo ================== */ | |||
@@ -97,6 +109,10 @@ public: | |||
} | |||
void* type() const override; | |||
bool is_preferred(const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
#if MEGDNN_X86_WITH_MKL_DNN | |||
@@ -134,6 +150,10 @@ public: | |||
} | |||
void* type() const override; | |||
bool is_preferred(const NCBKernSizeParam& param) const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::DIRECT}; | |||
} | |||
}; | |||
/* ===================== mkldnn qint8 matmul algo ===================== */ | |||
class ConvBiasImpl::AlgoMkldnnMatmulQint8 final : public AlgoBase { | |||
@@ -160,6 +180,10 @@ public: | |||
bool is_preferred(const NCBKernSizeParam& param) const override; | |||
void* type() const override; | |||
ConvAlgoTypePack get_algo_type() const override { | |||
return {AlgoDataType::QINT8X8X32, AlgoCategory::IM2COL}; | |||
} | |||
}; | |||
#endif | |||
@@ -103,10 +103,10 @@ public: | |||
#endif | |||
all_algos.emplace_back(&stride1_direct); | |||
all_algos.emplace_back(&stride2_direct); | |||
all_algos.emplace_back(&avx2_stride1_direct_int8); | |||
all_algos.emplace_back(&avx2_stride2_direct); | |||
all_algos.emplace_back(&avx2_stride1_chanwsie_qint8); | |||
all_algos.emplace_back(&avx2_stride2_chanwsie_qint8); | |||
all_algos.emplace_back(&avx2_stride1_direct_int8); | |||
all_algos.emplace_back(&avx2_stride2_direct); | |||
all_algos.emplace_back(&matmul); | |||
static CpuOprDelegationStorage<> storage; | |||
@@ -182,4 +182,41 @@ bool ConvBiasImpl::is_matmul_quantized_prefer( | |||
!chanwise_avx2_stride2_qint8_usable_preferred(param)); | |||
} | |||
SmallVector<AlgoCategory> | |||
ConvBiasImpl::suggest_algo_category_order(const NCBKernSizeParam& param) const { | |||
auto IC = param.filter_meta.icpg; | |||
auto OC = param.filter_meta.ocpg; | |||
auto FH = param.filter_meta.spatial[0]; | |||
auto FW = param.filter_meta.spatial[1]; | |||
//! TODO: now winograd only support fast-run | |||
if (param.filter_meta.format == param::ConvBias::Format::NCHW_WINOGRAD || | |||
param.filter_meta.format == param::ConvBias::Format::NCHW44_WINOGRAD || | |||
param.filter_meta.format == param::ConvBias::Format::NCHW88_WINOGRAD) { | |||
return {AlgoCategory::WINOGRAD}; | |||
} | |||
//! nchw88 use mkl-dnn which algo is direct | |||
if (param.filter_meta.format == param::ConvBias::Format::NCHW88) { | |||
return {AlgoCategory::DIRECT, AlgoCategory::IM2COL}; | |||
} | |||
//! im2col + matmul | |||
bool im2col_prefer = (IC >= 32 || OC >= 32); | |||
//! quantized algo use matmul when direct algo is unusable | |||
if (param.src_type.category() == DTypeCategory::QUANTIZED) { | |||
im2col_prefer = is_matmul_quantized_prefer(param); | |||
} | |||
//! conv1x1 | |||
im2col_prefer |= (FH == 1 && FW == 1); | |||
//! x86 8x8x16 not optmized, so it will use fallback im2col+matmul | |||
if (param.deduce_algo_data_type() == AlgoDataType::INT8X8X16) { | |||
im2col_prefer = true; | |||
} | |||
if (im2col_prefer) { | |||
return {AlgoCategory::IM2COL, AlgoCategory::DIRECT, | |||
AlgoCategory::NAIVE}; | |||
} else { | |||
return {AlgoCategory::DIRECT, AlgoCategory::IM2COL, | |||
AlgoCategory::NAIVE}; | |||
} | |||
} | |||
// vim: syntax=cpp.doxygen |
@@ -24,6 +24,8 @@ public: | |||
bool is_thread_safe() const override { return true; } | |||
SmallVector<AlgoBase*> algo_pack() override; | |||
SmallVector<AlgoCategory> suggest_algo_category_order( | |||
const NCBKernSizeParam& param) const override; | |||
class AlgoDirect; | |||
class AlgoDirectStride2; | |||
@@ -184,11 +184,10 @@ MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32Vnni::get_kern( | |||
return int8x8x32_kern_vnni; | |||
} | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32Vnni, | |||
megdnn_x86_matmul_kern, | |||
"AlgoInt8x8x32Vnni"_hash, | |||
x86::matmul::gemm_int8_vnni_12x32x4, | |||
dt_int8, dt_int32, dt_uint8); | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( | |||
AlgoInt8x8x32Vnni, megdnn_x86_matmul_kern, "AlgoInt8x8x32Vnni"_hash, | |||
x86::matmul::gemm_int8_vnni_12x32x4, dt_int8, dt_int32, | |||
dt_uint8AlgoDataType::QINT8X8X32, DEFAULT); | |||
#endif | |||
/* ===================== Int8 mkldnn algo ===================== */ | |||
@@ -397,7 +396,8 @@ size_t MatrixMulImpl::AlgoInt8x8x16AVX2::get_workspace( | |||
} | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( | |||
AlgoInt8x8x16AVX2, megdnn_x86_matmul_kern, "AlgoInt8x8x16AVX2"_hash, | |||
x86::matmul::gemm_avx2_s8s8s16_4x16x2, dt_int8, dt_int16, dt_int16); | |||
x86::matmul::gemm_avx2_s8s8s16_4x16x2, dt_int8, dt_int16, dt_int16, | |||
AlgoDataType::INT8X8X16, DEFAULT); | |||
/*************************AlgoInt8x8x16SSE********************/ | |||
void MatrixMulImpl::AlgoInt8x8x16SSE::gemm_s8s8s16_sse_4x8x2( | |||
@@ -474,7 +474,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x16SSE, | |||
megdnn_x86_matmul_kern, | |||
"AlgoInt8x8x16SSE"_hash, | |||
x86::matmul::gemm_sse_s8s8s16_4x8x2, | |||
dt_int8, dt_int16, dt_int16); | |||
dt_int8, dt_int16, dt_int16, | |||
AlgoDataType::INT8X8X16, DEFAULT); | |||
/*************************AlgoInt8x8x32AVX2M4N16K2********************/ | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_kern( | |||
@@ -516,7 +517,7 @@ size_t MatrixMulImpl::AlgoInt8x8x32AVX2M4N16K2::get_workspace( | |||
MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( | |||
AlgoInt8x8x32AVX2M4N16K2, megdnn_x86_matmul_kern, | |||
"AlgoInt8x8x32AVX2M4N16K2"_hash, x86::matmul::gemm_avx2_s8s8s32_4x16x2, | |||
dt_int8, dt_int32, dt_int16); | |||
dt_int8, dt_int32, dt_int16, AlgoDataType::QINT8X8X32, DEFAULT); | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32AVX2M2N4K16::get_kern( | |||
const KernSizeParam&) const { | |||
@@ -556,7 +557,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL(AlgoInt8x8x32AVX2M2N4K16, | |||
megdnn_x86_matmul_kern, | |||
"AlgoInt8x8x32AVX2M2N4K16"_hash, | |||
x86::matmul::gemm_avx2_s8s8s32_2x4x16, | |||
dt_int8, dt_int32); | |||
dt_int8, dt_int32, | |||
AlgoDataType::QINT8X8X32, DEFAULT); | |||
/*************************AlgoInt8x8x32SSEM4N8K2********************/ | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoInt8x8x32SSEM4N8K2::get_kern( | |||
@@ -596,7 +598,8 @@ MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL(AlgoInt8x8x32SSEM4N8K2, | |||
megdnn_x86_matmul_kern, | |||
"AlgoInt8x8x32SSEM4N8K2"_hash, | |||
x86::matmul::gemm_sse_s8s8s32_4x8x2, | |||
dt_int8, dt_int32, dt_int16); | |||
dt_int8, dt_int32, dt_int16, | |||
AlgoDataType::QINT8X8X32, DEFAULT); | |||
/*************************AlgoF32MK8_8x8********************/ | |||
MatrixMulImpl::kern_t MatrixMulImpl::AlgoF32MK8_8x8::get_kern( | |||
@@ -27,7 +27,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_x86_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) | |||
}; | |||
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM | |||
@@ -49,7 +49,7 @@ public: | |||
WorkspaceBundle get_bundle(const KernSizeParam& param) const override; | |||
InnerBlockSize get_inner_block_size() const override{ return {8, 16, 1}; }; | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4, AlgoDataType::FLOAT32, DEFAULT) | |||
}; | |||
#endif | |||
@@ -127,7 +127,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_x86_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4) | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4, AlgoDataType::FLOAT32, MK8) | |||
}; | |||
#if MEGDNN_X86_WITH_VNNI | |||
@@ -153,7 +153,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_x86_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2, AlgoDataType::QINT8X8X32, DEFAULT) | |||
}; | |||
#endif | |||
} // namespace x86 | |||
@@ -495,8 +495,9 @@ class AlgoChooser { | |||
} | |||
} | |||
mgb_assert(found, | |||
"algo got by heuristic not found in " | |||
"candidate list"); | |||
"algo %s got by heuristic not found in " | |||
"candidate list", | |||
heu->name()); | |||
return std::move(ret); | |||
} | |||
@@ -628,7 +629,7 @@ public: | |||
auto algo = get_algo(ctx); | |||
size_t workspace = ctx.get_workspace_size_bytes(algo); | |||
mgb_log_debug( | |||
"%s: input shapes (%s %s, %s %s) -> (%s %s): algo=%s " | |||
"%s:tensor layouts (%s %s, %s %s)->(%s %s) :algo=%s " | |||
"workspace=%.2fMiB reproducible=%d", | |||
mgb_opr->dyn_typeinfo()->name, | |||
layouts[0].to_string().c_str(), | |||
@@ -636,8 +637,7 @@ public: | |||
layouts[1].to_string().c_str(), | |||
layouts[1].dtype.name(), | |||
layouts[layouts.size() - 1].to_string().c_str(), | |||
layouts[layouts.size() - 1].dtype.name(), | |||
algo->name(), | |||
layouts[layouts.size() - 1].dtype.name(), algo->name(), | |||
workspace / (1024 * 1024.0), algo->is_reproducible()); | |||
megdnn_opr->execution_policy() = {algo}; | |||
return workspace; | |||