GitOrigin-RevId: 81e32da034
release-1.2
@@ -19,6 +19,7 @@ | |||
#include "megbrain/graph/extern_copr_api.h" | |||
#include "megbrain/opr/dnn/convolution.h" | |||
#include "megbrain/opr/io.h" | |||
#include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||
#include "megbrain/opr/utility.h" | |||
#include "megbrain/plugin/cpu_dispatch_checker.h" | |||
#include "megbrain/plugin/num_range_checker.h" | |||
@@ -691,7 +692,7 @@ void run_test_st(Args &env) { | |||
} | |||
mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, env.workspace_limit); | |||
using S = opr::mixin::Convolution::ExecutionPolicy::Strategy; | |||
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||
S strategy = S::HEURISTIC; | |||
#if MGB_ENABLE_FASTRUN | |||
if (env.use_fast_run) { | |||
@@ -15,6 +15,7 @@ | |||
#include "megbrain/graph/event.h" | |||
#include "megbrain/opr/dnn/batch_norm.h" | |||
#include "megbrain/opr/dnn/local.h" | |||
#include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||
#include "megbrain/utils/shared_set.h" | |||
#include "megbrain/serialization/opr_shallow_copy.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
@@ -116,8 +117,8 @@ SymbolVarArray gopt::optimize_for_inference( | |||
namespace { | |||
void modify_conv_strategy( | |||
opr::mixin::Convolution& conv, | |||
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) { | |||
opr::mixin::AlgoChooserHelper& conv, | |||
opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) { | |||
auto policy = conv.execution_policy_transient(); | |||
policy.strategy = strategy; | |||
conv.set_execution_policy(policy); | |||
@@ -126,13 +127,13 @@ void modify_conv_strategy( | |||
template <typename Opr> | |||
void inplace_conv_opr_modifier( | |||
OperatorNodeBase& opr, | |||
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) { | |||
opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) { | |||
modify_conv_strategy( | |||
opr.cast_final_safe<Opr>(), | |||
strategy); | |||
} | |||
void modify_conv_policy_workspace_limit(opr::mixin::Convolution& conv, | |||
void modify_conv_policy_workspace_limit(opr::mixin::AlgoChooserHelper& conv, | |||
size_t workspace_limit) { | |||
auto policy = conv.execution_policy_transient(); | |||
policy.workspace_limit = workspace_limit; | |||
@@ -159,9 +160,9 @@ void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr, | |||
void gopt::modify_opr_algo_strategy_inplace( | |||
const VarNodeArrayView& dest_vars, | |||
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) { | |||
opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) { | |||
#if !MGB_ENABLE_FASTRUN | |||
using S = opr::mixin::Convolution::ExecutionPolicy::Strategy; | |||
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||
if (strategy == S::PROFILE || strategy == S::PROFILE_REPRODUCIBLE) { | |||
mgb_throw(MegBrainError, "fastrun is disabled at compile time"); | |||
} | |||
@@ -190,16 +191,16 @@ void gopt::modify_opr_algo_strategy_inplace( | |||
void gopt::enable_opr_algo_profiling_inplace( | |||
const VarNodeArrayView& dest_vars) { | |||
modify_opr_algo_strategy_inplace(dest_vars, | |||
opr::mixin::Convolution::ExecutionPolicy:: | |||
Strategy::PROFILE); | |||
modify_opr_algo_strategy_inplace( | |||
dest_vars, | |||
opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy::PROFILE); | |||
} | |||
void gopt::enable_opr_use_profiling_cache_inplace( | |||
const VarNodeArrayView& dest_vars) { | |||
modify_opr_algo_strategy_inplace(dest_vars, | |||
opr::mixin::Convolution::ExecutionPolicy:: | |||
Strategy::PROFILE_HEURISTIC); | |||
modify_opr_algo_strategy_inplace( | |||
dest_vars, opr::mixin::AlgoChooserHelper::ExecutionPolicy:: | |||
Strategy::PROFILE_HEURISTIC); | |||
} | |||
@@ -14,6 +14,7 @@ | |||
#include "megbrain/gopt/framework.h" | |||
#include "megbrain/graph/cg.h" | |||
#include "megbrain/opr/dnn/convolution.h" | |||
#include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||
namespace mgb { | |||
namespace gopt { | |||
@@ -342,7 +343,7 @@ namespace gopt { | |||
*/ | |||
void modify_opr_algo_strategy_inplace( | |||
const VarNodeArrayView& dest_vars, | |||
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy); | |||
opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy); | |||
/*! | |||
* \brief enable PROFILE execution strategy for oprs with multiple | |||
@@ -13,7 +13,7 @@ | |||
#include "megbrain/opr/dnn/convolution.h" | |||
#include "megbrain/opr/io.h" | |||
#include "megbrain/opr/search_policy/algo_chooser.h" | |||
#include "megbrain/opr/search_policy/profiler.h" | |||
#include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||
#include "megbrain/graph/grad_impl.h" | |||
#include "megbrain/system.h" | |||
@@ -38,18 +38,9 @@ using intl::WorkspaceLimitGetter; | |||
/* ==================== misc impl ==================== */ | |||
mixin::Convolution::~Convolution() = default; | |||
void mixin::Convolution::set_execution_policy(const ExecutionPolicy& policy) { | |||
mgb_throw_if( | |||
m_policy_accessed, InternalError, | |||
"attempt to modify ExecutionPolicy after it has been accessed"); | |||
m_policy = policy; | |||
} | |||
template <class MgbOpr, class MegDNNOpr> | |||
void mixin::Convolution::init_output_static_infer_desc_for_bwd_data( | |||
cg::OperatorNodeBase* self) { | |||
void mixin::ConvolutionBackwardDataMixin:: | |||
init_output_static_infer_desc_for_bwd_data(cg::OperatorNodeBase* self) { | |||
using namespace cg::static_infer; | |||
auto&& mgr = self->owner_graph()->static_infer_manager(); | |||
@@ -93,7 +84,7 @@ void mixin::Convolution::init_output_static_infer_desc_for_bwd_data( | |||
}; | |||
inp_deps.push_back({self->output(0), DepType::SHAPE}); | |||
auto workspace_dep_var = | |||
WorkspaceLimitGetter::register_to_graph(self->owner_graph()); | |||
intl::WorkspaceLimitGetter::register_to_graph(self->owner_graph()); | |||
if (workspace_dep_var) { | |||
inp_deps.push_back({workspace_dep_var, DepType::VALUE}); | |||
} | |||
@@ -101,11 +92,7 @@ void mixin::Convolution::init_output_static_infer_desc_for_bwd_data( | |||
{SourceType::DEP, inp_deps, infer_wk}); | |||
} | |||
#define IMPL_CONV(_cls) \ | |||
std::pair<const void*, size_t> _cls::param_blob() const { \ | |||
return {¶m(), sizeof(Param)}; \ | |||
} \ | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(_cls) | |||
#define IMPL_CONV(_cls) MGB_DYN_TYPE_OBJ_FINAL_IMPL(_cls) | |||
class mixin::WeightPreprocessExecutor::PreprocessedFilterExecDep final | |||
: public cg::GraphExecutable::ExecDependency { | |||
@@ -11,6 +11,7 @@ | |||
*/ | |||
#include "megbrain/opr/search_policy/algo_chooser.h" | |||
#include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||
#include "megbrain/opr/search_policy/profiler.h" | |||
#include "../internal/invoke.h" | |||
@@ -200,7 +201,7 @@ size_t AlgoChooser<Opr>::setup_algo(const TensorLayoutArray& layouts, | |||
template <typename Opr> | |||
typename AlgoChooser<Opr>::ImplAlgo AlgoChooser<Opr>::get_algo( | |||
ExeContext& ctx) { | |||
using S = mixin::Convolution::ExecutionPolicy::Strategy; | |||
using S = mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; | |||
MGB_MARK_USED_VAR(TIMEOUT_TOLERANCE); | |||
switch (ctx.mgb_opr()->execution_policy().strategy) { | |||
case S::HEURISTIC: | |||
@@ -0,0 +1,33 @@ | |||
/** | |||
* \file src/opr/impl/search_policy/algo_chooser_helper.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||
#include "megbrain/opr/search_policy/algo_chooser.h" | |||
#include "megbrain/graph/cg.h" | |||
#include "../internal/megdnn_opr_wrapper.inl" | |||
using namespace mgb; | |||
using namespace opr; | |||
using namespace mixin; | |||
/* ==================== misc impl ==================== */ | |||
AlgoChooserHelper::~AlgoChooserHelper() = default; | |||
void AlgoChooserHelper::set_execution_policy(const ExecutionPolicy& policy) { | |||
mgb_throw_if( | |||
m_policy_accessed, InternalError, | |||
"attempt to modify ExecutionPolicy after it has been accessed"); | |||
m_policy = policy; | |||
} | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -13,6 +13,7 @@ | |||
#include "megbrain/opr/search_policy/profiler.h" | |||
#include "../internal/invoke.h" | |||
#include "../internal/megdnn_opr_wrapper.inl" | |||
#if MGB_ROCM | |||
#include "hcc_detail/hcc_defs_prologue.h" | |||
@@ -11,6 +11,7 @@ | |||
#pragma once | |||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
#include "megbrain/opr/search_policy/algo_chooser_helper.h" | |||
#include "megbrain/utils/persistent_cache.h" | |||
#include "megbrain/opr/param_defs.h" | |||
#include "megdnn/oprs/nn.h" | |||
@@ -19,68 +20,14 @@ namespace mgb { | |||
namespace opr { | |||
namespace mixin { | |||
/*! | |||
* \brief Convolution base class | |||
*/ | |||
class Convolution { | |||
public: | |||
using ExecutionPolicy = megdnn::param::ExecutionPolicy; | |||
using AlgorithmInfo = megdnn::detail::Algorithm::Info; | |||
using AlgoChooserHook = | |||
std::function<AlgorithmInfo(const OperatorNodeBase*)>; | |||
const ExecutionPolicy& execution_policy() const { | |||
if (!m_policy_accessed) { | |||
m_policy_accessed = true; | |||
} | |||
return m_policy; | |||
} | |||
/*! | |||
* \brief get current policy without marking it as having been accessed | |||
* | |||
* This is primarily used for getting current policy before calling | |||
* set_execution_policy(). | |||
*/ | |||
const ExecutionPolicy& execution_policy_transient() const { | |||
return m_policy; | |||
} | |||
/*! | |||
* \brief modify execution policy | |||
* | |||
* Exception would be thrown if execution_policy() has been accessed, | |||
* since it would influence cache and many other decisions. | |||
*/ | |||
void set_execution_policy(const ExecutionPolicy& policy); | |||
AlgoChooserProfileCache& profile_cache() const; | |||
virtual std::pair<const void*, size_t> param_blob() const = 0; | |||
/*! | |||
* \brief register a hook to implement custom algo chooser | |||
*/ | |||
void setup_algo_chooser(AlgoChooserHook&& func) { | |||
m_algo_chooser = func; | |||
} | |||
AlgoChooserHook algo_chooser() const { | |||
return m_algo_chooser; | |||
} | |||
protected: | |||
~Convolution(); | |||
mutable bool m_policy_accessed = false; | |||
ExecutionPolicy m_policy; | |||
AlgoChooserHook m_algo_chooser; | |||
class ConvolutionBackwardDataMixin : public cg::OperatorNodeMixinBase { | |||
protected: | |||
//! init output desc for conv backward data oprs; it handles both grad | |||
//! usage and deconv usage | |||
template <class MgbOpr, class MegDNNOpr> | |||
static void init_output_static_infer_desc_for_bwd_data( | |||
cg::OperatorNodeBase* self); | |||
//! init output desc for conv backward data oprs; it handles both grad | |||
//! usage and deconv usage | |||
template <class MgbOpr, class MegDNNOpr> | |||
static void init_output_static_infer_desc_for_bwd_data( | |||
cg::OperatorNodeBase* self); | |||
}; | |||
class WeightPreprocessExecutor : public cg::OperatorNodeMixinBase { | |||
@@ -153,7 +100,7 @@ class ConvolutionTestingPeer; | |||
} // namespace testing | |||
MGB_DEFINE_OPR_CLASS(ConvolutionForward, | |||
intl::ConvolutionForwardBase, public mixin::Convolution) // { | |||
intl::ConvolutionForwardBase, public mixin::AlgoChooserHelper) // { | |||
void init_output_dtype() override; | |||
size_t get_workspace_size_bytes( | |||
@@ -183,12 +130,11 @@ MGB_DEFINE_OPR_CLASS(ConvolutionForward, | |||
const ExecutionPolicy &policy = {}, | |||
const OperatorNodeConfig &config = {}); | |||
std::pair<const void*, size_t> param_blob() const override; | |||
}; | |||
using Convolution = ConvolutionForward; | |||
MGB_DEFINE_OPR_CLASS(ConvBiasForward, intl::ConvBiasForwardBase, | |||
public mixin::Convolution) // { | |||
public mixin::AlgoChooserHelper) // { | |||
void init_output_dtype() override; | |||
size_t get_workspace_size_bytes( | |||
@@ -240,7 +186,6 @@ public: | |||
const ExecutionPolicy& policy = {}, | |||
const OperatorNodeConfig& config = {}); | |||
std::pair<const void*, size_t> param_blob() const override; | |||
static void check_winograd_param_valid( | |||
const megdnn::ConvBias::WinogradParam& param, | |||
@@ -253,10 +198,12 @@ using ConvBias = ConvBiasForward; | |||
/*! | |||
* \brief Can be used in two ways: compute gradient of conv, or deconv | |||
*/ | |||
MGB_DEFINE_OPR_CLASS(ConvolutionBackwardData, | |||
MGB_DEFINE_OPR_CLASS( | |||
ConvolutionBackwardData, | |||
cg::SingleCNOperatorNodeBaseT< | |||
mixin::MegDNNOprHolderImpl<megdnn::ConvolutionBackwardData>>, | |||
public mixin::Convolution) // { | |||
mixin::MegDNNOprHolderImpl<megdnn::ConvolutionBackwardData>>, | |||
public mixin::AlgoChooserHelper, | |||
public mixin::ConvolutionBackwardDataMixin) // { | |||
void init_output_static_infer_desc() override; | |||
void init_output_dtype() override; | |||
void init_output_format() override; | |||
@@ -296,12 +243,11 @@ MGB_DEFINE_OPR_CLASS(ConvolutionBackwardData, | |||
return make(filter, data, param, policy, config); | |||
} | |||
std::pair<const void*, size_t> param_blob() const override; | |||
}; | |||
MGB_DEFINE_OPR_CLASS(ConvolutionBackwardFilter, | |||
intl::MegDNNOprWrapperBwd<megdnn::ConvolutionBackwardFilter>, | |||
public mixin::Convolution ) // { | |||
public mixin::AlgoChooserHelper ) // { | |||
size_t get_workspace_size_bytes( | |||
@@ -318,7 +264,6 @@ MGB_DEFINE_OPR_CLASS(ConvolutionBackwardFilter, | |||
const ExecutionPolicy &policy = {}, | |||
const OperatorNodeConfig &config = {}); | |||
std::pair<const void*, size_t> param_blob() const override; | |||
}; | |||
MGB_DEFINE_OPR_CLASS(MaskConvolution, | |||
@@ -350,7 +295,7 @@ public: | |||
MGB_DEFINE_OPR_CLASS(Convolution3DForward, | |||
intl::MegDNNOprWrapperFwd<megdnn::Convolution3DForward>, | |||
public mixin::Convolution) // { | |||
public mixin::AlgoChooserHelper) // { | |||
void init_output_dtype() override; | |||
size_t get_workspace_size_bytes( | |||
@@ -368,17 +313,18 @@ MGB_DEFINE_OPR_CLASS(Convolution3DForward, | |||
const ExecutionPolicy &policy = {}, | |||
const OperatorNodeConfig &config = {}); | |||
std::pair<const void*, size_t> param_blob() const override; | |||
}; | |||
using Convolution3D = Convolution3DForward; | |||
/*! | |||
* \brief Can be used in two ways: compute gradient of conv, or deconv | |||
*/ | |||
MGB_DEFINE_OPR_CLASS(Convolution3DBackwardData, | |||
MGB_DEFINE_OPR_CLASS( | |||
Convolution3DBackwardData, | |||
cg::SingleCNOperatorNodeBaseT< | |||
mixin::MegDNNOprHolderImpl<megdnn::Convolution3DBackwardData>>, | |||
public mixin::Convolution) // { | |||
mixin::MegDNNOprHolderImpl<megdnn::Convolution3DBackwardData>>, | |||
public mixin::AlgoChooserHelper, | |||
public mixin::ConvolutionBackwardDataMixin) // { | |||
void init_output_static_infer_desc() override; | |||
void add_input_layout_constraint() override; | |||
@@ -416,12 +362,11 @@ MGB_DEFINE_OPR_CLASS(Convolution3DBackwardData, | |||
return make(filter, data, param, policy, config); | |||
} | |||
std::pair<const void*, size_t> param_blob() const override; | |||
}; | |||
MGB_DEFINE_OPR_CLASS(Convolution3DBackwardFilter, | |||
intl::MegDNNOprWrapperBwd<megdnn::Convolution3DBackwardFilter>, | |||
public mixin::Convolution) // { | |||
public mixin::AlgoChooserHelper) // { | |||
size_t get_workspace_size_bytes( | |||
const TensorShapeArray &input_shapes, | |||
@@ -437,12 +382,11 @@ MGB_DEFINE_OPR_CLASS(Convolution3DBackwardFilter, | |||
const ExecutionPolicy &policy = {}, | |||
const OperatorNodeConfig &config = {}); | |||
std::pair<const void*, size_t> param_blob() const override; | |||
}; | |||
MGB_DEFINE_OPR_CLASS(LocalShareForward, | |||
intl::MegDNNOprWrapperFwd<megdnn::LocalShareForward>, | |||
public mixin::Convolution) // { | |||
public mixin::AlgoChooserHelper) // { | |||
void init_output_dtype() override; | |||
void init_output_format() override; | |||
@@ -457,7 +401,6 @@ public: | |||
static SymbolVar make(SymbolVar src, SymbolVar filter, const Param& param = {}, | |||
const ExecutionPolicy& policy = {}, | |||
const OperatorNodeConfig& config = {}); | |||
std::pair<const void*, size_t> param_blob() const override; | |||
}; | |||
using LocalShare = LocalShareForward; | |||
@@ -465,7 +408,8 @@ MGB_DEFINE_OPR_CLASS( | |||
LocalShareBackwardData, | |||
cg::SingleCNOperatorNodeBaseT< | |||
mixin::MegDNNOprHolderImpl<megdnn::LocalShareBackwardData>>, | |||
public mixin::Convolution) // { | |||
public mixin::AlgoChooserHelper, | |||
public mixin::ConvolutionBackwardDataMixin) // { | |||
void init_output_static_infer_desc() override; | |||
void init_output_dtype() override; | |||
@@ -485,13 +429,12 @@ public: | |||
const ExecutionPolicy& policy = {}, | |||
const OperatorNodeConfig& config = {}); | |||
std::pair<const void*, size_t> param_blob() const override; | |||
}; | |||
MGB_DEFINE_OPR_CLASS( | |||
LocalShareBackwardFilter, | |||
intl::MegDNNOprWrapperBwd<megdnn::LocalShareBackwardFilter>, | |||
public mixin::Convolution) // { | |||
public mixin::AlgoChooserHelper) // { | |||
size_t get_workspace_size_bytes( | |||
const TensorShapeArray& input_shapes, | |||
@@ -506,12 +449,11 @@ public: | |||
const ExecutionPolicy& policy = {}, | |||
const OperatorNodeConfig& config = {}); | |||
std::pair<const void*, size_t> param_blob() const override; | |||
}; | |||
MGB_DEFINE_OPR_CLASS(DeformableConvForward, | |||
intl::MegDNNOprWrapperFwd<megdnn::DeformableConvForward>, | |||
public mixin::Convolution) // { | |||
public mixin::AlgoChooserHelper) // { | |||
public: | |||
DeformableConvForward( | |||
VarNode *src, VarNode *filter, VarNode *offset, VarNode *mask, | |||
@@ -525,7 +467,6 @@ MGB_DEFINE_OPR_CLASS(DeformableConvForward, | |||
const ExecutionPolicy &policy = {}, | |||
const OperatorNodeConfig &config = {}); | |||
std::pair<const void*, size_t> param_blob() const override; | |||
private: | |||
void init_output_dtype() override; | |||
void init_output_format() override; | |||
@@ -537,7 +478,8 @@ using DeformableConv = DeformableConvForward; | |||
MGB_DEFINE_OPR_CLASS(DeformableConvBackwardData, | |||
intl::DeformableConvBackwardDataBase, | |||
public mixin::Convolution) // { | |||
public mixin::AlgoChooserHelper, | |||
public mixin::ConvolutionBackwardDataMixin) // { | |||
public: | |||
DeformableConvBackwardData( | |||
VarNode * src, VarNode * filter, VarNode * offset, VarNode * mask, | |||
@@ -557,7 +499,6 @@ public: | |||
const OperatorNodeConfig& config = {}); | |||
void scn_do_execute() override; | |||
std::pair<const void*, size_t> param_blob() const override; | |||
private: | |||
void get_output_var_shape(const TensorShapeArray& inp_shape, | |||
@@ -578,7 +519,7 @@ private: | |||
MGB_DEFINE_OPR_CLASS( | |||
DeformableConvBackwardFilter, | |||
intl::MegDNNOprWrapperBwd<megdnn::DeformableConvBackwardFilter>, | |||
public mixin::Convolution) // { | |||
public mixin::AlgoChooserHelper) // { | |||
public: | |||
DeformableConvBackwardFilter( | |||
VarNode * src, VarNode * filter, VarNode * offset, VarNode * mask, | |||
@@ -592,7 +533,6 @@ public: | |||
const OperatorNodeConfig& config = {}); | |||
void scn_do_execute() override; | |||
std::pair<const void*, size_t> param_blob() const override; | |||
private: | |||
size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes, | |||
@@ -601,7 +541,7 @@ private: | |||
}; | |||
MGB_DEFINE_OPR_CLASS(BatchConvBiasForward, intl::BatchConvBiasForwardBase, | |||
public mixin::Convolution) // { | |||
public mixin::AlgoChooserHelper) // { | |||
void init_output_dtype() override; | |||
size_t get_workspace_size_bytes( | |||
@@ -650,7 +590,6 @@ public: | |||
const ExecutionPolicy& policy = {}, | |||
const OperatorNodeConfig& config = {}); | |||
std::pair<const void*, size_t> param_blob() const override; | |||
}; | |||
using BatchConvBias = BatchConvBiasForward; | |||
@@ -13,6 +13,7 @@ | |||
#pragma once | |||
#include "megbrain/opr/search_policy/profiler.h" | |||
#include "megbrain/opr/dnn/convolution.h" | |||
template <class MegDNNOpr> | |||
struct MegDNNOpr2MGBOpr; | |||
@@ -0,0 +1,80 @@ | |||
/** | |||
* \file src/opr/include/megbrain/opr/search_policy/algo_chooser_helper.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or | |||
* implied. | |||
*/ | |||
#pragma once | |||
#include "megbrain/graph/operator_node.h" | |||
#include "megbrain/opr/param_defs.h" | |||
#include "megdnn/oprs/base.h" | |||
#include "megdnn/oprs/nn.h" | |||
namespace mgb { | |||
namespace opr { | |||
namespace mixin { | |||
/*! | |||
* \brief base class for the opr which can be tuning | |||
*/ | |||
class AlgoChooserHelper : cg::OperatorNodeMixinBase { | |||
public: | |||
using ExecutionPolicy = megdnn::param::ExecutionPolicy; | |||
using AlgorithmInfo = megdnn::detail::Algorithm::Info; | |||
using AlgoChooserHook = | |||
std::function<AlgorithmInfo(const cg::OperatorNodeBase*)>; | |||
const ExecutionPolicy& execution_policy() const { | |||
if (!m_policy_accessed) { | |||
m_policy_accessed = true; | |||
} | |||
return m_policy; | |||
} | |||
/*! | |||
* \brief get current policy without marking it as having been accessed | |||
* | |||
* This is primarily used for getting current policy before calling | |||
* set_execution_policy(). | |||
*/ | |||
const ExecutionPolicy& execution_policy_transient() const { | |||
return m_policy; | |||
} | |||
/*! | |||
* \brief modify execution policy | |||
* | |||
* Exception would be thrown if execution_policy() has been accessed, | |||
* since it would influence cache and many other decisions. | |||
*/ | |||
void set_execution_policy(const ExecutionPolicy& policy); | |||
/*! | |||
* \brief register a hook to implement custom algo chooser | |||
*/ | |||
void setup_algo_chooser(AlgoChooserHook&& func) { m_algo_chooser = func; } | |||
AlgoChooserHook algo_chooser() const { return m_algo_chooser; } | |||
protected: | |||
~AlgoChooserHelper(); | |||
mutable bool m_policy_accessed = false; | |||
ExecutionPolicy m_policy; | |||
AlgoChooserHook m_algo_chooser; | |||
}; | |||
} // namespace mixin | |||
} // namespace opr | |||
} // namespace mgb | |||
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} |
@@ -12,9 +12,10 @@ | |||
#pragma once | |||
#include "megbrain/opr/dnn/convolution.h" | |||
#include "megbrain/utils/hash_ct.h" | |||
#include "megbrain/utils/timer.h" | |||
#include "megbrain/system.h" | |||
#include "megbrain/comp_node.h" | |||
#include "megdnn/basic_types.h" | |||
#include "megdnn/oprs/nn.h" | |||
@@ -127,15 +128,15 @@ class TimedProfiler { | |||
static constexpr int arity_out = OprArityTrait<Opr>::arity_out; | |||
static constexpr int arity = OprArityTrait<Opr>::arity; | |||
using ConvTensorShapes = std::array<TensorShape, arity>; | |||
using TensorShapeArray = std::array<megdnn::TensorShape, arity>; | |||
public: | |||
struct Param { | |||
char algo_name[128]; | |||
size_t workspace; | |||
DTypeEnum dtypes[arity]; | |||
megdnn::DTypeEnum dtypes[arity]; | |||
CompNode::Locator comp_node_loc; | |||
ConvTensorShapes shapes; | |||
TensorShapeArray shapes; | |||
typename Opr::Param opr_param; | |||
bool allow_weight_preprocess; | |||