GitOrigin-RevId: 5bde0b60f0
release-0.6
@@ -60,6 +60,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4) | |||||
}; | }; | ||||
class MatrixMulImpl::AlgoF32Gemv final | class MatrixMulImpl::AlgoF32Gemv final | ||||
@@ -86,6 +87,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2) | |||||
}; | }; | ||||
#endif | #endif | ||||
@@ -207,6 +209,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2) | |||||
}; | }; | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
@@ -234,6 +237,7 @@ public: | |||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||||
}; | }; | ||||
#else | #else | ||||
@@ -12,6 +12,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "src/arm_common/matrix_mul/opr_impl.h" | #include "src/arm_common/matrix_mul/opr_impl.h" | ||||
#include "src/fallback/matrix_mul/gemm_common.h" | |||||
namespace megdnn { | namespace megdnn { | ||||
namespace arm_common { | namespace arm_common { | ||||
@@ -25,6 +26,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||||
}; | }; | ||||
class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { | ||||
@@ -38,6 +40,7 @@ public: | |||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||||
}; | }; | ||||
class MatrixMulImpl::AlgoF32Gemv : public AlgoBase { | class MatrixMulImpl::AlgoF32Gemv : public AlgoBase { | ||||
@@ -54,6 +57,7 @@ public: | |||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||||
}; | }; | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
@@ -68,6 +72,7 @@ public: | |||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||||
}; | }; | ||||
#endif | #endif | ||||
@@ -82,6 +87,7 @@ public: | |||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||||
}; | }; | ||||
@@ -49,6 +49,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4) | |||||
}; | }; | ||||
#if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC | ||||
@@ -71,6 +72,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2) | |||||
}; | }; | ||||
#endif | #endif | ||||
#if __ARM_FEATURE_DOTPROD | #if __ARM_FEATURE_DOTPROD | ||||
@@ -190,6 +192,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_arm_common_algo_type; } | void* type() const override { return sm_arm_common_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2) | |||||
}; | }; | ||||
class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { | ||||
@@ -47,14 +47,17 @@ static void copy_padding_kern(WorkspaceBundle bundle, | |||||
} | } | ||||
//! packA_kern | //! packA_kern | ||||
static void packA_kern(WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
StrategyBase* im2colstrategy, size_t pack_oc_size) { | |||||
static void packA_kern( | |||||
WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
StrategyBase* im2colstrategy, | |||||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc, | |||||
size_t pack_oc_size) { | |||||
im2colstrategy->packA_kern(bundle, param, matmulparam, matmul_algo, | im2colstrategy->packA_kern(bundle, param, matmulparam, matmul_algo, | ||||
ncb_index, pack_oc_size); | |||||
ncb_index, matmul_desc, pack_oc_size); | |||||
} | } | ||||
/*! | /*! | ||||
@@ -72,7 +75,8 @@ public: | |||||
WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | ||||
const ConvBiasImpl::NCBKernParam& param, | const ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param, | fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc, | |||||
StrategyParam strategyparam, | StrategyParam strategyparam, | ||||
fallback::ConvBiasImpl::NCBKernIndex ncb_index, | fallback::ConvBiasImpl::NCBKernIndex ncb_index, | ||||
size_t ohw_tile_size, StrategyBase* im2colstrategy) { | size_t ohw_tile_size, StrategyBase* im2colstrategy) { | ||||
@@ -111,7 +115,8 @@ public: | |||||
//! 2.packb and matmul compute | //! 2.packb and matmul compute | ||||
im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread, | im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread, | ||||
matmul_param, matmul_algo, ncb_index); | |||||
matmul_param, matmul_algo, ncb_index, | |||||
matmul_desc); | |||||
//! 3.postprocess and copy dst if need | //! 3.postprocess and copy dst if need | ||||
im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread); | im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread); | ||||
@@ -151,7 +156,8 @@ public: | |||||
WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | ||||
const ConvBiasImpl::NCBKernParam& param, | const ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param, | fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc, | |||||
StrategyParam strategyparam, | StrategyParam strategyparam, | ||||
fallback::ConvBiasImpl::NCBKernIndex ncb_index, | fallback::ConvBiasImpl::NCBKernIndex ncb_index, | ||||
size_t ohw_tile_size, StrategyBase* im2colstrategy) { | size_t ohw_tile_size, StrategyBase* im2colstrategy) { | ||||
@@ -191,7 +197,8 @@ public: | |||||
//! 2.packb and matmul compute | //! 2.packb and matmul compute | ||||
im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread, | im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread, | ||||
matmul_param, matmul_algo, ncb_index); | |||||
matmul_param, matmul_algo, ncb_index, | |||||
matmul_desc); | |||||
//! 3.postprocess and copy dst if need | //! 3.postprocess and copy dst if need | ||||
im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread); | im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread); | ||||
@@ -232,7 +239,8 @@ public: | |||||
WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | ||||
const ConvBiasImpl::NCBKernParam& param, | const ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param, | fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc, | |||||
StrategyParam strategyparam, | StrategyParam strategyparam, | ||||
fallback::ConvBiasImpl::NCBKernIndex ncb_index, | fallback::ConvBiasImpl::NCBKernIndex ncb_index, | ||||
size_t ohw_tile_size, StrategyBase* im2colstrategy) { | size_t ohw_tile_size, StrategyBase* im2colstrategy) { | ||||
@@ -272,7 +280,8 @@ public: | |||||
//! 2.packb and matmul compute | //! 2.packb and matmul compute | ||||
im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread, | im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread, | ||||
matmul_param, matmul_algo, ncb_index); | |||||
matmul_param, matmul_algo, ncb_index, | |||||
matmul_desc); | |||||
//! 3.postprocess and copy dst if need | //! 3.postprocess and copy dst if need | ||||
im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread); | im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread); | ||||
@@ -401,13 +410,15 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( | |||||
size_t padding = 0, packa_size = 0, packa_group_size = 0; | size_t padding = 0, packa_size = 0, packa_group_size = 0; | ||||
size_t nr_threads = param.nr_threads; | size_t nr_threads = param.nr_threads; | ||||
size_t GROUP = param.filter_meta.group; | size_t GROUP = param.filter_meta.group; | ||||
bool need_pack = m_matmul_algo->packmode() == Pack_Mode::DEFAULT; | |||||
bool only_packA = m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA; | |||||
fallback::MatrixMulImpl::AlgoBase::MatmulDescription mdesc = | |||||
m_matmul_algo->matmul_description(); | |||||
bool need_pack = mdesc.packmode == Pack_Mode::DEFAULT; | |||||
bool only_packA = mdesc.packmode == Pack_Mode::ONLY_PACKA; | |||||
size_t oc_tile_size = 0, ohw_tile_size = 0; | size_t oc_tile_size = 0, ohw_tile_size = 0; | ||||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||||
mdesc.innerblocksize.m, mdesc.innerblocksize.n, | |||||
mdesc.packmode); | |||||
if (need_pack || only_packA) { | if (need_pack || only_packA) { | ||||
auto inner_block = m_matmul_algo->get_inner_block_size(); | |||||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, inner_block.m, | |||||
inner_block.n, m_matmul_algo->packmode()); | |||||
auto im2col_kern_param = get_matmul_kern_param( | auto im2col_kern_param = get_matmul_kern_param( | ||||
param, ohw_tile_size, only_packA ? oc_tile_size : OC); | param, ohw_tile_size, only_packA ? oc_tile_size : OC); | ||||
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | ||||
@@ -415,11 +426,6 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( | |||||
packa_group_size = only_packA ? oc_parallel_times * wb.get_size(0) | packa_group_size = only_packA ? oc_parallel_times * wb.get_size(0) | ||||
: wb.get_size(0); | : wb.get_size(0); | ||||
} else { //! not support pack,not need pack | } else { //! not support pack,not need pack | ||||
size_t nopack_default_blockm = 8; | |||||
size_t nopack_default_blockn = 16; | |||||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||||
nopack_default_blockm, nopack_default_blockn, | |||||
m_matmul_algo->packmode()); | |||||
packa_group_size = 0; | packa_group_size = 0; | ||||
} | } | ||||
@@ -481,23 +487,18 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||||
WorkspaceBundle bundle = get_bundle(param); | WorkspaceBundle bundle = get_bundle(param); | ||||
WorkspaceBundle bundle_thread = {nullptr, {}}; | WorkspaceBundle bundle_thread = {nullptr, {}}; | ||||
bool need_padding = (PH != 0 || PW != 0); | bool need_padding = (PH != 0 || PW != 0); | ||||
Pack_Mode packmode = m_matmul_algo->packmode(); | |||||
fallback::MatrixMulImpl::AlgoBase::MatmulDescription mdesc = | |||||
m_matmul_algo->matmul_description(); | |||||
Pack_Mode packmode = mdesc.packmode; | |||||
bool default_pack = packmode == Pack_Mode::DEFAULT; | bool default_pack = packmode == Pack_Mode::DEFAULT; | ||||
bool no_pack = packmode == Pack_Mode::NO_PACK; | bool no_pack = packmode == Pack_Mode::NO_PACK; | ||||
bool only_packA = packmode == Pack_Mode::ONLY_PACKA; | bool only_packA = packmode == Pack_Mode::ONLY_PACKA; | ||||
if (default_pack || only_packA) { | |||||
auto inner_block = m_matmul_algo->get_inner_block_size(); | |||||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||||
inner_block.m, inner_block.n, | |||||
m_matmul_algo->packmode()); | |||||
} else { //! nopack_mode | |||||
size_t nopack_default_blockm = 8; | |||||
size_t nopack_default_blockn = 16; | |||||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||||
nopack_default_blockm, nopack_default_blockn, | |||||
m_matmul_algo->packmode()); | |||||
} | |||||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||||
mdesc.innerblocksize.m, mdesc.innerblocksize.n, | |||||
mdesc.packmode); | |||||
size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size); | size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size); | ||||
size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | size_t oc_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | ||||
@@ -507,18 +508,17 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||||
if (only_packA) { | if (only_packA) { | ||||
packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | packa_parallel_times = div_ceil<size_t>(OC, oc_tile_size); | ||||
} else if (default_pack) { | } else if (default_pack) { | ||||
packa_parallel_times = div_ceil<size_t>( | |||||
OC, m_matmul_algo->get_inner_block_size().m); | |||||
packa_parallel_times = div_ceil<size_t>(OC, mdesc.innerblocksize.m); | |||||
} | } | ||||
auto matmul_param = get_matmul_kern_param( | auto matmul_param = get_matmul_kern_param( | ||||
param, ohw_tile_size, only_packA ? oc_tile_size : OC); | param, ohw_tile_size, only_packA ? oc_tile_size : OC); | ||||
if (m_matmul_algo->packmode() == Pack_Mode::DEFAULT) { | |||||
if (mdesc.packmode == Pack_Mode::DEFAULT) { | |||||
Im2colKerns<Pack_Mode::DEFAULT> defaultkern; | Im2colKerns<Pack_Mode::DEFAULT> defaultkern; | ||||
bundle_thread = defaultkern.get_thread_bundle( | bundle_thread = defaultkern.get_thread_bundle( | ||||
param, matmul_param, m_matmul_algo, ohw_tile_size, | param, matmul_param, m_matmul_algo, ohw_tile_size, | ||||
oc_tile_size); | oc_tile_size); | ||||
} else if (m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA) { | |||||
} else if (mdesc.packmode == Pack_Mode::ONLY_PACKA) { | |||||
Im2colKerns<Pack_Mode::ONLY_PACKA> onlypackakern; | Im2colKerns<Pack_Mode::ONLY_PACKA> onlypackakern; | ||||
bundle_thread = onlypackakern.get_thread_bundle( | bundle_thread = onlypackakern.get_thread_bundle( | ||||
param, matmul_param, m_matmul_algo, ohw_tile_size, | param, matmul_param, m_matmul_algo, ohw_tile_size, | ||||
@@ -559,24 +559,24 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||||
auto kern_packA = [bundle, matmul_algo = m_matmul_algo, | auto kern_packA = [bundle, matmul_algo = m_matmul_algo, | ||||
matmul_param, im2colstrategy, | matmul_param, im2colstrategy, | ||||
pack_oc_size = pack_oc_size]( | |||||
const NCBKernParam& param, | |||||
const NCBKernIndex& ncb_index) { | |||||
pack_oc_size = pack_oc_size, | |||||
mdesc = mdesc](const NCBKernParam& param, | |||||
const NCBKernIndex& ncb_index) { | |||||
packA_kern(bundle, param, matmul_param, matmul_algo, ncb_index, | packA_kern(bundle, param, matmul_param, matmul_algo, ncb_index, | ||||
im2colstrategy, pack_oc_size); | |||||
im2colstrategy, mdesc, pack_oc_size); | |||||
}; | }; | ||||
if (default_pack) { | if (default_pack) { | ||||
auto kern_compute_default = | auto kern_compute_default = | ||||
[bundle, bundle_thread, matmul_param, | [bundle, bundle_thread, matmul_param, | ||||
matmul_algo = m_matmul_algo, | matmul_algo = m_matmul_algo, | ||||
ohw_tile_size = ohw_tile_size, | ohw_tile_size = ohw_tile_size, | ||||
strategyparam = strategyparam, | |||||
strategyparam = strategyparam, matmul_desc = mdesc, | |||||
im2colstrategy](const NCBKernParam& param, | im2colstrategy](const NCBKernParam& param, | ||||
const NCBKernIndex& ncb_index) { | const NCBKernIndex& ncb_index) { | ||||
Im2colKerns<Pack_Mode::DEFAULT>::kerns( | Im2colKerns<Pack_Mode::DEFAULT>::kerns( | ||||
bundle, bundle_thread, param, matmul_param, | bundle, bundle_thread, param, matmul_param, | ||||
matmul_algo, strategyparam, ncb_index, | |||||
ohw_tile_size, im2colstrategy); | |||||
matmul_algo, matmul_desc, strategyparam, | |||||
ncb_index, ohw_tile_size, im2colstrategy); | |||||
}; | }; | ||||
ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}}); | ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}}); | ||||
@@ -592,13 +592,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||||
[bundle, bundle_thread, matmul_param, | [bundle, bundle_thread, matmul_param, | ||||
matmul_algo = m_matmul_algo, | matmul_algo = m_matmul_algo, | ||||
strategyparam = strategyparam, | strategyparam = strategyparam, | ||||
ohw_tile_size = ohw_tile_size, | |||||
ohw_tile_size = ohw_tile_size, matmul_desc = mdesc, | |||||
im2colstrategy](const NCBKernParam& param, | im2colstrategy](const NCBKernParam& param, | ||||
const NCBKernIndex& ncb_index) { | const NCBKernIndex& ncb_index) { | ||||
Im2colKerns<Pack_Mode::ONLY_PACKA>::kerns( | Im2colKerns<Pack_Mode::ONLY_PACKA>::kerns( | ||||
bundle, bundle_thread, param, matmul_param, | bundle, bundle_thread, param, matmul_param, | ||||
matmul_algo, strategyparam, ncb_index, | |||||
ohw_tile_size, im2colstrategy); | |||||
matmul_algo, matmul_desc, strategyparam, | |||||
ncb_index, ohw_tile_size, im2colstrategy); | |||||
}; | }; | ||||
ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}}); | ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}}); | ||||
if (need_padding) { | if (need_padding) { | ||||
@@ -612,13 +612,13 @@ SmallVector<ConvBiasImpl::NCBKern> ConvBiasImpl::AlgoIm2col::dispatch_kerns( | |||||
[bundle, bundle_thread, matmul_param, | [bundle, bundle_thread, matmul_param, | ||||
matmul_algo = m_matmul_algo, | matmul_algo = m_matmul_algo, | ||||
strategyparam = strategyparam, | strategyparam = strategyparam, | ||||
ohw_tile_size = ohw_tile_size, | |||||
ohw_tile_size = ohw_tile_size, matmul_desc = mdesc, | |||||
im2colstrategy](const NCBKernParam& param, | im2colstrategy](const NCBKernParam& param, | ||||
const NCBKernIndex& ncb_index) { | const NCBKernIndex& ncb_index) { | ||||
Im2colKerns<Pack_Mode::NO_PACK>::kerns( | Im2colKerns<Pack_Mode::NO_PACK>::kerns( | ||||
bundle, bundle_thread, param, matmul_param, | bundle, bundle_thread, param, matmul_param, | ||||
matmul_algo, strategyparam, ncb_index, | |||||
ohw_tile_size, im2colstrategy); | |||||
matmul_algo, matmul_desc, strategyparam, | |||||
ncb_index, ohw_tile_size, im2colstrategy); | |||||
}; | }; | ||||
if (need_padding) { | if (need_padding) { | ||||
@@ -668,10 +668,12 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
fallback::MatrixMulImpl::AlgoBase::MatmulDescription mdesc = | |||||
m_matmul_algo->matmul_description(); | |||||
if (opr->param().format == param::ConvBias::Format::NCHW44 || | if (opr->param().format == param::ConvBias::Format::NCHW44 || | ||||
opr->param().format == param::ConvBias::Format::NCHW44_DOT) { | opr->param().format == param::ConvBias::Format::NCHW44_DOT) { | ||||
//! current NCHW44 im2col only support DEFAULT mode matmul | //! current NCHW44 im2col only support DEFAULT mode matmul | ||||
if (m_matmul_algo->packmode() != Pack_Mode::DEFAULT) { | |||||
if (mdesc.packmode != Pack_Mode::DEFAULT) { | |||||
return false; | return false; | ||||
//! nchw44 hybird mode and channel wise is not support | //! nchw44 hybird mode and channel wise is not support | ||||
} else if (param.filter_meta.icpg < 4_z || | } else if (param.filter_meta.icpg < 4_z || | ||||
@@ -682,22 +684,9 @@ bool ConvBiasImpl::AlgoIm2col::usable( | |||||
} | } | ||||
size_t oc_tile_size = 0, ohw_tile_size = 0; | size_t oc_tile_size = 0, ohw_tile_size = 0; | ||||
Pack_Mode packmode = m_matmul_algo->packmode(); | |||||
bool default_pack = packmode == Pack_Mode::DEFAULT; | |||||
bool only_packA = packmode == Pack_Mode::ONLY_PACKA; | |||||
if (default_pack || only_packA) { | |||||
auto inner_block = m_matmul_algo->get_inner_block_size(); | |||||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||||
inner_block.m, inner_block.n, | |||||
m_matmul_algo->packmode()); | |||||
} else { //! not support pack,not need pack | |||||
size_t nopack_default_blockm = 8; | |||||
size_t nopack_default_blockn = 16; | |||||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||||
nopack_default_blockm, nopack_default_blockn, | |||||
m_matmul_algo->packmode()); | |||||
} | |||||
choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, | |||||
mdesc.innerblocksize.m, mdesc.innerblocksize.n, | |||||
m_matmul_algo->packmode()); | |||||
fallback::MatrixMulImpl::KernSizeParam matmul_param = | fallback::MatrixMulImpl::KernSizeParam matmul_param = | ||||
get_matmul_kern_param(param, ohw_tile_size, oc_tile_size); | get_matmul_kern_param(param, ohw_tile_size, oc_tile_size); | ||||
bool matmulusable = m_matmul_algo->usable(matmul_param); | bool matmulusable = m_matmul_algo->usable(matmul_param); | ||||
@@ -58,8 +58,9 @@ public: | |||||
WorkspaceBundle bundle, | WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | ||||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desec, | |||||
size_t pack_size) = 0; | size_t pack_size) = 0; | ||||
virtual void exec_im2col( | virtual void exec_im2col( | ||||
@@ -67,15 +68,17 @@ public: | |||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) = 0; | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) = 0; | |||||
virtual void exec_matmul( | virtual void exec_matmul( | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, WorkspaceBundle bundle, | const StrategyParam& sparam, WorkspaceBundle bundle, | ||||
WorkspaceBundle bundle_thread, | WorkspaceBundle bundle_thread, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) = 0; | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc | |||||
) = 0; | |||||
virtual void exec_postprocess( | virtual void exec_postprocess( | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
@@ -284,26 +287,30 @@ public: | |||||
Strategy() = default; | Strategy() = default; | ||||
virtual void packA_kern(WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t pack_size) override; | |||||
virtual void packA_kern( | |||||
WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc, | |||||
size_t pack_size) override; | |||||
virtual void exec_im2col( | virtual void exec_im2col( | ||||
WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
void exec_matmul( | void exec_matmul( | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, WorkspaceBundle bundle, | const StrategyParam& sparam, WorkspaceBundle bundle, | ||||
WorkspaceBundle bundle_thread, | WorkspaceBundle bundle_thread, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc | |||||
) override; | |||||
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
WorkspaceBundle bundle_thread) override { | WorkspaceBundle bundle_thread) override { | ||||
@@ -338,7 +345,7 @@ public: | |||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
}; | }; | ||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
@@ -359,20 +366,24 @@ public: | |||||
Strategy() = default; | Strategy() = default; | ||||
void packA_kern(WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t pack_size) override; | |||||
void packA_kern( | |||||
WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& MDsec, | |||||
size_t pack_size) override; | |||||
void exec_matmul( | void exec_matmul( | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, WorkspaceBundle bundle, | const StrategyParam& sparam, WorkspaceBundle bundle, | ||||
WorkspaceBundle bundle_thread, | WorkspaceBundle bundle_thread, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc | |||||
) override; | |||||
void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const WorkspaceBundle& bundle_thread, | const WorkspaceBundle& bundle_thread, | ||||
@@ -382,7 +393,7 @@ public: | |||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
WorkspaceBundle bundle_thread) override { | WorkspaceBundle bundle_thread) override { | ||||
@@ -411,26 +422,30 @@ public: | |||||
Strategy() = default; | Strategy() = default; | ||||
void packA_kern(WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
size_t pack_size) override; | |||||
void packA_kern( | |||||
WorkspaceBundle bundle, | |||||
const fallback::ConvBiasImpl::NCBKernParam& param, | |||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& MDsec, | |||||
size_t pack_size) override; | |||||
void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | void exec_im2col(WorkspaceBundle bundle, WorkspaceBundle bundle_thread, | ||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
void exec_matmul( | void exec_matmul( | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const StrategyParam& sparam, WorkspaceBundle bundle, | const StrategyParam& sparam, WorkspaceBundle bundle, | ||||
WorkspaceBundle bundle_thread, | WorkspaceBundle bundle_thread, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc | |||||
) override; | |||||
void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | void* get_matmul_dst_ptr(const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
const WorkspaceBundle& bundle_thread, | const WorkspaceBundle& bundle_thread, | ||||
@@ -465,7 +480,7 @@ public: | |||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
}; | }; | ||||
template <typename op_ctype, typename op_dtype, | template <typename op_ctype, typename op_dtype, | ||||
@@ -487,7 +502,7 @@ public: | |||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
}; | }; | ||||
@@ -510,7 +525,7 @@ public: | |||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; | |||||
}; | }; | ||||
#endif | #endif | ||||
@@ -21,8 +21,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
packA_kern(WorkspaceBundle bundle, | packA_kern(WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | ||||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& | |||||
matmul_desc, | |||||
size_t) { | size_t) { | ||||
bundle.set(param.workspace_ptr); | bundle.set(param.workspace_ptr); | ||||
fallback::MatrixMulImpl::KernParam matmul_param; | fallback::MatrixMulImpl::KernParam matmul_param; | ||||
@@ -31,16 +33,16 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
matmulparam; | matmulparam; | ||||
size_t packA_group_size = matmul_algo->get_bundle(matmul_param).get_size(0); | size_t packA_group_size = matmul_algo->get_bundle(matmul_param).get_size(0); | ||||
size_t packed_per_oc_block_size = | size_t packed_per_oc_block_size = | ||||
round_up(matmul_param.K, matmul_algo->get_inner_block_size().k) * | |||||
matmul_algo->get_inner_block_size().m * | |||||
matmul_algo->get_packA_type_size(); | |||||
round_up(matmul_param.K, matmul_desc.innerblocksize.k) * | |||||
matmul_desc.innerblocksize.m * matmul_desc.packa_type_size; | |||||
size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size; | size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size; | ||||
int8_t* a_panel = static_cast<int8_t*>(bundle.get(BUNDLE_PACKA_INDEX)) + | int8_t* a_panel = static_cast<int8_t*>(bundle.get(BUNDLE_PACKA_INDEX)) + | ||||
group_id * packA_group_size + a_panel_offset; | group_id * packA_group_size + a_panel_offset; | ||||
matmul_param.A_ptr = | matmul_param.A_ptr = | ||||
const_cast<src_ctype*>(param.filter<src_ctype>(group_id)); | const_cast<src_ctype*>(param.filter<src_ctype>(group_id)); | ||||
matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1], | matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1], | ||||
matmul_algo->get_inner_block_size().m); | |||||
matmul_desc.innerblocksize.m); | |||||
} | } | ||||
template <typename src_ctype, typename bias_ctype, typename dst_ctype, | template <typename src_ctype, typename bias_ctype, typename dst_ctype, | ||||
@@ -52,7 +54,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||||
size_t sh = param.filter_meta.stride[0]; | size_t sh = param.filter_meta.stride[0]; | ||||
size_t sw = param.filter_meta.stride[1]; | size_t sw = param.filter_meta.stride[1]; | ||||
size_t oc = param.filter_meta.ocpg; | size_t oc = param.filter_meta.ocpg; | ||||
@@ -140,11 +142,13 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
const StrategyParam& sparam, WorkspaceBundle bundle, | const StrategyParam& sparam, WorkspaceBundle bundle, | ||||
WorkspaceBundle bundle_thread, | WorkspaceBundle bundle_thread, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& | |||||
matmul_desc) { | |||||
size_t packA_per_oc_block_size = | size_t packA_per_oc_block_size = | ||||
round_up(matmul_param.K, matmul_algo->get_inner_block_size().k) * | |||||
sparam.oc_tile_size * matmul_algo->get_packA_type_size(); | |||||
round_up(matmul_param.K, matmul_desc.innerblocksize.k) * | |||||
sparam.oc_tile_size * matmul_desc.packa_type_size; | |||||
size_t packA_group_size = matmul_algo->get_bundle(matmul_param).get_size(0); | size_t packA_group_size = matmul_algo->get_bundle(matmul_param).get_size(0); | ||||
size_t a_panel_offset = ncb_index.ndrange_id[1] * packA_group_size + | size_t a_panel_offset = ncb_index.ndrange_id[1] * packA_group_size + | ||||
ncb_index.ndrange_id[3] * packA_per_oc_block_size; | ncb_index.ndrange_id[3] * packA_per_oc_block_size; | ||||
@@ -33,7 +33,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||||
size_t sh = param.filter_meta.stride[0]; | size_t sh = param.filter_meta.stride[0]; | ||||
size_t sw = param.filter_meta.stride[1]; | size_t sw = param.filter_meta.stride[1]; | ||||
size_t oc = param.filter_meta.ocpg; | size_t oc = param.filter_meta.ocpg; | ||||
@@ -173,7 +173,7 @@ void StrategyFuse4x4x16Nchw44<op_ctype, op_dtype, postprocess_mode>:: | |||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernParam, | fallback::MatrixMulImpl::KernParam, | ||||
fallback::MatrixMulImpl::AlgoBase*) { | |||||
const fallback::MatrixMulImpl::AlgoBase*) { | |||||
size_t ow = param.osz[1]; | size_t ow = param.osz[1]; | ||||
size_t ic = param.filter_meta.icpg; | size_t ic = param.filter_meta.icpg; | ||||
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | ||||
@@ -176,7 +176,7 @@ void StrategyFuse8x12x4Nchw44Dot<op_ctype, op_dtype, postprocess_mode>:: | |||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernParam /*matmul_param*/, | fallback::MatrixMulImpl::KernParam /*matmul_param*/, | ||||
fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) { | |||||
const fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) { | |||||
size_t ow = param.osz[1]; | size_t ow = param.osz[1]; | ||||
size_t ic = param.filter_meta.icpg; | size_t ic = param.filter_meta.icpg; | ||||
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | ||||
@@ -168,7 +168,7 @@ void StrategyFuse8x12x1Nchw44K3x3S2<op_ctype, op_dtype, postprocess_mode>:: | |||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernParam /*matmul_param*/, | fallback::MatrixMulImpl::KernParam /*matmul_param*/, | ||||
fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) { | |||||
const fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) { | |||||
size_t ow = param.osz[1]; | size_t ow = param.osz[1]; | ||||
size_t ic = param.filter_meta.icpg; | size_t ic = param.filter_meta.icpg; | ||||
size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; | ||||
@@ -22,8 +22,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
packA_kern(WorkspaceBundle bundle, | packA_kern(WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | ||||
const fallback::MatrixMulImpl::AlgoBase:: | |||||
MatmulDescription& /*matmul_dsec*/, | |||||
size_t) { | size_t) { | ||||
MEGDNN_MARK_USED_VAR(bundle); | MEGDNN_MARK_USED_VAR(bundle); | ||||
MEGDNN_MARK_USED_VAR(param); | MEGDNN_MARK_USED_VAR(param); | ||||
@@ -62,8 +64,11 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
const StrategyParam& sparam, WorkspaceBundle bundle, | const StrategyParam& sparam, WorkspaceBundle bundle, | ||||
WorkspaceBundle bundle_thread, | WorkspaceBundle bundle_thread, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const fallback::MatrixMulImpl::AlgoBase:: | |||||
MatmulDescription& /*matmul_desc*/ | |||||
) { | |||||
MEGDNN_MARK_USED_VAR(bundle); | MEGDNN_MARK_USED_VAR(bundle); | ||||
MEGDNN_MARK_USED_VAR(ncb_index); | MEGDNN_MARK_USED_VAR(ncb_index); | ||||
matmul_param.workspace_ptr = bundle_thread.get(THREAD_BUNDLE_MATCOMP_INDEX); | matmul_param.workspace_ptr = bundle_thread.get(THREAD_BUNDLE_MATCOMP_INDEX); | ||||
@@ -95,7 +100,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||||
MEGDNN_MARK_USED_VAR(matmul_param); | MEGDNN_MARK_USED_VAR(matmul_param); | ||||
MEGDNN_MARK_USED_VAR(matmul_algo); | MEGDNN_MARK_USED_VAR(matmul_algo); | ||||
size_t sh = param.filter_meta.stride[0]; | size_t sh = param.filter_meta.stride[0]; | ||||
@@ -22,8 +22,10 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
packA_kern(WorkspaceBundle bundle, | packA_kern(WorkspaceBundle bundle, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernSizeParam matmulparam, | fallback::MatrixMulImpl::KernSizeParam matmulparam, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | ||||
const fallback::MatrixMulImpl::AlgoBase:: | |||||
MatmulDescription& /*matmul_desc*/, | |||||
size_t) { | size_t) { | ||||
bundle.set(param.workspace_ptr); | bundle.set(param.workspace_ptr); | ||||
fallback::MatrixMulImpl::KernParam matmul_param; | fallback::MatrixMulImpl::KernParam matmul_param; | ||||
@@ -57,8 +59,11 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
const StrategyParam& sparam, WorkspaceBundle bundle, | const StrategyParam& sparam, WorkspaceBundle bundle, | ||||
WorkspaceBundle bundle_thread, | WorkspaceBundle bundle_thread, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) { | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo, | |||||
const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, | |||||
const fallback::MatrixMulImpl::AlgoBase:: | |||||
MatmulDescription& /*matmul_desc*/ | |||||
) { | |||||
size_t packA_group_size = | size_t packA_group_size = | ||||
bundle.get_size(BUNDLE_PACKA_INDEX) / param.filter_meta.group; | bundle.get_size(BUNDLE_PACKA_INDEX) / param.filter_meta.group; | ||||
size_t a_panel_offset = ncb_index.ndrange_id[3] * | size_t a_panel_offset = ncb_index.ndrange_id[3] * | ||||
@@ -95,7 +100,7 @@ void Strategy<src_ctype, bias_ctype, dst_ctype, op_ctype, op_dtype, | |||||
const StrategyParam& sparam, | const StrategyParam& sparam, | ||||
const fallback::ConvBiasImpl::NCBKernParam& param, | const fallback::ConvBiasImpl::NCBKernParam& param, | ||||
fallback::MatrixMulImpl::KernParam matmul_param, | fallback::MatrixMulImpl::KernParam matmul_param, | ||||
fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||||
const fallback::MatrixMulImpl::AlgoBase* matmul_algo) { | |||||
MEGDNN_MARK_USED_VAR(matmul_param); | MEGDNN_MARK_USED_VAR(matmul_param); | ||||
MEGDNN_MARK_USED_VAR(matmul_algo); | MEGDNN_MARK_USED_VAR(matmul_algo); | ||||
size_t sh = param.filter_meta.stride[0]; | size_t sh = param.filter_meta.stride[0]; | ||||
@@ -37,6 +37,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||||
}; | }; | ||||
} // namespace fallback | } // namespace fallback | ||||
@@ -352,6 +352,15 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K, | |||||
DType dtype_c) \ | DType dtype_c) \ | ||||
: A_dtype(dtype_a), B_dtype(dtype_b), C_dtype(dtype_c) {} | : A_dtype(dtype_a), B_dtype(dtype_b), C_dtype(dtype_c) {} | ||||
#define MEGDNN_OVERRIDE_MATMUL_DESC(_m, _n, _k, _packa_type_size) \ | |||||
MatmulDescription matmul_description() const override { \ | |||||
MatmulDescription mdesc; \ | |||||
mdesc.packmode = packmode(); \ | |||||
mdesc.innerblocksize = {_m, _n, _k}; \ | |||||
mdesc.packa_type_size = _packa_type_size; \ | |||||
return mdesc; \ | |||||
} | |||||
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL() \ | #define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL() \ | ||||
WorkspaceBundle get_bundle(const KernSizeParam&) const override; \ | WorkspaceBundle get_bundle(const KernSizeParam&) const override; \ | ||||
kern_naked_t get_kern_naked(const KernSizeParam&) const override; \ | kern_naked_t get_kern_naked(const KernSizeParam&) const override; \ | ||||
@@ -360,7 +369,7 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K, | |||||
void pack_B(const KernParam& kern_param, void* out, size_t x0, \ | void pack_B(const KernParam& kern_param, void* out, size_t x0, \ | ||||
size_t xmax) const override; \ | size_t xmax) const override; \ | ||||
InnerBlockSize get_inner_block_size() const override; \ | InnerBlockSize get_inner_block_size() const override; \ | ||||
size_t get_packA_type_size() const override; | |||||
MatmulDescription matmul_description() const override; | |||||
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \ | #define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL_DETAIL( \ | ||||
_algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \ | _algo_name, _midout_name, _mid_index, _strategy, _i_type, _c_type, \ | ||||
@@ -458,8 +467,14 @@ void gemm_kern(const Tin* packA, const Tin* packB, size_t M, size_t N, size_t K, | |||||
_strategy::UNROLL_K}; \ | _strategy::UNROLL_K}; \ | ||||
} \ | } \ | ||||
\ | \ | ||||
size_t MatrixMulImpl::_algo_name::get_packA_type_size() const { \ | |||||
return sizeof(_packa_type); \ | |||||
MatrixMulImpl::_algo_name::MatmulDescription \ | |||||
MatrixMulImpl::_algo_name::matmul_description() const { \ | |||||
MatmulDescription mdesc; \ | |||||
mdesc.packmode = PackMode(); \ | |||||
mdesc.innerblocksize = {_strategy::KERNEL_H, _strategy::KERNEL_W, \ | |||||
_strategy::UNROLL_K}; \ | |||||
mdesc.packa_type_size = sizeof(_packa_type); \ | |||||
return mdesc; \ | |||||
} | } | ||||
#define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \ | #define MEGDNN_REG_GEMM_FUNC_FOR_IM2COL_IMPL( \ | ||||
@@ -104,6 +104,12 @@ public: | |||||
size_t m, n, k; | size_t m, n, k; | ||||
}; | }; | ||||
struct MatmulDescription { | |||||
PackMode packmode; | |||||
InnerBlockSize innerblocksize; | |||||
size_t packa_type_size; | |||||
}; | |||||
virtual bool usable(const KernSizeParam&) const = 0; | virtual bool usable(const KernSizeParam&) const = 0; | ||||
virtual bool preferred(const KernSizeParam&) const { return true; } | virtual bool preferred(const KernSizeParam&) const { return true; } | ||||
virtual size_t get_workspace(const KernSizeParam&) const = 0; | virtual size_t get_workspace(const KernSizeParam&) const = 0; | ||||
@@ -125,11 +131,11 @@ public: | |||||
virtual InnerBlockSize get_inner_block_size() const { | virtual InnerBlockSize get_inner_block_size() const { | ||||
megdnn_assert(0); | megdnn_assert(0); | ||||
}; | }; | ||||
virtual size_t get_packA_type_size() const { megdnn_assert(0); }; | |||||
bool preferred_reproducible(const KernSizeParam& param, | bool preferred_reproducible(const KernSizeParam& param, | ||||
bool reproducible = true) { | bool reproducible = true) { | ||||
return (!reproducible || is_reproducible()) && preferred(param); | return (!reproducible || is_reproducible()) && preferred(param); | ||||
}; | }; | ||||
virtual MatmulDescription matmul_description() const = 0; | |||||
}; | }; | ||||
/** | /** | ||||
@@ -27,6 +27,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_x86_algo_type; } | void* type() const override { return sm_x86_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||||
}; | }; | ||||
#if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM | #if MEGDNN_X86_WITH_MKL && SUPPORT_MKL_PACKED_GEMM | ||||
@@ -46,7 +47,9 @@ public: | |||||
megdnn_assert(0); | megdnn_assert(0); | ||||
}; | }; | ||||
WorkspaceBundle get_bundle(const KernSizeParam& param) const override; | WorkspaceBundle get_bundle(const KernSizeParam& param) const override; | ||||
InnerBlockSize get_inner_block_size() const override { return {8, 16, 1}; }; | |||||
InnerBlockSize get_inner_block_size() const override{ return {8, 16, 1}; }; | |||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) | |||||
}; | }; | ||||
#endif | #endif | ||||
@@ -124,6 +127,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_x86_algo_type; } | void* type() const override { return sm_x86_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 4) | |||||
}; | }; | ||||
#if MEGDNN_X86_WITH_VNNI | #if MEGDNN_X86_WITH_VNNI | ||||
@@ -149,6 +153,7 @@ public: | |||||
kern_t get_kern(const KernSizeParam&) const override; | kern_t get_kern(const KernSizeParam&) const override; | ||||
void* type() const override { return sm_x86_algo_type; } | void* type() const override { return sm_x86_algo_type; } | ||||
PackMode packmode() const override { return PackMode::NO_PACK; } | PackMode packmode() const override { return PackMode::NO_PACK; } | ||||
MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) | |||||
}; | }; | ||||
#endif | #endif | ||||
} // namespace x86 | } // namespace x86 | ||||