From 8e9fa80cc2e65f7b3194100627d8ec71b6d99fd9 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 16 Jun 2020 22:08:48 +0800 Subject: [PATCH] feat(dnn/fallback): add matmul description for im2col GitOrigin-RevId: 5bde0b60f0b8102cd8bad14457cf123bc7e6dafa --- dnn/src/aarch64/matrix_mul/algos.h | 4 + dnn/src/arm_common/matrix_mul/algos.h | 6 + dnn/src/armv7/matrix_mul/algos.h | 3 + dnn/src/fallback/conv_bias/im2col/algos.cpp | 129 ++++++++++----------- dnn/src/fallback/conv_bias/im2col/strategy_base.h | 85 ++++++++------ .../fallback/conv_bias/im2col/strategy_default.cpp | 24 ++-- .../conv_bias/im2col/strategy_default_nchw44.cpp | 2 +- .../conv_bias/im2col/strategy_fuse_nchw44.cpp | 2 +- .../conv_bias/im2col/strategy_fuse_nchw44_dot.cpp | 2 +- .../im2col/strategy_fuse_nchw44_fp32_s2.cpp | 2 +- .../fallback/conv_bias/im2col/strategy_nopack.cpp | 13 ++- .../conv_bias/im2col/strategy_onlypacka.cpp | 13 ++- dnn/src/fallback/matrix_mul/algos.h | 1 + dnn/src/fallback/matrix_mul/gemm_common.h | 21 +++- dnn/src/fallback/matrix_mul/opr_impl.h | 8 +- dnn/src/x86/matrix_mul/algos.h | 7 +- 16 files changed, 190 insertions(+), 132 deletions(-) diff --git a/dnn/src/aarch64/matrix_mul/algos.h b/dnn/src/aarch64/matrix_mul/algos.h index 7e5a3613..d4b6ff6b 100644 --- a/dnn/src/aarch64/matrix_mul/algos.h +++ b/dnn/src/aarch64/matrix_mul/algos.h @@ -60,6 +60,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(4, 16, 4, 4) }; class MatrixMulImpl::AlgoF32Gemv final @@ -86,6 +87,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2) }; #endif @@ -207,6 +209,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(8, 8, 8, 2) }; #if __ARM_FEATURE_DOTPROD @@ -234,6 +237,7 @@ public: void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) }; #else diff --git a/dnn/src/arm_common/matrix_mul/algos.h b/dnn/src/arm_common/matrix_mul/algos.h index a8f7a339..b34512d4 100644 --- a/dnn/src/arm_common/matrix_mul/algos.h +++ b/dnn/src/arm_common/matrix_mul/algos.h @@ -12,6 +12,7 @@ #pragma once #include "src/arm_common/matrix_mul/opr_impl.h" +#include "src/fallback/matrix_mul/gemm_common.h" namespace megdnn { namespace arm_common { @@ -25,6 +26,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) }; class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { @@ -38,6 +40,7 @@ public: void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) }; class MatrixMulImpl::AlgoF32Gemv : public AlgoBase { @@ -54,6 +57,7 @@ public: void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) }; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -68,6 +72,7 @@ public: void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 2) }; #endif @@ -82,6 +87,7 @@ public: void* type() const override { return sm_arm_common_algo_type; } AlgoSet algoset() const override { return AlgoSet::ALGO_TYPE_GEMV; } PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(8, 16, 1, 4) }; diff --git a/dnn/src/armv7/matrix_mul/algos.h b/dnn/src/armv7/matrix_mul/algos.h index ae5d0fb7..9a509b46 100644 --- a/dnn/src/armv7/matrix_mul/algos.h +++ b/dnn/src/armv7/matrix_mul/algos.h @@ -49,6 +49,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 4, 4) }; #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC @@ -71,6 +72,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2) }; #endif #if __ARM_FEATURE_DOTPROD @@ -190,6 +192,7 @@ public: kern_t get_kern(const KernSizeParam&) const override; void* type() const override { return sm_arm_common_algo_type; } PackMode packmode() const override { return PackMode::NO_PACK; } + MEGDNN_OVERRIDE_MATMUL_DESC(4, 8, 8, 2) }; class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { diff --git a/dnn/src/fallback/conv_bias/im2col/algos.cpp b/dnn/src/fallback/conv_bias/im2col/algos.cpp index 9d4e89e7..051affe1 100644 --- a/dnn/src/fallback/conv_bias/im2col/algos.cpp +++ b/dnn/src/fallback/conv_bias/im2col/algos.cpp @@ -47,14 +47,17 @@ static void copy_padding_kern(WorkspaceBundle bundle, } //! packA_kern -static void packA_kern(WorkspaceBundle bundle, - const fallback::ConvBiasImpl::NCBKernParam& param, - fallback::MatrixMulImpl::KernSizeParam matmulparam, - fallback::MatrixMulImpl::AlgoBase* matmul_algo, - const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, - StrategyBase* im2colstrategy, size_t pack_oc_size) { +static void packA_kern( + WorkspaceBundle bundle, + const fallback::ConvBiasImpl::NCBKernParam& param, + fallback::MatrixMulImpl::KernSizeParam matmulparam, + fallback::MatrixMulImpl::AlgoBase* matmul_algo, + const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, + StrategyBase* im2colstrategy, + const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc, + size_t pack_oc_size) { im2colstrategy->packA_kern(bundle, param, matmulparam, matmul_algo, - ncb_index, pack_oc_size); + ncb_index, matmul_desc, pack_oc_size); } /*! @@ -72,7 +75,8 @@ public: WorkspaceBundle bundle, WorkspaceBundle bundle_thread, const ConvBiasImpl::NCBKernParam& param, fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param, - fallback::MatrixMulImpl::AlgoBase* matmul_algo, + const fallback::MatrixMulImpl::AlgoBase* matmul_algo, + const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc, StrategyParam strategyparam, fallback::ConvBiasImpl::NCBKernIndex ncb_index, size_t ohw_tile_size, StrategyBase* im2colstrategy) { @@ -111,7 +115,8 @@ public: //! 2.packb and matmul compute im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread, - matmul_param, matmul_algo, ncb_index); + matmul_param, matmul_algo, ncb_index, + matmul_desc); //! 3.postprocess and copy dst if need im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread); @@ -151,7 +156,8 @@ public: WorkspaceBundle bundle, WorkspaceBundle bundle_thread, const ConvBiasImpl::NCBKernParam& param, fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param, - fallback::MatrixMulImpl::AlgoBase* matmul_algo, + const fallback::MatrixMulImpl::AlgoBase* matmul_algo, + const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc, StrategyParam strategyparam, fallback::ConvBiasImpl::NCBKernIndex ncb_index, size_t ohw_tile_size, StrategyBase* im2colstrategy) { @@ -191,7 +197,8 @@ public: //! 2.packb and matmul compute im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread, - matmul_param, matmul_algo, ncb_index); + matmul_param, matmul_algo, ncb_index, + matmul_desc); //! 3.postprocess and copy dst if need im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread); @@ -232,7 +239,8 @@ public: WorkspaceBundle bundle, WorkspaceBundle bundle_thread, const ConvBiasImpl::NCBKernParam& param, fallback::MatrixMulImpl::KernSizeParam matmul_kernsize_param, - fallback::MatrixMulImpl::AlgoBase* matmul_algo, + const fallback::MatrixMulImpl::AlgoBase* matmul_algo, + const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc, StrategyParam strategyparam, fallback::ConvBiasImpl::NCBKernIndex ncb_index, size_t ohw_tile_size, StrategyBase* im2colstrategy) { @@ -272,7 +280,8 @@ public: //! 2.packb and matmul compute im2colstrategy->exec_matmul(param, strategyparam, bundle, bundle_thread, - matmul_param, matmul_algo, ncb_index); + matmul_param, matmul_algo, ncb_index, + matmul_desc); //! 3.postprocess and copy dst if need im2colstrategy->exec_postprocess(param, strategyparam, bundle_thread); @@ -401,13 +410,15 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( size_t padding = 0, packa_size = 0, packa_group_size = 0; size_t nr_threads = param.nr_threads; size_t GROUP = param.filter_meta.group; - bool need_pack = m_matmul_algo->packmode() == Pack_Mode::DEFAULT; - bool only_packA = m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA; + fallback::MatrixMulImpl::AlgoBase::MatmulDescription mdesc = + m_matmul_algo->matmul_description(); + bool need_pack = mdesc.packmode == Pack_Mode::DEFAULT; + bool only_packA = mdesc.packmode == Pack_Mode::ONLY_PACKA; size_t oc_tile_size = 0, ohw_tile_size = 0; + choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, + mdesc.innerblocksize.m, mdesc.innerblocksize.n, + mdesc.packmode); if (need_pack || only_packA) { - auto inner_block = m_matmul_algo->get_inner_block_size(); - choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, inner_block.m, - inner_block.n, m_matmul_algo->packmode()); auto im2col_kern_param = get_matmul_kern_param( param, ohw_tile_size, only_packA ? oc_tile_size : OC); size_t oc_parallel_times = div_ceil(OC, oc_tile_size); @@ -415,11 +426,6 @@ WorkspaceBundle ConvBiasImpl::AlgoIm2col::get_bundle( packa_group_size = only_packA ? oc_parallel_times * wb.get_size(0) : wb.get_size(0); } else { //! not support pack,not need pack - size_t nopack_default_blockm = 8; - size_t nopack_default_blockn = 16; - choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, - nopack_default_blockm, nopack_default_blockn, - m_matmul_algo->packmode()); packa_group_size = 0; } @@ -481,23 +487,18 @@ SmallVector ConvBiasImpl::AlgoIm2col::dispatch_kerns( WorkspaceBundle bundle = get_bundle(param); WorkspaceBundle bundle_thread = {nullptr, {}}; bool need_padding = (PH != 0 || PW != 0); - Pack_Mode packmode = m_matmul_algo->packmode(); + + fallback::MatrixMulImpl::AlgoBase::MatmulDescription mdesc = + m_matmul_algo->matmul_description(); + + Pack_Mode packmode = mdesc.packmode; bool default_pack = packmode == Pack_Mode::DEFAULT; bool no_pack = packmode == Pack_Mode::NO_PACK; bool only_packA = packmode == Pack_Mode::ONLY_PACKA; - if (default_pack || only_packA) { - auto inner_block = m_matmul_algo->get_inner_block_size(); - choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, - inner_block.m, inner_block.n, - m_matmul_algo->packmode()); - } else { //! nopack_mode - size_t nopack_default_blockm = 8; - size_t nopack_default_blockn = 16; - choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, - nopack_default_blockm, nopack_default_blockn, - m_matmul_algo->packmode()); - } + choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, + mdesc.innerblocksize.m, mdesc.innerblocksize.n, + mdesc.packmode); size_t ohw_parallel_times = div_ceil(ohw, ohw_tile_size); size_t oc_parallel_times = div_ceil(OC, oc_tile_size); @@ -507,18 +508,17 @@ SmallVector ConvBiasImpl::AlgoIm2col::dispatch_kerns( if (only_packA) { packa_parallel_times = div_ceil(OC, oc_tile_size); } else if (default_pack) { - packa_parallel_times = div_ceil( - OC, m_matmul_algo->get_inner_block_size().m); + packa_parallel_times = div_ceil(OC, mdesc.innerblocksize.m); } auto matmul_param = get_matmul_kern_param( param, ohw_tile_size, only_packA ? oc_tile_size : OC); - if (m_matmul_algo->packmode() == Pack_Mode::DEFAULT) { + if (mdesc.packmode == Pack_Mode::DEFAULT) { Im2colKerns defaultkern; bundle_thread = defaultkern.get_thread_bundle( param, matmul_param, m_matmul_algo, ohw_tile_size, oc_tile_size); - } else if (m_matmul_algo->packmode() == Pack_Mode::ONLY_PACKA) { + } else if (mdesc.packmode == Pack_Mode::ONLY_PACKA) { Im2colKerns onlypackakern; bundle_thread = onlypackakern.get_thread_bundle( param, matmul_param, m_matmul_algo, ohw_tile_size, @@ -559,24 +559,24 @@ SmallVector ConvBiasImpl::AlgoIm2col::dispatch_kerns( auto kern_packA = [bundle, matmul_algo = m_matmul_algo, matmul_param, im2colstrategy, - pack_oc_size = pack_oc_size]( - const NCBKernParam& param, - const NCBKernIndex& ncb_index) { + pack_oc_size = pack_oc_size, + mdesc = mdesc](const NCBKernParam& param, + const NCBKernIndex& ncb_index) { packA_kern(bundle, param, matmul_param, matmul_algo, ncb_index, - im2colstrategy, pack_oc_size); + im2colstrategy, mdesc, pack_oc_size); }; if (default_pack) { auto kern_compute_default = [bundle, bundle_thread, matmul_param, matmul_algo = m_matmul_algo, ohw_tile_size = ohw_tile_size, - strategyparam = strategyparam, + strategyparam = strategyparam, matmul_desc = mdesc, im2colstrategy](const NCBKernParam& param, const NCBKernIndex& ncb_index) { Im2colKerns::kerns( bundle, bundle_thread, param, matmul_param, - matmul_algo, strategyparam, ncb_index, - ohw_tile_size, im2colstrategy); + matmul_algo, matmul_desc, strategyparam, + ncb_index, ohw_tile_size, im2colstrategy); }; ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}}); @@ -592,13 +592,13 @@ SmallVector ConvBiasImpl::AlgoIm2col::dispatch_kerns( [bundle, bundle_thread, matmul_param, matmul_algo = m_matmul_algo, strategyparam = strategyparam, - ohw_tile_size = ohw_tile_size, + ohw_tile_size = ohw_tile_size, matmul_desc = mdesc, im2colstrategy](const NCBKernParam& param, const NCBKernIndex& ncb_index) { Im2colKerns::kerns( bundle, bundle_thread, param, matmul_param, - matmul_algo, strategyparam, ncb_index, - ohw_tile_size, im2colstrategy); + matmul_algo, matmul_desc, strategyparam, + ncb_index, ohw_tile_size, im2colstrategy); }; ret_kern.push_back({kern_packA, {GROUP, packa_parallel_times}}); if (need_padding) { @@ -612,13 +612,13 @@ SmallVector ConvBiasImpl::AlgoIm2col::dispatch_kerns( [bundle, bundle_thread, matmul_param, matmul_algo = m_matmul_algo, strategyparam = strategyparam, - ohw_tile_size = ohw_tile_size, + ohw_tile_size = ohw_tile_size, matmul_desc = mdesc, im2colstrategy](const NCBKernParam& param, const NCBKernIndex& ncb_index) { Im2colKerns::kerns( bundle, bundle_thread, param, matmul_param, - matmul_algo, strategyparam, ncb_index, - ohw_tile_size, im2colstrategy); + matmul_algo, matmul_desc, strategyparam, + ncb_index, ohw_tile_size, im2colstrategy); }; if (need_padding) { @@ -668,10 +668,12 @@ bool ConvBiasImpl::AlgoIm2col::usable( return false; } } + fallback::MatrixMulImpl::AlgoBase::MatmulDescription mdesc = + m_matmul_algo->matmul_description(); if (opr->param().format == param::ConvBias::Format::NCHW44 || opr->param().format == param::ConvBias::Format::NCHW44_DOT) { //! current NCHW44 im2col only support DEFAULT mode matmul - if (m_matmul_algo->packmode() != Pack_Mode::DEFAULT) { + if (mdesc.packmode != Pack_Mode::DEFAULT) { return false; //! nchw44 hybird mode and channel wise is not support } else if (param.filter_meta.icpg < 4_z || @@ -682,22 +684,9 @@ bool ConvBiasImpl::AlgoIm2col::usable( } size_t oc_tile_size = 0, ohw_tile_size = 0; - Pack_Mode packmode = m_matmul_algo->packmode(); - bool default_pack = packmode == Pack_Mode::DEFAULT; - bool only_packA = packmode == Pack_Mode::ONLY_PACKA; - - if (default_pack || only_packA) { - auto inner_block = m_matmul_algo->get_inner_block_size(); - choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, - inner_block.m, inner_block.n, - m_matmul_algo->packmode()); - } else { //! not support pack,not need pack - size_t nopack_default_blockm = 8; - size_t nopack_default_blockn = 16; - choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, - nopack_default_blockm, nopack_default_blockn, - m_matmul_algo->packmode()); - } + choice_ohw_oc_block(param, oc_tile_size, ohw_tile_size, + mdesc.innerblocksize.m, mdesc.innerblocksize.n, + m_matmul_algo->packmode()); fallback::MatrixMulImpl::KernSizeParam matmul_param = get_matmul_kern_param(param, ohw_tile_size, oc_tile_size); bool matmulusable = m_matmul_algo->usable(matmul_param); diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_base.h b/dnn/src/fallback/conv_bias/im2col/strategy_base.h index 0ce050dd..976873ca 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_base.h +++ b/dnn/src/fallback/conv_bias/im2col/strategy_base.h @@ -58,8 +58,9 @@ public: WorkspaceBundle bundle, const fallback::ConvBiasImpl::NCBKernParam& param, fallback::MatrixMulImpl::KernSizeParam matmulparam, - fallback::MatrixMulImpl::AlgoBase* matmul_algo, + const fallback::MatrixMulImpl::AlgoBase* matmul_algo, const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, + const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desec, size_t pack_size) = 0; virtual void exec_im2col( @@ -67,15 +68,17 @@ public: const StrategyParam& sparam, const fallback::ConvBiasImpl::NCBKernParam& param, fallback::MatrixMulImpl::KernParam matmul_param, - fallback::MatrixMulImpl::AlgoBase* matmul_algo) = 0; + const fallback::MatrixMulImpl::AlgoBase* matmul_algo) = 0; virtual void exec_matmul( const fallback::ConvBiasImpl::NCBKernParam& param, const StrategyParam& sparam, WorkspaceBundle bundle, WorkspaceBundle bundle_thread, fallback::MatrixMulImpl::KernParam matmul_param, - fallback::MatrixMulImpl::AlgoBase* matmul_algo, - const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) = 0; + const fallback::MatrixMulImpl::AlgoBase* matmul_algo, + const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, + const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc + ) = 0; virtual void exec_postprocess( const fallback::ConvBiasImpl::NCBKernParam& param, @@ -284,26 +287,30 @@ public: Strategy() = default; - virtual void packA_kern(WorkspaceBundle bundle, - const fallback::ConvBiasImpl::NCBKernParam& param, - fallback::MatrixMulImpl::KernSizeParam matmulparam, - fallback::MatrixMulImpl::AlgoBase* matmul_algo, - const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, - size_t pack_size) override; + virtual void packA_kern( + WorkspaceBundle bundle, + const fallback::ConvBiasImpl::NCBKernParam& param, + fallback::MatrixMulImpl::KernSizeParam matmulparam, + const fallback::MatrixMulImpl::AlgoBase* matmul_algo, + const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, + const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc, + size_t pack_size) override; virtual void exec_im2col( WorkspaceBundle bundle, WorkspaceBundle bundle_thread, const StrategyParam& sparam, const fallback::ConvBiasImpl::NCBKernParam& param, fallback::MatrixMulImpl::KernParam matmul_param, - fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; + const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; void exec_matmul( const fallback::ConvBiasImpl::NCBKernParam& param, const StrategyParam& sparam, WorkspaceBundle bundle, WorkspaceBundle bundle_thread, fallback::MatrixMulImpl::KernParam matmul_param, - fallback::MatrixMulImpl::AlgoBase* matmul_algo, - const fallback::ConvBiasImpl::NCBKernIndex& ncb_index) override; + const fallback::MatrixMulImpl::AlgoBase* matmul_algo, + const fallback::ConvBiasImpl::NCBKernIndex& ncb_index, + const fallback::MatrixMulImpl::AlgoBase::MatmulDescription& matmul_desc + ) override; void exec_postprocess(const fallback::ConvBiasImpl::NCBKernParam& param, const StrategyParam& sparam, WorkspaceBundle bundle_thread) override { @@ -338,7 +345,7 @@ public: const StrategyParam& sparam, const fallback::ConvBiasImpl::NCBKernParam& param, fallback::MatrixMulImpl::KernParam matmul_param, - fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; + const fallback::MatrixMulImpl::AlgoBase* matmul_algo) override; }; template get_bundle(matmul_param).get_size(0); size_t packed_per_oc_block_size = - round_up(matmul_param.K, matmul_algo->get_inner_block_size().k) * - matmul_algo->get_inner_block_size().m * - matmul_algo->get_packA_type_size(); + round_up(matmul_param.K, matmul_desc.innerblocksize.k) * + matmul_desc.innerblocksize.m * matmul_desc.packa_type_size; + size_t a_panel_offset = ncb_index.ndrange_id[1] * packed_per_oc_block_size; int8_t* a_panel = static_cast(bundle.get(BUNDLE_PACKA_INDEX)) + group_id * packA_group_size + a_panel_offset; matmul_param.A_ptr = const_cast(param.filter(group_id)); matmul_algo->pack_A(matmul_param, a_panel, ncb_index.ndrange_id[1], - matmul_algo->get_inner_block_size().m); + matmul_desc.innerblocksize.m); } template get_inner_block_size().k) * - sparam.oc_tile_size * matmul_algo->get_packA_type_size(); + round_up(matmul_param.K, matmul_desc.innerblocksize.k) * + sparam.oc_tile_size * matmul_desc.packa_type_size; size_t packA_group_size = matmul_algo->get_bundle(matmul_param).get_size(0); size_t a_panel_offset = ncb_index.ndrange_id[1] * packA_group_size + ncb_index.ndrange_id[3] * packA_per_oc_block_size; diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp index dc61b814..b4d869f0 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp +++ b/dnn/src/fallback/conv_bias/im2col/strategy_default_nchw44.cpp @@ -33,7 +33,7 @@ void Strategy:: const StrategyParam& sparam, const fallback::ConvBiasImpl::NCBKernParam& param, fallback::MatrixMulImpl::KernParam, - fallback::MatrixMulImpl::AlgoBase*) { + const fallback::MatrixMulImpl::AlgoBase*) { size_t ow = param.osz[1]; size_t ic = param.filter_meta.icpg; size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot.cpp index ab5ad5fc..eeb34a69 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot.cpp +++ b/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_dot.cpp @@ -176,7 +176,7 @@ void StrategyFuse8x12x4Nchw44Dot:: const StrategyParam& sparam, const fallback::ConvBiasImpl::NCBKernParam& param, fallback::MatrixMulImpl::KernParam /*matmul_param*/, - fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) { + const fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) { size_t ow = param.osz[1]; size_t ic = param.filter_meta.icpg; size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp index 328cb2d1..d85a0398 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp +++ b/dnn/src/fallback/conv_bias/im2col/strategy_fuse_nchw44_fp32_s2.cpp @@ -168,7 +168,7 @@ void StrategyFuse8x12x1Nchw44K3x3S2:: const StrategyParam& sparam, const fallback::ConvBiasImpl::NCBKernParam& param, fallback::MatrixMulImpl::KernParam /*matmul_param*/, - fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) { + const fallback::MatrixMulImpl::AlgoBase* /*matmul_algo*/) { size_t ow = param.osz[1]; size_t ic = param.filter_meta.icpg; size_t ih = param.isz[0] + param.filter_meta.padding[0] * 2; diff --git a/dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp b/dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp index 35a06742..8a26382b 100644 --- a/dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp +++ b/dnn/src/fallback/conv_bias/im2col/strategy_nopack.cpp @@ -22,8 +22,10 @@ void Strategy