Browse Source

refactor(mgb): move convolution mixin to search_policy

GitOrigin-RevId: 81e32da034
release-1.2
Megvii Engine Team 4 years ago
parent
commit
d07cfdcbe9
11 changed files with 176 additions and 130 deletions
  1. +2
    -1
      sdk/load-and-run/src/mgblar.cpp
  2. +13
    -12
      src/gopt/impl/inference.cpp
  3. +2
    -1
      src/gopt/include/megbrain/gopt/inference.h
  4. +5
    -18
      src/opr/impl/dnn/convolution.cpp
  5. +2
    -1
      src/opr/impl/search_policy/algo_chooser.cpp
  6. +33
    -0
      src/opr/impl/search_policy/algo_chooser_helper.cpp
  7. +1
    -0
      src/opr/impl/search_policy/profiler.cpp
  8. +32
    -93
      src/opr/include/megbrain/opr/dnn/convolution.h
  9. +1
    -0
      src/opr/include/megbrain/opr/search_policy/algo_chooser.h
  10. +80
    -0
      src/opr/include/megbrain/opr/search_policy/algo_chooser_helper.h
  11. +5
    -4
      src/opr/include/megbrain/opr/search_policy/profiler.h

+ 2
- 1
sdk/load-and-run/src/mgblar.cpp View File

@@ -19,6 +19,7 @@
#include "megbrain/graph/extern_copr_api.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/search_policy/algo_chooser_helper.h"
#include "megbrain/opr/utility.h"
#include "megbrain/plugin/cpu_dispatch_checker.h"
#include "megbrain/plugin/num_range_checker.h"
@@ -691,7 +692,7 @@ void run_test_st(Args &env) {
}

mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, env.workspace_limit);
using S = opr::mixin::Convolution::ExecutionPolicy::Strategy;
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
S strategy = S::HEURISTIC;
#if MGB_ENABLE_FASTRUN
if (env.use_fast_run) {


+ 13
- 12
src/gopt/impl/inference.cpp View File

@@ -15,6 +15,7 @@
#include "megbrain/graph/event.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/search_policy/algo_chooser_helper.h"
#include "megbrain/utils/shared_set.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/opr/basic_arith.h"
@@ -116,8 +117,8 @@ SymbolVarArray gopt::optimize_for_inference(

namespace {
void modify_conv_strategy(
opr::mixin::Convolution& conv,
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) {
opr::mixin::AlgoChooserHelper& conv,
opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) {
auto policy = conv.execution_policy_transient();
policy.strategy = strategy;
conv.set_execution_policy(policy);
@@ -126,13 +127,13 @@ void modify_conv_strategy(
template <typename Opr>
void inplace_conv_opr_modifier(
OperatorNodeBase& opr,
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) {
opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) {
modify_conv_strategy(
opr.cast_final_safe<Opr>(),
strategy);
}

void modify_conv_policy_workspace_limit(opr::mixin::Convolution& conv,
void modify_conv_policy_workspace_limit(opr::mixin::AlgoChooserHelper& conv,
size_t workspace_limit) {
auto policy = conv.execution_policy_transient();
policy.workspace_limit = workspace_limit;
@@ -159,9 +160,9 @@ void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr,

void gopt::modify_opr_algo_strategy_inplace(
const VarNodeArrayView& dest_vars,
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) {
opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy) {
#if !MGB_ENABLE_FASTRUN
using S = opr::mixin::Convolution::ExecutionPolicy::Strategy;
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
if (strategy == S::PROFILE || strategy == S::PROFILE_REPRODUCIBLE) {
mgb_throw(MegBrainError, "fastrun is disabled at compile time");
}
@@ -190,16 +191,16 @@ void gopt::modify_opr_algo_strategy_inplace(

void gopt::enable_opr_algo_profiling_inplace(
const VarNodeArrayView& dest_vars) {
modify_opr_algo_strategy_inplace(dest_vars,
opr::mixin::Convolution::ExecutionPolicy::
Strategy::PROFILE);
modify_opr_algo_strategy_inplace(
dest_vars,
opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy::PROFILE);
}

void gopt::enable_opr_use_profiling_cache_inplace(
const VarNodeArrayView& dest_vars) {
modify_opr_algo_strategy_inplace(dest_vars,
opr::mixin::Convolution::ExecutionPolicy::
Strategy::PROFILE_HEURISTIC);
modify_opr_algo_strategy_inplace(
dest_vars, opr::mixin::AlgoChooserHelper::ExecutionPolicy::
Strategy::PROFILE_HEURISTIC);
}




+ 2
- 1
src/gopt/include/megbrain/gopt/inference.h View File

@@ -14,6 +14,7 @@
#include "megbrain/gopt/framework.h"
#include "megbrain/graph/cg.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/search_policy/algo_chooser_helper.h"

namespace mgb {
namespace gopt {
@@ -342,7 +343,7 @@ namespace gopt {
*/
void modify_opr_algo_strategy_inplace(
const VarNodeArrayView& dest_vars,
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy);
opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy strategy);

/*!
* \brief enable PROFILE execution strategy for oprs with multiple


+ 5
- 18
src/opr/impl/dnn/convolution.cpp View File

@@ -13,7 +13,7 @@
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/search_policy/algo_chooser.h"
#include "megbrain/opr/search_policy/profiler.h"
#include "megbrain/opr/search_policy/algo_chooser_helper.h"

#include "megbrain/graph/grad_impl.h"
#include "megbrain/system.h"
@@ -38,18 +38,9 @@ using intl::WorkspaceLimitGetter;

/* ==================== misc impl ==================== */

mixin::Convolution::~Convolution() = default;

void mixin::Convolution::set_execution_policy(const ExecutionPolicy& policy) {
mgb_throw_if(
m_policy_accessed, InternalError,
"attempt to modify ExecutionPolicy after it has been accessed");
m_policy = policy;
}

template <class MgbOpr, class MegDNNOpr>
void mixin::Convolution::init_output_static_infer_desc_for_bwd_data(
cg::OperatorNodeBase* self) {
void mixin::ConvolutionBackwardDataMixin::
init_output_static_infer_desc_for_bwd_data(cg::OperatorNodeBase* self) {
using namespace cg::static_infer;
auto&& mgr = self->owner_graph()->static_infer_manager();

@@ -93,7 +84,7 @@ void mixin::Convolution::init_output_static_infer_desc_for_bwd_data(
};
inp_deps.push_back({self->output(0), DepType::SHAPE});
auto workspace_dep_var =
WorkspaceLimitGetter::register_to_graph(self->owner_graph());
intl::WorkspaceLimitGetter::register_to_graph(self->owner_graph());
if (workspace_dep_var) {
inp_deps.push_back({workspace_dep_var, DepType::VALUE});
}
@@ -101,11 +92,7 @@ void mixin::Convolution::init_output_static_infer_desc_for_bwd_data(
{SourceType::DEP, inp_deps, infer_wk});
}

#define IMPL_CONV(_cls) \
std::pair<const void*, size_t> _cls::param_blob() const { \
return {&param(), sizeof(Param)}; \
} \
MGB_DYN_TYPE_OBJ_FINAL_IMPL(_cls)
#define IMPL_CONV(_cls) MGB_DYN_TYPE_OBJ_FINAL_IMPL(_cls)

class mixin::WeightPreprocessExecutor::PreprocessedFilterExecDep final
: public cg::GraphExecutable::ExecDependency {


+ 2
- 1
src/opr/impl/search_policy/algo_chooser.cpp View File

@@ -11,6 +11,7 @@
*/

#include "megbrain/opr/search_policy/algo_chooser.h"
#include "megbrain/opr/search_policy/algo_chooser_helper.h"
#include "megbrain/opr/search_policy/profiler.h"

#include "../internal/invoke.h"
@@ -200,7 +201,7 @@ size_t AlgoChooser<Opr>::setup_algo(const TensorLayoutArray& layouts,
template <typename Opr>
typename AlgoChooser<Opr>::ImplAlgo AlgoChooser<Opr>::get_algo(
ExeContext& ctx) {
using S = mixin::Convolution::ExecutionPolicy::Strategy;
using S = mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
MGB_MARK_USED_VAR(TIMEOUT_TOLERANCE);
switch (ctx.mgb_opr()->execution_policy().strategy) {
case S::HEURISTIC:


+ 33
- 0
src/opr/impl/search_policy/algo_chooser_helper.cpp View File

@@ -0,0 +1,33 @@
/**
* \file src/opr/impl/search_policy/algo_chooser_helper.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#include "megbrain/opr/search_policy/algo_chooser_helper.h"
#include "megbrain/opr/search_policy/algo_chooser.h"
#include "megbrain/graph/cg.h"

#include "../internal/megdnn_opr_wrapper.inl"

using namespace mgb;
using namespace opr;
using namespace mixin;
/* ==================== misc impl ==================== */

AlgoChooserHelper::~AlgoChooserHelper() = default;

void AlgoChooserHelper::set_execution_policy(const ExecutionPolicy& policy) {
mgb_throw_if(
m_policy_accessed, InternalError,
"attempt to modify ExecutionPolicy after it has been accessed");
m_policy = policy;
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 1
- 0
src/opr/impl/search_policy/profiler.cpp View File

@@ -13,6 +13,7 @@
#include "megbrain/opr/search_policy/profiler.h"

#include "../internal/invoke.h"
#include "../internal/megdnn_opr_wrapper.inl"

#if MGB_ROCM
#include "hcc_detail/hcc_defs_prologue.h"


+ 32
- 93
src/opr/include/megbrain/opr/dnn/convolution.h View File

@@ -11,6 +11,7 @@
#pragma once

#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/search_policy/algo_chooser_helper.h"
#include "megbrain/utils/persistent_cache.h"
#include "megbrain/opr/param_defs.h"
#include "megdnn/oprs/nn.h"
@@ -19,68 +20,14 @@ namespace mgb {
namespace opr {
namespace mixin {

/*!
* \brief Convolution base class
*/
class Convolution {
public:
using ExecutionPolicy = megdnn::param::ExecutionPolicy;
using AlgorithmInfo = megdnn::detail::Algorithm::Info;
using AlgoChooserHook =
std::function<AlgorithmInfo(const OperatorNodeBase*)>;

const ExecutionPolicy& execution_policy() const {
if (!m_policy_accessed) {
m_policy_accessed = true;
}
return m_policy;
}

/*!
* \brief get current policy without marking it as having been accessed
*
* This is primarily used for getting current policy before calling
* set_execution_policy().
*/
const ExecutionPolicy& execution_policy_transient() const {
return m_policy;
}

/*!
* \brief modify execution policy
*
* Exception would be thrown if execution_policy() has been accessed,
* since it would influence cache and many other decisions.
*/
void set_execution_policy(const ExecutionPolicy& policy);

AlgoChooserProfileCache& profile_cache() const;

virtual std::pair<const void*, size_t> param_blob() const = 0;

/*!
* \brief register a hook to implement custom algo chooser
*/
void setup_algo_chooser(AlgoChooserHook&& func) {
m_algo_chooser = func;
}
AlgoChooserHook algo_chooser() const {
return m_algo_chooser;
}

protected:
~Convolution();

mutable bool m_policy_accessed = false;
ExecutionPolicy m_policy;

AlgoChooserHook m_algo_chooser;
class ConvolutionBackwardDataMixin : public cg::OperatorNodeMixinBase {
protected:
//! init output desc for conv backward data oprs; it handles both grad
//! usage and deconv usage
template <class MgbOpr, class MegDNNOpr>
static void init_output_static_infer_desc_for_bwd_data(
cg::OperatorNodeBase* self);

//! init output desc for conv backward data oprs; it handles both grad
//! usage and deconv usage
template <class MgbOpr, class MegDNNOpr>
static void init_output_static_infer_desc_for_bwd_data(
cg::OperatorNodeBase* self);
};

class WeightPreprocessExecutor : public cg::OperatorNodeMixinBase {
@@ -153,7 +100,7 @@ class ConvolutionTestingPeer;
} // namespace testing

MGB_DEFINE_OPR_CLASS(ConvolutionForward,
intl::ConvolutionForwardBase, public mixin::Convolution) // {
intl::ConvolutionForwardBase, public mixin::AlgoChooserHelper) // {

void init_output_dtype() override;
size_t get_workspace_size_bytes(
@@ -183,12 +130,11 @@ MGB_DEFINE_OPR_CLASS(ConvolutionForward,
const ExecutionPolicy &policy = {},
const OperatorNodeConfig &config = {});

std::pair<const void*, size_t> param_blob() const override;
};
using Convolution = ConvolutionForward;

MGB_DEFINE_OPR_CLASS(ConvBiasForward, intl::ConvBiasForwardBase,
public mixin::Convolution) // {
public mixin::AlgoChooserHelper) // {

void init_output_dtype() override;
size_t get_workspace_size_bytes(
@@ -240,7 +186,6 @@ public:
const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {});

std::pair<const void*, size_t> param_blob() const override;

static void check_winograd_param_valid(
const megdnn::ConvBias::WinogradParam& param,
@@ -253,10 +198,12 @@ using ConvBias = ConvBiasForward;
/*!
* \brief Can be used in two ways: compute gradient of conv, or deconv
*/
MGB_DEFINE_OPR_CLASS(ConvolutionBackwardData,
MGB_DEFINE_OPR_CLASS(
ConvolutionBackwardData,
cg::SingleCNOperatorNodeBaseT<
mixin::MegDNNOprHolderImpl<megdnn::ConvolutionBackwardData>>,
public mixin::Convolution) // {
mixin::MegDNNOprHolderImpl<megdnn::ConvolutionBackwardData>>,
public mixin::AlgoChooserHelper,
public mixin::ConvolutionBackwardDataMixin) // {
void init_output_static_infer_desc() override;
void init_output_dtype() override;
void init_output_format() override;
@@ -296,12 +243,11 @@ MGB_DEFINE_OPR_CLASS(ConvolutionBackwardData,
return make(filter, data, param, policy, config);
}

std::pair<const void*, size_t> param_blob() const override;
};

MGB_DEFINE_OPR_CLASS(ConvolutionBackwardFilter,
intl::MegDNNOprWrapperBwd<megdnn::ConvolutionBackwardFilter>,
public mixin::Convolution ) // {
public mixin::AlgoChooserHelper ) // {


size_t get_workspace_size_bytes(
@@ -318,7 +264,6 @@ MGB_DEFINE_OPR_CLASS(ConvolutionBackwardFilter,
const ExecutionPolicy &policy = {},
const OperatorNodeConfig &config = {});

std::pair<const void*, size_t> param_blob() const override;
};

MGB_DEFINE_OPR_CLASS(MaskConvolution,
@@ -350,7 +295,7 @@ public:

MGB_DEFINE_OPR_CLASS(Convolution3DForward,
intl::MegDNNOprWrapperFwd<megdnn::Convolution3DForward>,
public mixin::Convolution) // {
public mixin::AlgoChooserHelper) // {

void init_output_dtype() override;
size_t get_workspace_size_bytes(
@@ -368,17 +313,18 @@ MGB_DEFINE_OPR_CLASS(Convolution3DForward,
const ExecutionPolicy &policy = {},
const OperatorNodeConfig &config = {});

std::pair<const void*, size_t> param_blob() const override;
};
using Convolution3D = Convolution3DForward;

/*!
* \brief Can be used in two ways: compute gradient of conv, or deconv
*/
MGB_DEFINE_OPR_CLASS(Convolution3DBackwardData,
MGB_DEFINE_OPR_CLASS(
Convolution3DBackwardData,
cg::SingleCNOperatorNodeBaseT<
mixin::MegDNNOprHolderImpl<megdnn::Convolution3DBackwardData>>,
public mixin::Convolution) // {
mixin::MegDNNOprHolderImpl<megdnn::Convolution3DBackwardData>>,
public mixin::AlgoChooserHelper,
public mixin::ConvolutionBackwardDataMixin) // {
void init_output_static_infer_desc() override;

void add_input_layout_constraint() override;
@@ -416,12 +362,11 @@ MGB_DEFINE_OPR_CLASS(Convolution3DBackwardData,
return make(filter, data, param, policy, config);
}

std::pair<const void*, size_t> param_blob() const override;
};

MGB_DEFINE_OPR_CLASS(Convolution3DBackwardFilter,
intl::MegDNNOprWrapperBwd<megdnn::Convolution3DBackwardFilter>,
public mixin::Convolution) // {
public mixin::AlgoChooserHelper) // {

size_t get_workspace_size_bytes(
const TensorShapeArray &input_shapes,
@@ -437,12 +382,11 @@ MGB_DEFINE_OPR_CLASS(Convolution3DBackwardFilter,
const ExecutionPolicy &policy = {},
const OperatorNodeConfig &config = {});

std::pair<const void*, size_t> param_blob() const override;
};

MGB_DEFINE_OPR_CLASS(LocalShareForward,
intl::MegDNNOprWrapperFwd<megdnn::LocalShareForward>,
public mixin::Convolution) // {
public mixin::AlgoChooserHelper) // {
void init_output_dtype() override;
void init_output_format() override;

@@ -457,7 +401,6 @@ public:
static SymbolVar make(SymbolVar src, SymbolVar filter, const Param& param = {},
const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {});
std::pair<const void*, size_t> param_blob() const override;
};
using LocalShare = LocalShareForward;

@@ -465,7 +408,8 @@ MGB_DEFINE_OPR_CLASS(
LocalShareBackwardData,
cg::SingleCNOperatorNodeBaseT<
mixin::MegDNNOprHolderImpl<megdnn::LocalShareBackwardData>>,
public mixin::Convolution) // {
public mixin::AlgoChooserHelper,
public mixin::ConvolutionBackwardDataMixin) // {
void init_output_static_infer_desc() override;
void init_output_dtype() override;

@@ -485,13 +429,12 @@ public:
const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {});

std::pair<const void*, size_t> param_blob() const override;
};

MGB_DEFINE_OPR_CLASS(
LocalShareBackwardFilter,
intl::MegDNNOprWrapperBwd<megdnn::LocalShareBackwardFilter>,
public mixin::Convolution) // {
public mixin::AlgoChooserHelper) // {

size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
@@ -506,12 +449,11 @@ public:
const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {});

std::pair<const void*, size_t> param_blob() const override;
};

MGB_DEFINE_OPR_CLASS(DeformableConvForward,
intl::MegDNNOprWrapperFwd<megdnn::DeformableConvForward>,
public mixin::Convolution) // {
public mixin::AlgoChooserHelper) // {
public:
DeformableConvForward(
VarNode *src, VarNode *filter, VarNode *offset, VarNode *mask,
@@ -525,7 +467,6 @@ MGB_DEFINE_OPR_CLASS(DeformableConvForward,
const ExecutionPolicy &policy = {},
const OperatorNodeConfig &config = {});

std::pair<const void*, size_t> param_blob() const override;
private:
void init_output_dtype() override;
void init_output_format() override;
@@ -537,7 +478,8 @@ using DeformableConv = DeformableConvForward;

MGB_DEFINE_OPR_CLASS(DeformableConvBackwardData,
intl::DeformableConvBackwardDataBase,
public mixin::Convolution) // {
public mixin::AlgoChooserHelper,
public mixin::ConvolutionBackwardDataMixin) // {
public:
DeformableConvBackwardData(
VarNode * src, VarNode * filter, VarNode * offset, VarNode * mask,
@@ -557,7 +499,6 @@ public:
const OperatorNodeConfig& config = {});

void scn_do_execute() override;
std::pair<const void*, size_t> param_blob() const override;

private:
void get_output_var_shape(const TensorShapeArray& inp_shape,
@@ -578,7 +519,7 @@ private:
MGB_DEFINE_OPR_CLASS(
DeformableConvBackwardFilter,
intl::MegDNNOprWrapperBwd<megdnn::DeformableConvBackwardFilter>,
public mixin::Convolution) // {
public mixin::AlgoChooserHelper) // {
public:
DeformableConvBackwardFilter(
VarNode * src, VarNode * filter, VarNode * offset, VarNode * mask,
@@ -592,7 +533,6 @@ public:
const OperatorNodeConfig& config = {});

void scn_do_execute() override;
std::pair<const void*, size_t> param_blob() const override;

private:
size_t get_workspace_size_bytes(const TensorShapeArray& input_shapes,
@@ -601,7 +541,7 @@ private:
};

MGB_DEFINE_OPR_CLASS(BatchConvBiasForward, intl::BatchConvBiasForwardBase,
public mixin::Convolution) // {
public mixin::AlgoChooserHelper) // {

void init_output_dtype() override;
size_t get_workspace_size_bytes(
@@ -650,7 +590,6 @@ public:
const ExecutionPolicy& policy = {},
const OperatorNodeConfig& config = {});

std::pair<const void*, size_t> param_blob() const override;
};
using BatchConvBias = BatchConvBiasForward;



+ 1
- 0
src/opr/include/megbrain/opr/search_policy/algo_chooser.h View File

@@ -13,6 +13,7 @@
#pragma once

#include "megbrain/opr/search_policy/profiler.h"
#include "megbrain/opr/dnn/convolution.h"

template <class MegDNNOpr>
struct MegDNNOpr2MGBOpr;


+ 80
- 0
src/opr/include/megbrain/opr/search_policy/algo_chooser_helper.h View File

@@ -0,0 +1,80 @@
/**
* \file src/opr/include/megbrain/opr/search_policy/algo_chooser_helper.h
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/

#pragma once

#include "megbrain/graph/operator_node.h"
#include "megbrain/opr/param_defs.h"
#include "megdnn/oprs/base.h"
#include "megdnn/oprs/nn.h"

namespace mgb {
namespace opr {

namespace mixin {

/*!
* \brief base class for the opr which can be tuning
*/
class AlgoChooserHelper : cg::OperatorNodeMixinBase {
public:
using ExecutionPolicy = megdnn::param::ExecutionPolicy;
using AlgorithmInfo = megdnn::detail::Algorithm::Info;
using AlgoChooserHook =
std::function<AlgorithmInfo(const cg::OperatorNodeBase*)>;

const ExecutionPolicy& execution_policy() const {
if (!m_policy_accessed) {
m_policy_accessed = true;
}
return m_policy;
}

/*!
* \brief get current policy without marking it as having been accessed
*
* This is primarily used for getting current policy before calling
* set_execution_policy().
*/
const ExecutionPolicy& execution_policy_transient() const {
return m_policy;
}

/*!
* \brief modify execution policy
*
* Exception would be thrown if execution_policy() has been accessed,
* since it would influence cache and many other decisions.
*/
void set_execution_policy(const ExecutionPolicy& policy);

/*!
* \brief register a hook to implement custom algo chooser
*/
void setup_algo_chooser(AlgoChooserHook&& func) { m_algo_chooser = func; }
AlgoChooserHook algo_chooser() const { return m_algo_chooser; }

protected:
~AlgoChooserHelper();

mutable bool m_policy_accessed = false;
ExecutionPolicy m_policy;

AlgoChooserHook m_algo_chooser;

};
} // namespace mixin

} // namespace opr
} // namespace mgb

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}

+ 5
- 4
src/opr/include/megbrain/opr/search_policy/profiler.h View File

@@ -12,9 +12,10 @@

#pragma once

#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/utils/hash_ct.h"
#include "megbrain/utils/timer.h"
#include "megbrain/system.h"
#include "megbrain/comp_node.h"

#include "megdnn/basic_types.h"
#include "megdnn/oprs/nn.h"
@@ -127,15 +128,15 @@ class TimedProfiler {
static constexpr int arity_out = OprArityTrait<Opr>::arity_out;
static constexpr int arity = OprArityTrait<Opr>::arity;

using ConvTensorShapes = std::array<TensorShape, arity>;
using TensorShapeArray = std::array<megdnn::TensorShape, arity>;

public:
struct Param {
char algo_name[128];
size_t workspace;
DTypeEnum dtypes[arity];
megdnn::DTypeEnum dtypes[arity];
CompNode::Locator comp_node_loc;
ConvTensorShapes shapes;
TensorShapeArray shapes;
typename Opr::Param opr_param;
bool allow_weight_preprocess;



Loading…
Cancel
Save