From 1cadf9d8d7d9c436457dc9820ea2fd3f0881f3cb Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 1 Apr 2021 16:01:54 +0800 Subject: [PATCH] fix(mgb): add usable-depend-on-shape attr GitOrigin-RevId: 3a14fa6b6f61c30999a9cd7f20f90dccc52e1377 --- dnn/include/megdnn/oprs/base.h | 5 ++ dnn/src/aarch64/matrix_mul/algos.h | 21 ++++--- dnn/src/arm_common/matrix_mul/algos.h | 12 ++-- dnn/src/armv7/matrix_mul/algos.h | 9 ++- dnn/src/common/algo_base.cpp | 3 +- dnn/src/cuda/convolution/backward_data/algo.h | 3 +- dnn/src/cuda/local_share/forward/algo.h | 6 +- dnn/src/cuda/matrix_mul/algos.h | 6 +- dnn/src/rocm/matrix_mul/algos.h | 3 +- dnn/src/x86/matrix_mul/algos.h | 3 +- src/opr/impl/search_policy/algo_chooser.cpp | 67 +++++++++++----------- .../megbrain/opr/search_policy/algo_chooser.h | 10 ++++ 12 files changed, 93 insertions(+), 55 deletions(-) diff --git a/dnn/include/megdnn/oprs/base.h b/dnn/include/megdnn/oprs/base.h index e668be8c..f356d9bb 100644 --- a/dnn/include/megdnn/oprs/base.h +++ b/dnn/include/megdnn/oprs/base.h @@ -122,6 +122,11 @@ public: * these algorithms to speed up fastrun. * */ NAIVE = 1 << 1, + + /** + * \brief whether the algo is usable once shape changed. + * */ + USABLE_DEPEND_ON_SHAPE = 1 << 2, }; /** diff --git a/dnn/src/aarch64/matrix_mul/algos.h b/dnn/src/aarch64/matrix_mul/algos.h index 8d690189..65652acd 100644 --- a/dnn/src/aarch64/matrix_mul/algos.h +++ b/dnn/src/aarch64/matrix_mul/algos.h @@ -35,7 +35,8 @@ public: class MatrixMulImpl::AlgoF32MK4_8x12x1 final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "AARCH64_F32_MK4_K8X12X1"; } bool usable(const KernSizeParam&) const override; @@ -146,7 +147,8 @@ public: class MatrixMulImpl::AlgoInt8x8x32MK4_4x4x16 final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "AARCH64_INT8X8X32_MK4_4X4X16"; } bool usable(const KernSizeParam&) const override; @@ -220,7 +222,8 @@ public: class MatrixMulImpl::AlgoInt4x4x16K8x8x8 final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "AARCH64_INT4X4X16_K8X8X8"; } bool usable(const KernSizeParam&) const override; @@ -235,7 +238,8 @@ public: class MatrixMulImpl::AlgoInt8x8x16MK4_16x12x4 final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "AARCH64_INT8X8X16_MK4_16X12X4"; @@ -253,7 +257,8 @@ public: class MatrixMulImpl::AlgoInt8x8x16MK4_K8x8x8 final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "AARCH64_INT8X8X16_MK4_K8X8X8"; @@ -271,7 +276,8 @@ public: class MatrixMulImpl::AlgoInt8x8x16MK4_4x4x8 final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "AARCH64_INT8X8X16_MK4_4X4X8"; } bool usable(const KernSizeParam&) const override; @@ -330,7 +336,8 @@ public: class MatrixMulImpl::AlgoQuint8GemvDotProd final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "AARCH64_QUINT8_GEMV_DOTPROD"; } bool usable(const KernSizeParam&) const override; diff --git a/dnn/src/arm_common/matrix_mul/algos.h b/dnn/src/arm_common/matrix_mul/algos.h index fb55f5fe..852e64b2 100644 --- a/dnn/src/arm_common/matrix_mul/algos.h +++ b/dnn/src/arm_common/matrix_mul/algos.h @@ -34,7 +34,8 @@ public: class MatrixMulImpl::AlgoInt8x8x32Gemv : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV"; } bool usable(const KernSizeParam&) const override; @@ -50,7 +51,8 @@ public: class MatrixMulImpl::AlgoInt8x8x32GemvMK4 : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4"; } bool usable(const KernSizeParam&) const override; @@ -67,7 +69,8 @@ public: class MatrixMulImpl::AlgoInt8x8x32GemvMK4Dot : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "ARM_COMMON_INT8X8X32_GEMV_MK4_DOT"; } bool usable(const KernSizeParam&) const override; @@ -102,7 +105,8 @@ public: class MatrixMulImpl::AlgoF32GemvMK4 : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "ARM_COMMON_F32_GEMV_MK4"; } bool usable(const KernSizeParam&) const override; diff --git a/dnn/src/armv7/matrix_mul/algos.h b/dnn/src/armv7/matrix_mul/algos.h index dcd55463..26176770 100644 --- a/dnn/src/armv7/matrix_mul/algos.h +++ b/dnn/src/armv7/matrix_mul/algos.h @@ -35,7 +35,8 @@ public: class MatrixMulImpl::AlgoF32MK4Pack4x12 final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "ARMV7_F32_MK4_PACK_4X12"; } bool usable(const KernSizeParam&) const override; @@ -224,7 +225,8 @@ public: class MatrixMulImpl::AlgoInt8x8x16MK4_8x8x4 final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "ARMV7_INT8X8X16_MK4_K8X8X4"; } bool usable(const KernSizeParam&) const override; @@ -266,7 +268,8 @@ public: class MatrixMulImpl::AlgoInt8x8x32MK4_4x2x16 final : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "ARMV7_INT8X8X32_MK4_4X2X16"; } bool usable(const KernSizeParam&) const override; diff --git a/dnn/src/common/algo_base.cpp b/dnn/src/common/algo_base.cpp index 966870ba..21ef9172 100644 --- a/dnn/src/common/algo_base.cpp +++ b/dnn/src/common/algo_base.cpp @@ -18,7 +18,8 @@ using namespace megdnn; #define FOREACH_ALGO_ATTRIBUTE(cb) \ cb(DEFAULT) \ cb(REPRODUCIBLE) \ - cb(NAIVE) + cb(NAIVE) \ + cb(USABLE_DEPEND_ON_SHAPE) namespace { inline const char* attr_str(const AlgoAttribute& attr) { diff --git a/dnn/src/cuda/convolution/backward_data/algo.h b/dnn/src/cuda/convolution/backward_data/algo.h index d216def3..00439d30 100644 --- a/dnn/src/cuda/convolution/backward_data/algo.h +++ b/dnn/src/cuda/convolution/backward_data/algo.h @@ -184,7 +184,8 @@ public: const char* name() const override { return "CHANNEL_WISE_SMALL"; } MEGDNN_DECL_ALGO_TYPE(CUDA_CHANWISE_SMALL) AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } }; diff --git a/dnn/src/cuda/local_share/forward/algo.h b/dnn/src/cuda/local_share/forward/algo.h index 83a53d4a..44498099 100644 --- a/dnn/src/cuda/local_share/forward/algo.h +++ b/dnn/src/cuda/local_share/forward/algo.h @@ -89,7 +89,8 @@ public: void exec(const ExecArgs& args) const override; AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { @@ -108,7 +109,8 @@ public: void exec(const ExecArgs& args) const override; AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { diff --git a/dnn/src/cuda/matrix_mul/algos.h b/dnn/src/cuda/matrix_mul/algos.h index 5bbb9245..fc6394e9 100644 --- a/dnn/src/cuda/matrix_mul/algos.h +++ b/dnn/src/cuda/matrix_mul/algos.h @@ -114,7 +114,8 @@ public: void exec(const ExecArgs& args) const override; MEGDNN_DECL_ALGO_TYPE(CUDA_CUBLAS) AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } }; @@ -231,7 +232,8 @@ public: const char* name() const override { return m_name.c_str(); } void exec(const ExecArgs& args) const override; AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } MEGDNN_DECL_ALGO_TYPE(CUDA_FLOAT32_SIMT_SPLIT_K) diff --git a/dnn/src/rocm/matrix_mul/algos.h b/dnn/src/rocm/matrix_mul/algos.h index b8354e31..4c8d10cc 100644 --- a/dnn/src/rocm/matrix_mul/algos.h +++ b/dnn/src/rocm/matrix_mul/algos.h @@ -100,7 +100,8 @@ public: const char* name() const override { return "BLAS"; } void exec(const ExecArgs& args) const override; AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } MEGDNN_DECL_ALGO_TYPE(ROCM_BLAS) }; diff --git a/dnn/src/x86/matrix_mul/algos.h b/dnn/src/x86/matrix_mul/algos.h index a47a0854..7216f22e 100644 --- a/dnn/src/x86/matrix_mul/algos.h +++ b/dnn/src/x86/matrix_mul/algos.h @@ -135,7 +135,8 @@ public: class MatrixMulImpl::AlgoF32MK8_8x8 : public AlgoBase { public: AlgoAttribute attribute() const override { - return AlgoAttribute::REPRODUCIBLE; + return AlgoAttribute::REPRODUCIBLE | + AlgoAttribute::USABLE_DEPEND_ON_SHAPE; } const char* name() const override { return "X86_F32MK8_8X8"; } bool usable(const KernSizeParam&) const override; diff --git a/src/opr/impl/search_policy/algo_chooser.cpp b/src/opr/impl/search_policy/algo_chooser.cpp index 6116e831..7821fb29 100644 --- a/src/opr/impl/search_policy/algo_chooser.cpp +++ b/src/opr/impl/search_policy/algo_chooser.cpp @@ -276,21 +276,6 @@ std::vector flatten_search_space( return ret; } -//! return pair -std::pair -extract_algo_attribute_from_execution_strategy( - const ExecutionStrategy& strategy) { - std::pair ret = - std::make_pair(AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT); - if (strategy & ExecutionStrategy::REPRODUCIBLE) { - ret.first |= AlgoAttribute::REPRODUCIBLE; - } - if (strategy & ExecutionStrategy::OPTIMIZED) { - ret.second |= AlgoAttribute::NAIVE; - } - return ret; -} - } // namespace namespace mgb { @@ -303,9 +288,9 @@ void AlgoChooser::profile(ExeContext& ctx, return; AlgoChooserProfileCache::Result prof_rst; - auto target_attr = - extract_algo_attribute_from_execution_strategy(selected_strategy); - std::string layouts_str = format_fixlayouts(ctx.layouts(), arity_in, arity_out); + auto target_attr = ctx.extract_algo_attribute(selected_strategy); + std::string layouts_str = + format_fixlayouts(ctx.layouts(), arity_in, arity_out); double cur_timeout = 0; auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( @@ -558,16 +543,15 @@ AlgoChooser::ExeContext::get_profile_result_from_cache( if (prof.empty()) return {}; - auto attr_from_strategy = - extract_algo_attribute_from_execution_strategy(selected_strategy); + auto target_attr = extract_algo_attribute(selected_strategy); for (auto&& i : prof) { auto attr_of_algo = static_cast(i.attribute); bool contain_attr_all_positive = - (attr_from_strategy.first == - (attr_of_algo & attr_from_strategy.first)); + (target_attr.first == + (attr_of_algo & target_attr.first)); bool contain_attr_any_negative = - static_cast(attr_of_algo & attr_from_strategy.second); + static_cast(attr_of_algo & target_attr.second); if (contain_attr_all_positive && !contain_attr_any_negative) { auto iter = algo_map.find(i.algo); mgb_assert(iter != algo_map.end(), @@ -586,8 +570,8 @@ AlgoChooser::ExeContext::get_profile_result_from_cache( mgb_log_error( "algos read from cache could not satisfy attribute with %s and " "without %s", - Algorithm::attribute_str(attr_from_strategy.first).c_str(), - Algorithm::attribute_str(attr_from_strategy.second).c_str()); + Algorithm::attribute_str(target_attr.first).c_str(), + Algorithm::attribute_str(target_attr.second).c_str()); mgb_trap(); MIDOUT_E @@ -606,8 +590,7 @@ AlgoChooser::ExeContext::choose_by_heuristic( } auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( owner_graph(), m_cn, m_execution_policy.workspace_limit); - auto attr = - extract_algo_attribute_from_execution_strategy(selected_strategy); + auto attr = extract_algo_attribute(selected_strategy); ImplExecutionPolicy policy; policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( @@ -668,9 +651,7 @@ void AlgoChooser::ExeContext::construct_execution_policy( if (retrive_from_cache) { policy.algo = get_profile_result_from_cache(selected_strategy).desc; if (!policy.algo.valid()) { - auto target_attr = - extract_algo_attribute_from_execution_strategy( - selected_strategy); + auto target_attr = extract_algo_attribute(selected_strategy); std::string layouts_str = format_fixlayouts(m_layouts, arity_in, arity_out); std::string msg = ssprintf( @@ -692,8 +673,7 @@ void AlgoChooser::ExeContext::construct_execution_policy( auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( owner_graph(), m_cn, m_execution_policy.workspace_limit); - auto attr = extract_algo_attribute_from_execution_strategy( - selected_strategy); + auto attr = extract_algo_attribute(selected_strategy); policy.algo = APPLY(m_megdnn_opr->get_algorithm_info_heuristic( args..., workspace_limit, attr.first, attr.second), @@ -837,6 +817,24 @@ AlgoChooser::ExeContext::construct_fake_preprocess_filter() const { return result; } +template +std::pair +AlgoChooser::ExeContext::extract_algo_attribute( + const ExecutionStrategy& strategy) const { + std::pair ret = + std::make_pair(AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT); + + //! from strategy + if (strategy & ExecutionStrategy::REPRODUCIBLE) { + ret.first |= AlgoAttribute::REPRODUCIBLE; + } + if (strategy & ExecutionStrategy::OPTMIZED) { + ret.second |= AlgoAttribute::NAIVE; + } + + return ret; +} + #define INST(Opr) \ template AlgoChooser::ExeContext::ExeContext( \ const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \ @@ -865,7 +863,10 @@ AlgoChooser::ExeContext::construct_fake_preprocess_filter() const { AlgoChooser::ExeContext::profile_single_algo( \ const typename AlgoChooser::ImplExecutionPolicy& \ policy, \ - double& timeout) const; + double& timeout) const; \ + template std::pair \ + AlgoChooser::ExeContext::extract_algo_attribute( \ + const ExecutionStrategy& strategy) const; MGB_FOREACH_FASTRUN_OPR(INST) diff --git a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h index 8d20d7c8..adf4263f 100644 --- a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h +++ b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h @@ -149,6 +149,16 @@ public: ImplExecutionPolicy& policy, bool retrive_from_cache = true) const; + /** + * \brief extract algo attribute from execution strategy and graph + * option. + * + * \param strategy select algo which matched this strategy + * \return pair + */ + std::pair extract_algo_attribute( + const ExecutionStrategy& strategy) const; + private: Maybe> construct_fake_preprocess_filter() const; };