GitOrigin-RevId: 81e32da034
release-1.2
@@ -19,6 +19,7 @@ | |||||
#include "megbrain/graph/extern_copr_api.h" | #include "megbrain/graph/extern_copr_api.h" | ||||
#include "megbrain/opr/dnn/convolution.h" | #include "megbrain/opr/dnn/convolution.h" | ||||
#include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
#include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||||
#include "megbrain/opr/utility.h" | #include "megbrain/opr/utility.h" | ||||
#include "megbrain/plugin/cpu_dispatch_checker.h" | #include "megbrain/plugin/cpu_dispatch_checker.h" | ||||
#include "megbrain/plugin/num_range_checker.h" | #include "megbrain/plugin/num_range_checker.h" | ||||
@@ -691,7 +692,7 @@ void run_test_st(Args &env) { | |||||
} | } | ||||
mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, env.workspace_limit); | mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, env.workspace_limit); | ||||
using S = opr::mixin::Convolution::ExecutionPolicy::Strategy; | |||||
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||||
S strategy = S::HEURISTIC; | S strategy = S::HEURISTIC; | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
if (env.use_fast_run) { | if (env.use_fast_run) { | ||||
@@ -15,6 +15,7 @@ | |||||
#include "megbrain/graph/event.h" | #include "megbrain/graph/event.h" | ||||
#include "megbrain/opr/dnn/batch_norm.h" | #include "megbrain/opr/dnn/batch_norm.h" | ||||
#include "megbrain/opr/dnn/local.h" | #include "megbrain/opr/dnn/local.h" | ||||
#include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||||
#include "megbrain/utils/shared_set.h" | #include "megbrain/utils/shared_set.h" | ||||
#include "megbrain/serialization/opr_shallow_copy.h" | #include "megbrain/serialization/opr_shallow_copy.h" | ||||
#include "megbrain/opr/basic_arith.h" | #include "megbrain/opr/basic_arith.h" | ||||
@@ -116,8 +117,8 @@ SymbolVarArray gopt::optimize_for_inference( | |||||
namespace { | namespace { | ||||
void modify_conv_strategy( | void modify_conv_strategy( | ||||
opr::mixin::Convolution& conv, | |||||
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) { | |||||
opr::mixin::AlgoChooserHelper& conv, | |||||
opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) { | |||||
auto policy = conv.execution_policy_transient(); | auto policy = conv.execution_policy_transient(); | ||||
policy.strategy = strategy; | policy.strategy = strategy; | ||||
conv.set_execution_policy(policy); | conv.set_execution_policy(policy); | ||||
@@ -126,13 +127,13 @@ void modify_conv_strategy( | |||||
template <typename Opr> | template <typename Opr> | ||||
void inplace_conv_opr_modifier( | void inplace_conv_opr_modifier( | ||||
OperatorNodeBase& opr, | OperatorNodeBase& opr, | ||||
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) { | |||||
opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) { | |||||
modify_conv_strategy( | modify_conv_strategy( | ||||
opr.cast_final_safe<Opr>(), | opr.cast_final_safe<Opr>(), | ||||
strategy); | strategy); | ||||
} | } | ||||
void modify_conv_policy_workspace_limit(opr::mixin::Convolution& conv, | |||||
void modify_conv_policy_workspace_limit(opr::mixin::AlgoChooserHelper& conv, | |||||
size_t workspace_limit) { | size_t workspace_limit) { | ||||
auto policy = conv.execution_policy_transient(); | auto policy = conv.execution_policy_transient(); | ||||
policy.workspace_limit = workspace_limit; | policy.workspace_limit = workspace_limit; | ||||
@@ -159,9 +160,9 @@ void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr, | |||||
void gopt::modify_opr_algo_strategy_inplace( | void gopt::modify_opr_algo_strategy_inplace( | ||||
const VarNodeArrayView& dest_vars, | const VarNodeArrayView& dest_vars, | ||||
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) { | |||||
opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) { | |||||
#if !MGB_ENABLE_FASTRUN | #if !MGB_ENABLE_FASTRUN | ||||
using S = opr::mixin::Convolution::ExecutionPolicy::Strategy; | |||||
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||||
if (strategy == S::PROFILE || strategy == S::PROFILE_REPRODUCIBLE) { | if (strategy == S::PROFILE || strategy == S::PROFILE_REPRODUCIBLE) { | ||||
mgb_throw(MegBrainError, "fastrun is disabled at compile time"); | mgb_throw(MegBrainError, "fastrun is disabled at compile time"); | ||||
} | } | ||||
@@ -190,16 +191,16 @@ void gopt::modify_opr_algo_strategy_inplace( | |||||
void gopt::enable_opr_algo_profiling_inplace( | void gopt::enable_opr_algo_profiling_inplace( | ||||
const VarNodeArrayView& dest_vars) { | const VarNodeArrayView& dest_vars) { | ||||
modify_opr_algo_strategy_inplace(dest_vars, | |||||
opr::mixin::Convolution::ExecutionPolicy:: | |||||
Strategy::PROFILE); | |||||
modify_opr_algo_strategy_inplace( | |||||
dest_vars, | |||||
opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy::PROFILE); | |||||
} | } | ||||
void gopt::enable_opr_use_profiling_cache_inplace( | void gopt::enable_opr_use_profiling_cache_inplace( | ||||
const VarNodeArrayView& dest_vars) { | const VarNodeArrayView& dest_vars) { | ||||
modify_opr_algo_strategy_inplace(dest_vars, | |||||
opr::mixin::Convolution::ExecutionPolicy:: | |||||
Strategy::PROFILE_HEURISTIC); | |||||
modify_opr_algo_strategy_inplace( | |||||
dest_vars, opr::mixin::AlgoChooserHelper::ExecutionPolicy:: | |||||
Strategy::PROFILE_HEURISTIC); | |||||
} | } | ||||
@@ -14,6 +14,7 @@ | |||||
#include "megbrain/gopt/framework.h" | #include "megbrain/gopt/framework.h" | ||||
#include "megbrain/graph/cg.h" | #include "megbrain/graph/cg.h" | ||||
#include "megbrain/opr/dnn/convolution.h" | #include "megbrain/opr/dnn/convolution.h" | ||||
#include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace gopt { | namespace gopt { | ||||
@@ -342,7 +343,7 @@ namespace gopt { | |||||
*/ | */ | ||||
void modify_opr_algo_strategy_inplace( | void modify_opr_algo_strategy_inplace( | ||||
const VarNodeArrayView& dest_vars, | const VarNodeArrayView& dest_vars, | ||||
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy); | |||||
opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy); | |||||
/*! | /*! | ||||
* \brief enable PROFILE execution strategy for oprs with multiple | * \brief enable PROFILE execution strategy for oprs with multiple | ||||
@@ -13,7 +13,7 @@ | |||||
#include "megbrain/opr/dnn/convolution.h" | #include "megbrain/opr/dnn/convolution.h" | ||||
#include "megbrain/opr/io.h" | #include "megbrain/opr/io.h" | ||||
#include "megbrain/opr/search_policy/algo_chooser.h" | #include "megbrain/opr/search_policy/algo_chooser.h" | ||||
#include "megbrain/opr/search_policy/profiler.h" | |||||
#include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||||
#include "megbrain/graph/grad_impl.h" | #include "megbrain/graph/grad_impl.h" | ||||
#include "megbrain/system.h" | #include "megbrain/system.h" | ||||
@@ -38,18 +38,9 @@ using intl::WorkspaceLimitGetter; | |||||
/* ==================== misc impl ==================== */ | /* ==================== misc impl ==================== */ | ||||
mixin::Convolution::~Convolution() = default; | |||||
void mixin::Convolution::set_execution_policy(const ExecutionPolicy& policy) { | |||||
mgb_throw_if( | |||||
m_policy_accessed, InternalError, | |||||
"attempt to modify ExecutionPolicy after it has been accessed"); | |||||
m_policy = policy; | |||||
} | |||||
template <class MgbOpr, class MegDNNOpr> | template <class MgbOpr, class MegDNNOpr> | ||||
void mixin::Convolution::init_output_static_infer_desc_for_bwd_data( | |||||
cg::OperatorNodeBase* self) { | |||||
void mixin::ConvolutionBackwardDataMixin:: | |||||
init_output_static_infer_desc_for_bwd_data(cg::OperatorNodeBase* self) { | |||||
using namespace cg::static_infer; | using namespace cg::static_infer; | ||||
auto&& mgr = self->owner_graph()->static_infer_manager(); | auto&& mgr = self->owner_graph()->static_infer_manager(); | ||||
@@ -93,7 +84,7 @@ void mixin::Convolution::init_output_static_infer_desc_for_bwd_data( | |||||
}; | }; | ||||
inp_deps.push_back({self->output(0), DepType::SHAPE}); | inp_deps.push_back({self->output(0), DepType::SHAPE}); | ||||
auto workspace_dep_var = | auto workspace_dep_var = | ||||
WorkspaceLimitGetter::register_to_graph(self->owner_graph()); | |||||
intl::WorkspaceLimitGetter::register_to_graph(self->owner_graph()); | |||||
if (workspace_dep_var) { | if (workspace_dep_var) { | ||||
inp_deps.push_back({workspace_dep_var, DepType::VALUE}); | inp_deps.push_back({workspace_dep_var, DepType::VALUE}); | ||||
} | } | ||||
@@ -101,11 +92,7 @@ void mixin::Convolution::init_output_static_infer_desc_for_bwd_data( | |||||
{SourceType::DEP, inp_deps, infer_wk}); | {SourceType::DEP, inp_deps, infer_wk}); | ||||
} | } | ||||
#define IMPL_CONV(_cls) \ | |||||
std::pair<const void*, size_t> _cls::param_blob() const { \ | |||||
return {¶m(), sizeof(Param)}; \ | |||||
} \ | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(_cls) | |||||
#define IMPL_CONV(_cls) MGB_DYN_TYPE_OBJ_FINAL_IMPL(_cls) | |||||
class mixin::WeightPreprocessExecutor::PreprocessedFilterExecDep final | class mixin::WeightPreprocessExecutor::PreprocessedFilterExecDep final | ||||
: public cg::GraphExecutable::ExecDependency { | : public cg::GraphExecutable::ExecDependency { | ||||
@@ -11,6 +11,7 @@ | |||||
*/ | */ | ||||
#include "megbrain/opr/search_policy/algo_chooser.h" | #include "megbrain/opr/search_policy/algo_chooser.h" | ||||
#include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||||
#include "megbrain/opr/search_policy/profiler.h" | #include "megbrain/opr/search_policy/profiler.h" | ||||
#include "../internal/invoke.h" | #include "../internal/invoke.h" | ||||
@@ -200,7 +201,7 @@ size_t AlgoChooser<Opr>::setup_algo(const TensorLayoutArray& layouts, | |||||
template <typename Opr> | template <typename Opr> | ||||
typename AlgoChooser<Opr>::ImplAlgo AlgoChooser<Opr>::get_algo( | typename AlgoChooser<Opr>::ImplAlgo AlgoChooser<Opr>::get_algo( | ||||
ExeContext& ctx) { | ExeContext& ctx) { | ||||
using S = mixin::Convolution::ExecutionPolicy::Strategy; | |||||
using S = mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||||
MGB_MARK_USED_VAR(TIMEOUT_TOLERANCE); | MGB_MARK_USED_VAR(TIMEOUT_TOLERANCE); | ||||
switch (ctx.mgb_opr()->execution_policy().strategy) { | switch (ctx.mgb_opr()->execution_policy().strategy) { | ||||
case S::HEURISTIC: | case S::HEURISTIC: | ||||
@@ -0,0 +1,33 @@ | |||||
/** | |||||
* \file src/opr/impl/search_policy/algo_chooser_helper.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||||
#include "megbrain/opr/search_policy/algo_chooser.h" | |||||
#include "megbrain/graph/cg.h" | |||||
#include "../internal/megdnn_opr_wrapper.inl" | |||||
using namespace mgb; | |||||
using namespace opr; | |||||
using namespace mixin; | |||||
/* ==================== misc impl ==================== */ | |||||
AlgoChooserHelper::~AlgoChooserHelper() = default; | |||||
void AlgoChooserHelper::set_execution_policy(const ExecutionPolicy& policy) { | |||||
mgb_throw_if( | |||||
m_policy_accessed, InternalError, | |||||
"attempt to modify ExecutionPolicy after it has been accessed"); | |||||
m_policy = policy; | |||||
} | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -13,6 +13,7 @@ | |||||
#include "megbrain/opr/search_policy/profiler.h" | #include "megbrain/opr/search_policy/profiler.h" | ||||
#include "../internal/invoke.h" | #include "../internal/invoke.h" | ||||
#include "../internal/megdnn_opr_wrapper.inl" | |||||
#if MGB_ROCM | #if MGB_ROCM | ||||
#include "hcc_detail/hcc_defs_prologue.h" | #include "hcc_detail/hcc_defs_prologue.h" | ||||
@@ -11,6 +11,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | ||||
#include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||||
#include "megbrain/utils/persistent_cache.h" | #include "megbrain/utils/persistent_cache.h" | ||||
#include "megbrain/opr/param_defs.h" | #include "megbrain/opr/param_defs.h" | ||||
#include "megdnn/oprs/nn.h" | #include "megdnn/oprs/nn.h" | ||||
@@ -19,68 +20,14 @@ namespace mgb { | |||||
namespace opr { | namespace opr { | ||||
namespace mixin { | namespace mixin { | ||||
/*! | |||||
* \brief Convolution base class | |||||
*/ | |||||
class Convolution { | |||||
public: | |||||
using ExecutionPolicy = megdnn::param::ExecutionPolicy; | |||||
using AlgorithmInfo = megdnn::detail::Algorithm::Info; | |||||
using AlgoChooserHook = | |||||
std::function<AlgorithmInfo(const OperatorNodeBase*)>; | |||||
const ExecutionPolicy& execution_policy() const { | |||||
if (!m_policy_accessed) { | |||||
m_policy_accessed = true; | |||||
} | |||||
return m_policy; | |||||
} | |||||
/*! | |||||
* \brief get current policy without marking it as having been accessed | |||||
* | |||||
* This is primarily used for getting current policy before calling | |||||
* set_execution_policy(). | |||||
*/ | |||||
const ExecutionPolicy& execution_policy_transient() const { | |||||
return m_policy; | |||||
} | |||||
/*! | |||||
* \brief modify execution policy | |||||
* | |||||
* Exception would be thrown if execution_policy() has been accessed, | |||||
* since it would influence cache and many other decisions. | |||||
*/ | |||||
void set_execution_policy(const ExecutionPolicy& policy); | |||||
AlgoChooserProfileCache& profile_cache() const; | |||||
virtual std::pair<const void*, size_t> param_blob() const = 0; | |||||
/*! | |||||
* \brief register a hook to implement custom algo chooser | |||||
*/ | |||||
void setup_algo_chooser(AlgoChooserHook&& func) { | |||||
m_algo_chooser = func; | |||||
} | |||||
AlgoChooserHook algo_chooser() const { | |||||
return m_algo_chooser; | |||||
} | |||||
protected: | |||||
~Convolution(); | |||||
mutable bool m_policy_accessed = false; | |||||
ExecutionPolicy m_policy; | |||||
AlgoChooserHook m_algo_chooser; | |||||
class ConvolutionBackwardDataMixin : public cg::OperatorNodeMixinBase { | |||||
protected: | |||||
//! init output desc for conv backward data oprs; it handles both grad | |||||
//! usage and deconv usage | |||||
template <class MgbOpr, class MegDNNOpr> | |||||
static void init_output_static_infer_desc_for_bwd_data( | |||||
cg::OperatorNodeBase* self); | |||||
//! init output desc for conv backward data oprs; it handles both grad | |||||
//! usage and deconv usage | |||||
template <class MgbOpr, class MegDNNOpr> | |||||
static void init_output_static_infer_desc_for_bwd_data( | |||||
cg::OperatorNodeBase* self); | |||||
}; | }; | ||||
class WeightPreprocessExecutor : public cg::OperatorNodeMixinBase { | class WeightPreprocessExecutor : public cg::OperatorNodeMixinBase { | ||||
@@ -153,7 +100,7 @@ class ConvolutionTestingPeer; | |||||
} // namespace testing | } // namespace testing | ||||
MGB_DEFINE_OPR_CLASS(ConvolutionForward, | MGB_DEFINE_OPR_CLASS(ConvolutionForward, | ||||
intl::ConvolutionForwardBase, public mixin::Convolution) // { | |||||
intl::ConvolutionForwardBase, public mixin::AlgoChooserHelper) // { | |||||
void init_output_dtype() override; | void init_output_dtype() override; | ||||
size_t get_workspace_size_bytes( | size_t get_workspace_size_bytes( | ||||
@@ -183,12 +130,11 @@ MGB_DEFINE_OPR_CLASS(ConvolutionForward, | |||||
const ExecutionPolicy &policy = {}, | const ExecutionPolicy &policy = {}, | ||||
const OperatorNodeConfig &config = {}); | const OperatorNodeConfig &config = {}); | ||||
std::pair<const void*, size_t> param_blob() const override; | |||||
}; | }; | ||||
using Convolution = ConvolutionForward; | using Convolution = ConvolutionForward; | ||||
MGB_DEFINE_OPR_CLASS(ConvBiasForward, intl::ConvBiasForwardBase, | MGB_DEFINE_OPR_CLASS(ConvBiasForward, intl::ConvBiasForwardBase, | ||||
public mixin::Convolution) // { | |||||
public mixin::AlgoChooserHelper) // { | |||||
void init_output_dtype() override; | void init_output_dtype() override; | ||||
size_t get_workspace_size_bytes( | size_t get_workspace_size_bytes( | ||||
@@ -240,7 +186,6 @@ public: | |||||
const ExecutionPolicy& policy = {}, | const ExecutionPolicy& policy = {}, | ||||
const OperatorNodeConfig& config = {}); | const OperatorNodeConfig& config = {}); | ||||
std::pair<const void*, size_t> param_blob() const override; | |||||
static void check_winograd_param_valid( | static void check_winograd_param_valid( | ||||
const megdnn::ConvBias::WinogradParam& param, | const megdnn::ConvBias::WinogradParam& param, | ||||
@@ -253,10 +198,12 @@ using ConvBias = ConvBiasForward; | |||||
/*! | /*! | ||||
* \brief Can be used in two ways: compute gradient of conv, or deconv | * \brief Can be used in two ways: compute gradient of conv, or deconv | ||||
*/ | */ | ||||
MGB_DEFINE_OPR_CLASS(ConvolutionBackwardData, | |||||
MGB_DEFINE_OPR_CLASS( | |||||
ConvolutionBackwardData, | |||||
cg::SingleCNOperatorNodeBaseT< | cg::SingleCNOperatorNodeBaseT< | ||||
mixin::MegDNNOprHolderImpl<megdnn::ConvolutionBackwardData>>, | |||||
public mixin::Convolution) // { | |||||
mixin::MegDNNOprHolderImpl<megdnn::ConvolutionBackwardData>>, | |||||
public mixin::AlgoChooserHelper, | |||||
public mixin::ConvolutionBackwardDataMixin) // { | |||||
void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
void init_output_dtype() override; | void init_output_dtype() override; | ||||
void init_output_format() override; | void init_output_format() override; | ||||
@@ -296,12 +243,11 @@ MGB_DEFINE_OPR_CLASS(ConvolutionBackwardData, | |||||
return make(filter, data, param, policy, config); | return make(filter, data, param, policy, config); | ||||
} | } | ||||
std::pair<const void*, size_t> param_blob() const override; | |||||
}; | }; | ||||
MGB_DEFINE_OPR_CLASS(ConvolutionBackwardFilter, | MGB_DEFINE_OPR_CLASS(ConvolutionBackwardFilter, | ||||
intl::MegDNNOprWrapperBwd<megdnn::ConvolutionBackwardFilter>, | intl::MegDNNOprWrapperBwd<megdnn::ConvolutionBackwardFilter>, | ||||
public mixin::Convolution ) // { | |||||
public mixin::AlgoChooserHelper ) // { | |||||
size_t get_workspace_size_bytes( | size_t get_workspace_size_bytes( | ||||
@@ -318,7 +264,6 @@ MGB_DEFINE_OPR_CLASS(ConvolutionBackwardFilter, | |||||
const ExecutionPolicy &policy = {}, | const ExecutionPolicy &policy = {}, | ||||
const OperatorNodeConfig &config = {}); | const OperatorNodeConfig &config = {}); | ||||
std::pair<const void*, size_t> param_blob() const override; | |||||
}; | }; | ||||
MGB_DEFINE_OPR_CLASS(MaskConvolution, | MGB_DEFINE_OPR_CLASS(MaskConvolution, | ||||
@@ -350,7 +295,7 @@ public: | |||||
MGB_DEFINE_OPR_CLASS(Convolution3DForward, | MGB_DEFINE_OPR_CLASS(Convolution3DForward, | ||||
intl::MegDNNOprWrapperFwd<megdnn::Convolution3DForward>, | intl::MegDNNOprWrapperFwd<megdnn::Convolution3DForward>, | ||||
public mixin::Convolution) // { | |||||
public mixin::AlgoChooserHelper) // { | |||||
void init_output_dtype() override; | void init_output_dtype() override; | ||||
size_t get_workspace_size_bytes( | size_t get_workspace_size_bytes( | ||||
@@ -368,17 +313,18 @@ MGB_DEFINE_OPR_CLASS(Convolution3DForward, | |||||
const ExecutionPolicy &policy = {}, | const ExecutionPolicy &policy = {}, | ||||
const OperatorNodeConfig &config = {}); | const OperatorNodeConfig &config = {}); | ||||
std::pair<const void*, size_t> param_blob() const override; | |||||
}; | }; | ||||
using Convolution3D = Convolution3DForward; | using Convolution3D = Convolution3DForward; | ||||
/*! | /*! | ||||
* \brief Can be used in two ways: compute gradient of conv, or deconv | * \brief Can be used in two ways: compute gradient of conv, or deconv | ||||
*/ | */ | ||||
MGB_DEFINE_OPR_CLASS(Convolution3DBackwardData, | |||||
MGB_DEFINE_OPR_CLASS( | |||||
Convolution3DBackwardData, | |||||
cg::SingleCNOperatorNodeBaseT< | cg::SingleCNOperatorNodeBaseT< | ||||
mixin::MegDNNOprHolderImpl<megdnn::Convolution3DBackwardData>>, | |||||
public mixin::Convolution) // { | |||||
mixin::MegDNNOprHolderImpl<megdnn::Convolution3DBackwardData>>, | |||||
public mixin::AlgoChooserHelper, | |||||
public mixin::ConvolutionBackwardDataMixin) // { | |||||
void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
void add_input_layout_constraint() override; | void add_input_layout_constraint() override; | ||||
@@ -416,12 +362,11 @@ MGB_DEFINE_OPR_CLASS(Convolution3DBackwardData, | |||||
return make(filter, data, param, policy, config); | return make(filter, data, param, policy, config); | ||||
} | } | ||||
std::pair<const void*, size_t> param_blob() const override; | |||||
}; | }; | ||||
MGB_DEFINE_OPR_CLASS(Convolution3DBackwardFilter, | MGB_DEFINE_OPR_CLASS(Convolution3DBackwardFilter, | ||||
intl::MegDNNOprWrapperBwd<megdnn::Convolution3DBackwardFilter>, | intl::MegDNNOprWrapperBwd<megdnn::Convolution3DBackwardFilter>, | ||||
public mixin::Convolution) // { | |||||
public mixin::AlgoChooserHelper) // { | |||||
size_t get_workspace_size_bytes( | size_t get_workspace_size_bytes( | ||||
const TensorShapeArray &input_shapes, | const TensorShapeArray &input_shapes, | ||||
@@ -437,12 +382,11 @@ MGB_DEFINE_OPR_CLASS(Convolution3DBackwardFilter, | |||||
const ExecutionPolicy &policy = {}, | const ExecutionPolicy &policy = {}, | ||||
const OperatorNodeConfig &config = {}); | const OperatorNodeConfig &config = {}); | ||||
std::pair<const void*, size_t> param_blob() const override; | |||||
}; | }; | ||||
MGB_DEFINE_OPR_CLASS(LocalShareForward, | MGB_DEFINE_OPR_CLASS(LocalShareForward, | ||||
intl::MegDNNOprWrapperFwd<megdnn::LocalShareForward>, | intl::MegDNNOprWrapperFwd<megdnn::LocalShareForward>, | ||||
public mixin::Convolution) // { | |||||
public mixin::AlgoChooserHelper) // { | |||||
void init_output_dtype() override; | void init_output_dtype() override; | ||||
void init_output_format() override; | void init_output_format() override; | ||||
@@ -457,7 +401,6 @@ public: | |||||
static SymbolVar make(SymbolVar src, SymbolVar filter, const Param& param = {}, | static SymbolVar make(SymbolVar src, SymbolVar filter, const Param& param = {}, | ||||
const ExecutionPolicy& policy = {}, | const ExecutionPolicy& policy = {}, | ||||
const OperatorNodeConfig& config = {}); | const OperatorNodeConfig& config = {}); | ||||
std::pair<const void*, size_t> param_blob() const override; | |||||
}; | }; | ||||
using LocalShare = LocalShareForward; | using LocalShare = LocalShareForward; | ||||
@@ -465,7 +408,8 @@ MGB_DEFINE_OPR_CLASS( | |||||
LocalShareBackwardData, | LocalShareBackwardData, | ||||
cg::SingleCNOperatorNodeBaseT< | cg::SingleCNOperatorNodeBaseT< | ||||
mixin::MegDNNOprHolderImpl<megdnn::LocalShareBackwardData>>, | mixin::MegDNNOprHolderImpl<megdnn::LocalShareBackwardData>>, | ||||
public mixin::Convolution) // { | |||||
public mixin::AlgoChooserHelper, | |||||
public mixin::ConvolutionBackwardDataMixin) // { | |||||
void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
void init_output_dtype() override; | void init_output_dtype() override; | ||||
@@ -485,13 +429,12 @@ public: | |||||
const ExecutionPolicy& policy = {}, | const ExecutionPolicy& policy = {}, | ||||
const OperatorNodeConfig& config = {}); | const OperatorNodeConfig& config = {}); | ||||
std::pair<const void*, size_t> param_blob() const override; | |||||
}; | }; | ||||
MGB_DEFINE_OPR_CLASS( | MGB_DEFINE_OPR_CLASS( | ||||
LocalShareBackwardFilter, | LocalShareBackwardFilter, | ||||
intl::MegDNNOprWrapperBwd<megdnn::LocalShareBackwardFilter>, | intl::MegDNNOprWrapperBwd<megdnn::LocalShareBackwardFilter>, | ||||
public mixin::Convolution) // { | |||||
public mixin::AlgoChooserHelper) // { | |||||
size_t get_workspace_size_bytes( | size_t get_workspace_size_bytes( | ||||
const TensorShapeArray& input_shapes, | const TensorShapeArray& input_shapes, | ||||
@@ -506,12 +449,11 @@ public: | |||||
const ExecutionPolicy& policy = {}, | const ExecutionPolicy& policy = {}, | ||||
const OperatorNodeConfig& config = {}); | const OperatorNodeConfig& config = {}); | ||||
std::pair<const void*, size_t> param_blob() const override; | |||||
}; | }; | ||||
MGB_DEFINE_OPR_CLASS(DeformableConvForward, | MGB_DEFINE_OPR_CLASS(DeformableConvForward, | ||||
intl::MegDNNOprWrapperFwd<megdnn::DeformableConvForward>, | intl::MegDNNOprWrapperFwd<megdnn::DeformableConvForward>, | ||||
public mixin::Convolution) // { | |||||
public mixin::AlgoChooserHelper) // { | |||||
public: | public: | ||||
DeformableConvForward( | DeformableConvForward( | ||||
VarNode *src, VarNode *filter, VarNode *offset, VarNode *mask, | VarNode *src, VarNode *filter, VarNode *offset, VarNode *mask, | ||||
@@ -525,7 +467,6 @@ MGB_DEFINE_OPR_CLASS(DeformableConvForward, | |||||
const ExecutionPolicy &policy = {}, | const ExecutionPolicy &policy = {}, | ||||
const OperatorNodeConfig &config = {}); | const OperatorNodeConfig &config = {}); | ||||
std::pair<const void*, size_t> param_blob() const override; | |||||
private: | private: | ||||
void init_output_dtype() override; | void init_output_dtype() override; | ||||
void init_output_format() override; | void init_output_format() override; | ||||
@@ -537,7 +478,8 @@ using DeformableConv = DeformableConvForward; | |||||
MGB_DEFINE_OPR_CLASS(DeformableConvBackwardData, | MGB_DEFINE_OPR_CLASS(DeformableConvBackwardData, | ||||
intl::DeformableConvBackwardDataBase, | intl::DeformableConvBackwardDataBase, | ||||
public mixin::Convolution) // { | |||||
public mixin::AlgoChooserHelper, | |||||
public mixin::ConvolutionBackwardDataMixin) // { | |||||
public: | public: | ||||
DeformableConvBackwardData( | DeformableConvBackwardData( | ||||
VarNode * src, VarNode * filter, VarNode * offset, VarNode * mask, | VarNode * src, VarNode * filter, VarNode * offset, VarNode * mask, | ||||
@@ -557,7 +499,6 @@ public: | |||||
const OperatorNodeConfig& config = {}); | const OperatorNodeConfig& config = {}); | ||||
void scn_do_execute() override; | void scn_do_execute() override; | ||||
std::pair<const void*, size_t> param_blob() const override; | |||||
private: | private: | ||||
void get_output_var_shape(const TensorShapeArray& inp_shape, | void get_output_var_shape(const TensorShapeArray& inp_shape, | ||||
@@ -578,7 +519,7 @@ private: | |||||
MGB_DEFINE_OPR_CLASS( | MGB_DEFINE_OPR_CLASS( | ||||
DeformableConvBackwardFilter, | DeformableConvBackwardFilter, | ||||
intl::MegDNNOprWrapperBwd<megdnn::DeformableConvBackwardFilter>, | intl::MegDNNOprWrapperBwd<megdnn::DeformableConvBackwardFilter>, | ||||
public mixin::Convolution) // { | |||||
public mixin::AlgoChooserHelper) // { | |||||
public: | public: | ||||
DeformableConvBackwardFilter( | DeformableConvBackwardFilter( | ||||
VarNode * src, VarNode * filter, VarNode * offset, VarNode * mask, | VarNode * src, VarNode * filter, VarNode * offset, VarNode * mask, | ||||
@@ -592,7 +533,6 @@ public: | |||||
const OperatorNodeConfig& config = {}); | const OperatorNodeConfig& config = {}); | ||||
void scn_do_execute() override; | void scn_do_execute() override; | ||||
std::pair<const void*, size_t> param_blob() const override; | |||||
private: | private: | ||||
size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, | size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, | ||||
@@ -601,7 +541,7 @@ private: | |||||
}; | }; | ||||
MGB_DEFINE_OPR_CLASS(BatchConvBiasForward, intl::BatchConvBiasForwardBase, | MGB_DEFINE_OPR_CLASS(BatchConvBiasForward, intl::BatchConvBiasForwardBase, | ||||
public mixin::Convolution) // { | |||||
public mixin::AlgoChooserHelper) // { | |||||
void init_output_dtype() override; | void init_output_dtype() override; | ||||
size_t get_workspace_size_bytes( | size_t get_workspace_size_bytes( | ||||
@@ -650,7 +590,6 @@ public: | |||||
const ExecutionPolicy& policy = {}, | const ExecutionPolicy& policy = {}, | ||||
const OperatorNodeConfig& config = {}); | const OperatorNodeConfig& config = {}); | ||||
std::pair<const void*, size_t> param_blob() const override; | |||||
}; | }; | ||||
using BatchConvBias = BatchConvBiasForward; | using BatchConvBias = BatchConvBiasForward; | ||||
@@ -13,6 +13,7 @@ | |||||
#pragma once | #pragma once | ||||
#include "megbrain/opr/search_policy/profiler.h" | #include "megbrain/opr/search_policy/profiler.h" | ||||
#include "megbrain/opr/dnn/convolution.h" | |||||
template <class MegDNNOpr> | template <class MegDNNOpr> | ||||
struct MegDNNOpr2MGBOpr; | struct MegDNNOpr2MGBOpr; | ||||
@@ -0,0 +1,80 @@ | |||||
/** | |||||
* \file src/opr/include/megbrain/opr/search_policy/algo_chooser_helper.h | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||||
* implied. | |||||
*/ | |||||
#pragma once | |||||
#include "megbrain/graph/operator_node.h" | |||||
#include "megbrain/opr/param_defs.h" | |||||
#include "megdnn/oprs/base.h" | |||||
#include "megdnn/oprs/nn.h" | |||||
namespace mgb { | |||||
namespace opr { | |||||
namespace mixin { | |||||
/*! | |||||
* \brief base class for the opr which can be tuning | |||||
*/ | |||||
class AlgoChooserHelper : cg::OperatorNodeMixinBase { | |||||
public: | |||||
using ExecutionPolicy = megdnn::param::ExecutionPolicy; | |||||
using AlgorithmInfo = megdnn::detail::Algorithm::Info; | |||||
using AlgoChooserHook = | |||||
std::function<AlgorithmInfo(const cg::OperatorNodeBase*)>; | |||||
const ExecutionPolicy& execution_policy() const { | |||||
if (!m_policy_accessed) { | |||||
m_policy_accessed = true; | |||||
} | |||||
return m_policy; | |||||
} | |||||
/*! | |||||
* \brief get current policy without marking it as having been accessed | |||||
* | |||||
* This is primarily used for getting current policy before calling | |||||
* set_execution_policy(). | |||||
*/ | |||||
const ExecutionPolicy& execution_policy_transient() const { | |||||
return m_policy; | |||||
} | |||||
/*! | |||||
* \brief modify execution policy | |||||
* | |||||
* Exception would be thrown if execution_policy() has been accessed, | |||||
* since it would influence cache and many other decisions. | |||||
*/ | |||||
void set_execution_policy(const ExecutionPolicy& policy); | |||||
/*! | |||||
* \brief register a hook to implement custom algo chooser | |||||
*/ | |||||
void setup_algo_chooser(AlgoChooserHook&& func) { m_algo_chooser = func; } | |||||
AlgoChooserHook algo_chooser() const { return m_algo_chooser; } | |||||
protected: | |||||
~AlgoChooserHelper(); | |||||
mutable bool m_policy_accessed = false; | |||||
ExecutionPolicy m_policy; | |||||
AlgoChooserHook m_algo_chooser; | |||||
}; | |||||
} // namespace mixin | |||||
} // namespace opr | |||||
} // namespace mgb | |||||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -12,9 +12,10 @@ | |||||
#pragma once | #pragma once | ||||
#include "megbrain/opr/dnn/convolution.h" | |||||
#include "megbrain/utils/hash_ct.h" | #include "megbrain/utils/hash_ct.h" | ||||
#include "megbrain/utils/timer.h" | #include "megbrain/utils/timer.h" | ||||
#include "megbrain/system.h" | |||||
#include "megbrain/comp_node.h" | |||||
#include "megdnn/basic_types.h" | #include "megdnn/basic_types.h" | ||||
#include "megdnn/oprs/nn.h" | #include "megdnn/oprs/nn.h" | ||||
@@ -127,15 +128,15 @@ class TimedProfiler { | |||||
static constexpr int arity_out = OprArityTrait<Opr>::arity_out; | static constexpr int arity_out = OprArityTrait<Opr>::arity_out; | ||||
static constexpr int arity = OprArityTrait<Opr>::arity; | static constexpr int arity = OprArityTrait<Opr>::arity; | ||||
using ConvTensorShapes = std::array<TensorShape, arity>; | |||||
using TensorShapeArray = std::array<megdnn::TensorShape, arity>; | |||||
public: | public: | ||||
struct Param { | struct Param { | ||||
char algo_name[128]; | char algo_name[128]; | ||||
size_t workspace; | size_t workspace; | ||||
DTypeEnum dtypes[arity]; | |||||
megdnn::DTypeEnum dtypes[arity]; | |||||
CompNode::Locator comp_node_loc; | CompNode::Locator comp_node_loc; | ||||
ConvTensorShapes shapes; | |||||
TensorShapeArray shapes; | |||||
typename Opr::Param opr_param; | typename Opr::Param opr_param; | ||||
bool allow_weight_preprocess; | bool allow_weight_preprocess; | ||||