GitOrigin-RevId: 5bde0b60f0
release-0.6
@@ -60,6 +60,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4) | |||
}; | |||
class MatrixMulImpl::AlgoF32Gemv final | |||
@@ -86,6 +87,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2) | |||
}; | |||
#endif | |||
@@ -207,6 +209,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2) | |||
}; | |||
#if __ARM_FEATURE_DOTPROD | |||
@@ -234,6 +237,7 @@ public: | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||
}; | |||
#else | |||
@@ -12,6 +12,7 @@ | |||
#pragma once | |||
#include "src/arm_common/matrix_mul/opr_impl.h" | |||
#include "src/fallback/matrix_mul/gemm_common.h" | |||
namespace megdnn { | |||
namespace arm_common { | |||
@@ -25,6 +26,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | |||
@@ -38,6 +40,7 @@ public: | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||
}; | |||
class MatrixMulImpl::AlgoF32Gemv : public AlgoBase { | |||
@@ -54,6 +57,7 @@ public: | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||
}; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
@@ -68,6 +72,7 @@ public: | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||
}; | |||
#endif | |||
@@ -82,6 +87,7 @@ public: | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||
}; | |||
@@ -49,6 +49,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4) | |||
}; | |||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | |||
@@ -71,6 +72,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2) | |||
}; | |||
#endif | |||
#if __ARM_FEATURE_DOTPROD | |||
@@ -190,6 +192,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_arm_common_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2) | |||
}; | |||
class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | |||
@@ -47,14 +47,17 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||
} | |||
//! packA_kern | |||
static void packA_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
StrategyBase* im2colstrategy, size_t pack_oc_size) { | |||
static void packA_kern( | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
StrategyBase* im2colstrategy, | |||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc, | |||
size_t pack_oc_size) { | |||
im2colstrategy->packA_kern(bundle, param, matmulparam, matmul_algo, | |||
ncb_index, pack_oc_size); | |||
ncb_index, matmul_desc, pack_oc_size); | |||
} | |||
/*! | |||
@@ -72,7 +75,8 @@ public: | |||
WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
const ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc, | |||
StrategyParam strategyparam, | |||
fallback::ConvBiasImpl::NCBKernIndex ncb_index, | |||
size_t ohw_tile_size, StrategyBase* im2colstrategy) { | |||
@@ -111,7 +115,8 @@ public: | |||
//! 2.packb and matmul compute | |||
im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread, | |||
matmul_param, matmul_algo, ncb_index); | |||
matmul_param, matmul_algo, ncb_index, | |||
matmul_desc); | |||
//! 3.postprocess and copy dst if need | |||
im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread); | |||
@@ -151,7 +156,8 @@ public: | |||
WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
const ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc, | |||
StrategyParam strategyparam, | |||
fallback::ConvBiasImpl::NCBKernIndex ncb_index, | |||
size_t ohw_tile_size, StrategyBase* im2colstrategy) { | |||
@@ -191,7 +197,8 @@ public: | |||
//! 2.packb and matmul compute | |||
im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread, | |||
matmul_param, matmul_algo, ncb_index); | |||
matmul_param, matmul_algo, ncb_index, | |||
matmul_desc); | |||
//! 3.postprocess and copy dst if need | |||
im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread); | |||
@@ -232,7 +239,8 @@ public: | |||
WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
const ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc, | |||
StrategyParam strategyparam, | |||
fallback::ConvBiasImpl::NCBKernIndex ncb_index, | |||
size_t ohw_tile_size, StrategyBase* im2colstrategy) { | |||
@@ -272,7 +280,8 @@ public: | |||
//! 2.packb and matmul compute | |||
im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread, | |||
matmul_param, matmul_algo, ncb_index); | |||
matmul_param, matmul_algo, ncb_index, | |||
matmul_desc); | |||
//! 3.postprocess and copy dst if need | |||
im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread); | |||
@@ -401,13 +410,15 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( | |||
size_t padding = 0, packa_size = 0, packa_group_size = 0; | |||
size_t nr_threads = param.nr_threads; | |||
size_t GROUP = param.filter_meta.group; | |||
bool need_pack = m_matmul_algo->packmode() == Pack_Mode::DEFAULT; | |||
bool only_packA = m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA; | |||
fallback::MatrixMulImpl::AlgoBase::MatmulDescription mdesc = | |||
m_matmul_algo->matmul_description(); | |||
bool need_pack = mdesc.packmode == Pack_Mode::DEFAULT; | |||
bool only_packA = mdesc.packmode == Pack_Mode::ONLY_PACKA; | |||
size_t oc_tile_size = 0, ohw_tile_size = 0; | |||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||
mdesc.innerblocksize.m, mdesc.innerblocksize.n, | |||
mdesc.packmode); | |||
if (need_pack || only_packA) { | |||
auto inner_block = m_matmul_algo->get_inner_block_size(); | |||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, inner_block.m, | |||
inner_block.n, m_matmul_algo->packmode()); | |||
auto im2col_kern_param = get_matmul_kern_param( | |||
param, ohw_tile_size, only_packA ? oc_tile_size : OC); | |||
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | |||
@@ -415,11 +426,6 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( | |||
packa_group_size = only_packA ? oc_parallel_times * wb.get_size(0) | |||
: wb.get_size(0); | |||
} else { //! not support pack,not need pack | |||
size_t nopack_default_blockm = 8; | |||
size_t nopack_default_blockn = 16; | |||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||
nopack_default_blockm, nopack_default_blockn, | |||
m_matmul_algo->packmode()); | |||
packa_group_size = 0; | |||
} | |||
@@ -481,23 +487,18 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||
WorkspaceBundle bundle = get_bundle(param); | |||
WorkspaceBundle bundle_thread = {nullptr, {}}; | |||
bool need_padding = (PH != 0 || PW != 0); | |||
Pack_Mode packmode = m_matmul_algo->packmode(); | |||
fallback::MatrixMulImpl::AlgoBase::MatmulDescription mdesc = | |||
m_matmul_algo->matmul_description(); | |||
Pack_Mode packmode = mdesc.packmode; | |||
bool default_pack = packmode == Pack_Mode::DEFAULT; | |||
bool no_pack = packmode == Pack_Mode::NO_PACK; | |||
bool only_packA = packmode == Pack_Mode::ONLY_PACKA; | |||
if (default_pack || only_packA) { | |||
auto inner_block = m_matmul_algo->get_inner_block_size(); | |||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||
inner_block.m, inner_block.n, | |||
m_matmul_algo->packmode()); | |||
} else { //! nopack_mode | |||
size_t nopack_default_blockm = 8; | |||
size_t nopack_default_blockn = 16; | |||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||
nopack_default_blockm, nopack_default_blockn, | |||
m_matmul_algo->packmode()); | |||
} | |||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||
mdesc.innerblocksize.m, mdesc.innerblocksize.n, | |||
mdesc.packmode); | |||
size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size); | |||
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | |||
@@ -507,18 +508,17 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||
if (only_packA) { | |||
packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | |||
} else if (default_pack) { | |||
packa_parallel_times = div_ceil<size_t>( | |||
OC, m_matmul_algo->get_inner_block_size().m); | |||
packa_parallel_times = div_ceil<size_t>(OC, mdesc.innerblocksize.m); | |||
} | |||
auto matmul_param = get_matmul_kern_param( | |||
param, ohw_tile_size, only_packA ? oc_tile_size : OC); | |||
if (m_matmul_algo->packmode() == Pack_Mode::DEFAULT) { | |||
if (mdesc.packmode == Pack_Mode::DEFAULT) { | |||
Im2colKerns<Pack_Mode::DEFAULT> defaultkern; | |||
bundle_thread = defaultkern.get_thread_bundle( | |||
param, matmul_param, m_matmul_algo, ohw_tile_size, | |||
oc_tile_size); | |||
} else if (m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA) { | |||
} else if (mdesc.packmode == Pack_Mode::ONLY_PACKA) { | |||
Im2colKerns<Pack_Mode::ONLY_PACKA> onlypackakern; | |||
bundle_thread = onlypackakern.get_thread_bundle( | |||
param, matmul_param, m_matmul_algo, ohw_tile_size, | |||
@@ -559,24 +559,24 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||
auto kern_packA = [bundle, matmul_algo = m_matmul_algo, | |||
matmul_param, im2colstrategy, | |||
pack_oc_size = pack_oc_size]( | |||
const NCBKernParam& param, | |||
const NCBKernIndex& ncb_index) { | |||
pack_oc_size = pack_oc_size, | |||
mdesc = mdesc](const NCBKernParam& param, | |||
const NCBKernIndex& ncb_index) { | |||
packA_kern(bundle, param, matmul_param, matmul_algo, ncb_index, | |||
im2colstrategy, pack_oc_size); | |||
im2colstrategy, mdesc, pack_oc_size); | |||
}; | |||
if (default_pack) { | |||
auto kern_compute_default = | |||
[bundle, bundle_thread, matmul_param, | |||
matmul_algo = m_matmul_algo, | |||
ohw_tile_size = ohw_tile_size, | |||
strategyparam = strategyparam, | |||
strategyparam = strategyparam, matmul_desc = mdesc, | |||
im2colstrategy](const NCBKernParam& param, | |||
const NCBKernIndex& ncb_index) { | |||
Im2colKerns<Pack_Mode::DEFAULT>::kerns( | |||
bundle, bundle_thread, param, matmul_param, | |||
matmul_algo, strategyparam, ncb_index, | |||
ohw_tile_size, im2colstrategy); | |||
matmul_algo, matmul_desc, strategyparam, | |||
ncb_index, ohw_tile_size, im2colstrategy); | |||
}; | |||
ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}}); | |||
@@ -592,13 +592,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||
[bundle, bundle_thread, matmul_param, | |||
matmul_algo = m_matmul_algo, | |||
strategyparam = strategyparam, | |||
ohw_tile_size = ohw_tile_size, | |||
ohw_tile_size = ohw_tile_size, matmul_desc = mdesc, | |||
im2colstrategy](const NCBKernParam& param, | |||
const NCBKernIndex& ncb_index) { | |||
Im2colKerns<Pack_Mode::ONLY_PACKA>::kerns( | |||
bundle, bundle_thread, param, matmul_param, | |||
matmul_algo, strategyparam, ncb_index, | |||
ohw_tile_size, im2colstrategy); | |||
matmul_algo, matmul_desc, strategyparam, | |||
ncb_index, ohw_tile_size, im2colstrategy); | |||
}; | |||
ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}}); | |||
if (need_padding) { | |||
@@ -612,13 +612,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||
[bundle, bundle_thread, matmul_param, | |||
matmul_algo = m_matmul_algo, | |||
strategyparam = strategyparam, | |||
ohw_tile_size = ohw_tile_size, | |||
ohw_tile_size = ohw_tile_size, matmul_desc = mdesc, | |||
im2colstrategy](const NCBKernParam& param, | |||
const NCBKernIndex& ncb_index) { | |||
Im2colKerns<Pack_Mode::NO_PACK>::kerns( | |||
bundle, bundle_thread, param, matmul_param, | |||
matmul_algo, strategyparam, ncb_index, | |||
ohw_tile_size, im2colstrategy); | |||
matmul_algo, matmul_desc, strategyparam, | |||
ncb_index, ohw_tile_size, im2colstrategy); | |||
}; | |||
if (need_padding) { | |||
@@ -668,10 +668,12 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||
return false; | |||
} | |||
} | |||
fallback::MatrixMulImpl::AlgoBase::MatmulDescription mdesc = | |||
m_matmul_algo->matmul_description(); | |||
if (opr->param().format == param::ConvBias::Format::NCHW44 || | |||
opr->param().format == param::ConvBias::Format::NCHW44_DOT) { | |||
//! current NCHW44 im2col only support DEFAULT mode matmul | |||
if (m_matmul_algo->packmode() != Pack_Mode::DEFAULT) { | |||
if (mdesc.packmode != Pack_Mode::DEFAULT) { | |||
return false; | |||
//! nchw44 hybird mode and channel wise is not support | |||
} else if (param.filter_meta.icpg < 4_z || | |||
@@ -682,22 +684,9 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||
} | |||
size_t oc_tile_size = 0, ohw_tile_size = 0; | |||
Pack_Mode packmode = m_matmul_algo->packmode(); | |||
bool default_pack = packmode == Pack_Mode::DEFAULT; | |||
bool only_packA = packmode == Pack_Mode::ONLY_PACKA; | |||
if (default_pack || only_packA) { | |||
auto inner_block = m_matmul_algo->get_inner_block_size(); | |||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||
inner_block.m, inner_block.n, | |||
m_matmul_algo->packmode()); | |||
} else { //! not support pack,not need pack | |||
size_t nopack_default_blockm = 8; | |||
size_t nopack_default_blockn = 16; | |||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||
nopack_default_blockm, nopack_default_blockn, | |||
m_matmul_algo->packmode()); | |||
} | |||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||
mdesc.innerblocksize.m, mdesc.innerblocksize.n, | |||
m_matmul_algo->packmode()); | |||
fallback::MatrixMulImpl::KernSizeParam matmul_param = | |||
get_matmul_kern_param(param, ohw_tile_size, oc_tile_size); | |||
bool matmulusable = m_matmul_algo->usable(matmul_param); | |||
@@ -58,8 +58,9 @@ public: | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desec, | |||
size_t pack_size) = 0; | |||
virtual void exec_im2col( | |||
@@ -67,15 +68,17 @@ public: | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) = 0; | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) = 0; | |||
virtual void exec_matmul( | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const StrategyParam& sparam, WorkspaceBundle bundle, | |||
WorkspaceBundle bundle_thread, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) = 0; | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc | |||
) = 0; | |||
virtual void exec_postprocess( | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
@@ -284,26 +287,30 @@ public: | |||
Strategy() = default; | |||
virtual void packA_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t pack_size) override; | |||
virtual void packA_kern( | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc, | |||
size_t pack_size) override; | |||
virtual void exec_im2col( | |||
WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
void exec_matmul( | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const StrategyParam& sparam, WorkspaceBundle bundle, | |||
WorkspaceBundle bundle_thread, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc | |||
) override; | |||
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const StrategyParam& sparam, | |||
WorkspaceBundle bundle_thread) override { | |||
@@ -338,7 +345,7 @@ public: | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
}; | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
@@ -359,20 +366,24 @@ public: | |||
Strategy() = default; | |||
void packA_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t pack_size) override; | |||
void packA_kern( | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& MDsec, | |||
size_t pack_size) override; | |||
void exec_matmul( | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const StrategyParam& sparam, WorkspaceBundle bundle, | |||
WorkspaceBundle bundle_thread, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc | |||
) override; | |||
void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const WorkspaceBundle& bundle_thread, | |||
@@ -382,7 +393,7 @@ public: | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const StrategyParam& sparam, | |||
WorkspaceBundle bundle_thread) override { | |||
@@ -411,26 +422,30 @@ public: | |||
Strategy() = default; | |||
void packA_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
size_t pack_size) override; | |||
void packA_kern( | |||
WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& MDsec, | |||
size_t pack_size) override; | |||
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
void exec_matmul( | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const StrategyParam& sparam, WorkspaceBundle bundle, | |||
WorkspaceBundle bundle_thread, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc | |||
) override; | |||
void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | |||
const WorkspaceBundle& bundle_thread, | |||
@@ -465,7 +480,7 @@ public: | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
}; | |||
template <typename op_ctype, typename op_dtype, | |||
@@ -487,7 +502,7 @@ public: | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
}; | |||
@@ -510,7 +525,7 @@ public: | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||
}; | |||
#endif | |||
@@ -21,8 +21,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
packA_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& | |||
matmul_desc, | |||
size_t) { | |||
bundle.set(param.workspace_ptr); | |||
fallback::MatrixMulImpl::KernParam matmul_param; | |||
@@ -31,16 +33,16 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
matmulparam; | |||
size_t packA_group_size = matmul_algo->get_bundle(matmul_param).get_size(0); | |||
size_t packed_per_oc_block_size = | |||
round_up(matmul_param.K, matmul_algo->get_inner_block_size().k) * | |||
matmul_algo->get_inner_block_size().m * | |||
matmul_algo->get_packA_type_size(); | |||
round_up(matmul_param.K, matmul_desc.innerblocksize.k) * | |||
matmul_desc.innerblocksize.m * matmul_desc.packa_type_size; | |||
size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size; | |||
int8_t* a_panel = static_cast<int8_t*>(bundle.get(BUNDLE_PACKA_INDEX)) + | |||
group_id * packA_group_size + a_panel_offset; | |||
matmul_param.A_ptr = | |||
const_cast<src_ctype*>(param.filter<src_ctype>(group_id)); | |||
matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1], | |||
matmul_algo->get_inner_block_size().m); | |||
matmul_desc.innerblocksize.m); | |||
} | |||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | |||
@@ -52,7 +54,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||
size_t sh = param.filter_meta.stride[0]; | |||
size_t sw = param.filter_meta.stride[1]; | |||
size_t oc = param.filter_meta.ocpg; | |||
@@ -140,11 +142,13 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
const StrategyParam& sparam, WorkspaceBundle bundle, | |||
WorkspaceBundle bundle_thread, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& | |||
matmul_desc) { | |||
size_t packA_per_oc_block_size = | |||
round_up(matmul_param.K, matmul_algo->get_inner_block_size().k) * | |||
sparam.oc_tile_size * matmul_algo->get_packA_type_size(); | |||
round_up(matmul_param.K, matmul_desc.innerblocksize.k) * | |||
sparam.oc_tile_size * matmul_desc.packa_type_size; | |||
size_t packA_group_size = matmul_algo->get_bundle(matmul_param).get_size(0); | |||
size_t a_panel_offset = ncb_index.ndrange_id[1] * packA_group_size + | |||
ncb_index.ndrange_id[3] * packA_per_oc_block_size; | |||
@@ -33,7 +33,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||
size_t sh = param.filter_meta.stride[0]; | |||
size_t sw = param.filter_meta.stride[1]; | |||
size_t oc = param.filter_meta.ocpg; | |||
@@ -173,7 +173,7 @@ void StrategyFuse4x4x16Nchw44<op_ctype, op_dtype, postprocess_mode>:: | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam, | |||
fallback::MatrixMulImpl::AlgoBase*) { | |||
const fallback::MatrixMulImpl::AlgoBase*) { | |||
size_t ow = param.osz[1]; | |||
size_t ic = param.filter_meta.icpg; | |||
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||
@@ -176,7 +176,7 @@ void StrategyFuse8x12x4Nchw44Dot<op_ctype, op_dtype, postprocess_mode>:: | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam /*matmul_param*/, | |||
fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) { | |||
const fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) { | |||
size_t ow = param.osz[1]; | |||
size_t ic = param.filter_meta.icpg; | |||
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||
@@ -168,7 +168,7 @@ void StrategyFuse8x12x1Nchw44K3x3S2<op_ctype, op_dtype, postprocess_mode>:: | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam /*matmul_param*/, | |||
fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) { | |||
const fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) { | |||
size_t ow = param.osz[1]; | |||
size_t ic = param.filter_meta.icpg; | |||
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | |||
@@ -22,8 +22,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
packA_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const fallback::MatrixMulImpl::AlgoBase:: | |||
MatmulDescription& /*matmul_dsec*/, | |||
size_t) { | |||
MEGDNN_MARK_USED_VAR(bundle); | |||
MEGDNN_MARK_USED_VAR(param); | |||
@@ -62,8 +64,11 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
const StrategyParam& sparam, WorkspaceBundle bundle, | |||
WorkspaceBundle bundle_thread, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const fallback::MatrixMulImpl::AlgoBase:: | |||
MatmulDescription& /*matmul_desc*/ | |||
) { | |||
MEGDNN_MARK_USED_VAR(bundle); | |||
MEGDNN_MARK_USED_VAR(ncb_index); | |||
matmul_param.workspace_ptr = bundle_thread.get(THREAD_BUNDLE_MATCOMP_INDEX); | |||
@@ -95,7 +100,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||
MEGDNN_MARK_USED_VAR(matmul_param); | |||
MEGDNN_MARK_USED_VAR(matmul_algo); | |||
size_t sh = param.filter_meta.stride[0]; | |||
@@ -22,8 +22,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
packA_kern(WorkspaceBundle bundle, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const fallback::MatrixMulImpl::AlgoBase:: | |||
MatmulDescription& /*matmul_desc*/, | |||
size_t) { | |||
bundle.set(param.workspace_ptr); | |||
fallback::MatrixMulImpl::KernParam matmul_param; | |||
@@ -57,8 +59,11 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
const StrategyParam& sparam, WorkspaceBundle bundle, | |||
WorkspaceBundle bundle_thread, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||
const fallback::MatrixMulImpl::AlgoBase:: | |||
MatmulDescription& /*matmul_desc*/ | |||
) { | |||
size_t packA_group_size = | |||
bundle.get_size(BUNDLE_PACKA_INDEX) / param.filter_meta.group; | |||
size_t a_panel_offset = ncb_index.ndrange_id[3] * | |||
@@ -95,7 +100,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||
const StrategyParam& sparam, | |||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||
fallback::MatrixMulImpl::KernParam matmul_param, | |||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||
MEGDNN_MARK_USED_VAR(matmul_param); | |||
MEGDNN_MARK_USED_VAR(matmul_algo); | |||
size_t sh = param.filter_meta.stride[0]; | |||
@@ -37,6 +37,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||
}; | |||
} // namespace fallback | |||
@@ -352,6 +352,15 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K, | |||
DType dtype_c) \ | |||
: A_dtype(dtype_a), B_dtype(dtype_b), C_dtype(dtype_c) {} | |||
#define MEGDNN_OVERRIDE_MATMUL_DESC(_m, _n, _k, _packa_type_size) \ | |||
MatmulDescription matmul_description() const override { \ | |||
MatmulDescription mdesc; \ | |||
mdesc.packmode = packmode(); \ | |||
mdesc.innerblocksize = {_m, _n, _k}; \ | |||
mdesc.packa_type_size = _packa_type_size; \ | |||
return mdesc; \ | |||
} | |||
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL() \ | |||
WorkspaceBundle get_bundle(const KernSizeParam&) const override; \ | |||
kern_naked_t get_kern_naked(const KernSizeParam&) const override; \ | |||
@@ -360,7 +369,7 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K, | |||
void pack_B(const KernParam& kern_param, void* out, size_t x0, \ | |||
size_t xmax) const override; \ | |||
InnerBlockSize get_inner_block_size() const override; \ | |||
size_t get_packA_type_size() const override; | |||
MatmulDescription matmul_description() const override; | |||
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \ | |||
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \ | |||
@@ -458,8 +467,14 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K, | |||
_strategy::UNROLL_K}; \ | |||
} \ | |||
\ | |||
size_t MatrixMulImpl::_algo_name::get_packA_type_size() const { \ | |||
return sizeof(_packa_type); \ | |||
MatrixMulImpl::_algo_name::MatmulDescription \ | |||
MatrixMulImpl::_algo_name::matmul_description() const { \ | |||
MatmulDescription mdesc; \ | |||
mdesc.packmode = PackMode(); \ | |||
mdesc.innerblocksize = {_strategy::KERNEL_H, _strategy::KERNEL_W, \ | |||
_strategy::UNROLL_K}; \ | |||
mdesc.packa_type_size = sizeof(_packa_type); \ | |||
return mdesc; \ | |||
} | |||
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \ | |||
@@ -104,6 +104,12 @@ public: | |||
size_t m, n, k; | |||
}; | |||
struct MatmulDescription { | |||
PackMode packmode; | |||
InnerBlockSize innerblocksize; | |||
size_t packa_type_size; | |||
}; | |||
virtual bool usable(const KernSizeParam&) const = 0; | |||
virtual bool preferred(const KernSizeParam&) const { return true; } | |||
virtual size_t get_workspace(const KernSizeParam&) const = 0; | |||
@@ -125,11 +131,11 @@ public: | |||
virtual InnerBlockSize get_inner_block_size() const { | |||
megdnn_assert(0); | |||
}; | |||
virtual size_t get_packA_type_size() const { megdnn_assert(0); }; | |||
bool preferred_reproducible(const KernSizeParam& param, | |||
bool reproducible = true) { | |||
return (!reproducible || is_reproducible()) && preferred(param); | |||
}; | |||
virtual MatmulDescription matmul_description() const = 0; | |||
}; | |||
/** | |||
@@ -27,6 +27,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_x86_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||
}; | |||
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM | |||
@@ -46,7 +47,9 @@ public: | |||
megdnn_assert(0); | |||
}; | |||
WorkspaceBundle get_bundle(const KernSizeParam& param) const override; | |||
InnerBlockSize get_inner_block_size() const override { return {8, 16, 1}; }; | |||
InnerBlockSize get_inner_block_size() const override{ return {8, 16, 1}; }; | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||
}; | |||
#endif | |||
@@ -124,6 +127,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_x86_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4) | |||
}; | |||
#if MEGDNN_X86_WITH_VNNI | |||
@@ -149,6 +153,7 @@ public: | |||
kern_t get_kern(const KernSizeParam&) const override; | |||
void* type() const override { return sm_x86_algo_type; } | |||
PackMode packmode() const override { return PackMode::NO_PACK; } | |||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||
}; | |||
#endif | |||
} // namespace x86 | |||