@@ -14,6 +14,7 @@ | |||||
#include "./json_loader.h" | #include "./json_loader.h" | ||||
#include "./npy.h" | #include "./npy.h" | ||||
#include "megbrain/opr/dnn/convolution.h" | |||||
#include "megbrain/utils/debug.h" | #include "megbrain/utils/debug.h" | ||||
#include "megbrain/serialization/serializer.h" | #include "megbrain/serialization/serializer.h" | ||||
#include "megbrain/serialization/extern_c_opr.h" | #include "megbrain/serialization/extern_c_opr.h" | ||||
@@ -144,6 +145,10 @@ R"__usage__( | |||||
R"__usage__( | R"__usage__( | ||||
--fast-run-algo-policy <path> | --fast-run-algo-policy <path> | ||||
It will read the cache file before profile, and save new fastrun in cache file. | It will read the cache file before profile, and save new fastrun in cache file. | ||||
--reproducible | |||||
Enable choose algo which is reproducible. It mainly used for cudnn algos. | |||||
See https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html#reproducibility | |||||
for more details. | |||||
--wait-gdb | --wait-gdb | ||||
Print PID and wait for a line from stdin before starting execution. Useful | Print PID and wait for a line from stdin before starting execution. Useful | ||||
for waiting for gdb attach. | for waiting for gdb attach. | ||||
@@ -467,6 +472,7 @@ struct Args { | |||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
bool use_fast_run = false; | bool use_fast_run = false; | ||||
#endif | #endif | ||||
bool reproducible = false; | |||||
std::string fast_run_cache_path; | std::string fast_run_cache_path; | ||||
bool copy_to_host = false; | bool copy_to_host = false; | ||||
int nr_run = 10; | int nr_run = 10; | ||||
@@ -647,10 +653,24 @@ void run_test_st(Args &env) { | |||||
} | } | ||||
mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, env.workspace_limit); | mgb::gopt::set_opr_algo_workspace_limit_inplace(vars, env.workspace_limit); | ||||
using S = opr::mixin::Convolution::ExecutionPolicy::Strategy; | |||||
S strategy = S::HEURISTIC; | |||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
if (env.use_fast_run) | |||||
mgb::gopt::enable_opr_algo_profiling_inplace(vars); | |||||
if (env.use_fast_run) { | |||||
if (env.reproducible) { | |||||
strategy = S::PROFILE_REPRODUCIBLE; | |||||
} else { | |||||
strategy = S::PROFILE; | |||||
} | |||||
} else if (env.reproducible) { | |||||
strategy = S::HEURISTIC_REPRODUCIBLE; | |||||
} | |||||
#else | |||||
if (env.reproducible) { | |||||
strategy = S::HEURISTIC_REPRODUCIBLE; | |||||
} | |||||
#endif | #endif | ||||
mgb::gopt::modify_opr_algo_strategy_inplace(vars, strategy); | |||||
if (!env.fast_run_cache_path.empty()) { | if (!env.fast_run_cache_path.empty()) { | ||||
#if MGB_ENABLE_FASTRUN | #if MGB_ENABLE_FASTRUN | ||||
if (!access(env.fast_run_cache_path.c_str(), F_OK)) { | if (!access(env.fast_run_cache_path.c_str(), F_OK)) { | ||||
@@ -1149,6 +1169,10 @@ Args Args::from_argv(int argc, char **argv) { | |||||
ret.fast_run_cache_path = argv[i]; | ret.fast_run_cache_path = argv[i]; | ||||
continue; | continue; | ||||
} | } | ||||
if (!strcmp(argv[i], "--reproducible")) { | |||||
ret.reproducible = true; | |||||
continue; | |||||
} | |||||
if (!strcmp(argv[i], "--const-shape")) { | if (!strcmp(argv[i], "--const-shape")) { | ||||
ret.load_config.const_var_shape = true; | ret.load_config.const_var_shape = true; | ||||
continue; | continue; | ||||
@@ -104,25 +104,21 @@ SymbolVarArray gopt::optimize_for_inference( | |||||
} | } | ||||
namespace { | namespace { | ||||
void modify_conv_policy(opr::mixin::Convolution& conv, | |||||
megdnn::param::ExecutionPolicy::Strategy strategy) { | |||||
void modify_conv_strategy( | |||||
opr::mixin::Convolution& conv, | |||||
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) { | |||||
auto policy = conv.execution_policy_transient(); | auto policy = conv.execution_policy_transient(); | ||||
policy.strategy = strategy; | policy.strategy = strategy; | ||||
conv.set_execution_policy(policy); | conv.set_execution_policy(policy); | ||||
} | } | ||||
template <typename Opr> | template <typename Opr> | ||||
void inplace_conv_opr_profile_modifier(OperatorNodeBase& opr) { | |||||
modify_conv_policy( | |||||
void inplace_conv_opr_modifier( | |||||
OperatorNodeBase& opr, | |||||
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) { | |||||
modify_conv_strategy( | |||||
opr.cast_final_safe<Opr>(), | opr.cast_final_safe<Opr>(), | ||||
opr::mixin::Convolution::ExecutionPolicy::Strategy::PROFILE); | |||||
} | |||||
template <typename Opr> | |||||
void inplace_conv_opr_profile_cache_modifier(OperatorNodeBase& opr) { | |||||
modify_conv_policy(opr.cast_final_safe<Opr>(), | |||||
opr::mixin::Convolution::ExecutionPolicy::Strategy:: | |||||
PROFILE_HEURISTIC); | |||||
strategy); | |||||
} | } | ||||
void modify_conv_policy_workspace_limit(opr::mixin::Convolution& conv, | void modify_conv_policy_workspace_limit(opr::mixin::Convolution& conv, | ||||
@@ -150,12 +146,20 @@ void inplace_conv_opr_workspace_limit_modifier(OperatorNodeBase& opr, | |||||
cb(DeformableConvBackwardFilter), cb(DeformableConvBackwardData), \ | cb(DeformableConvBackwardFilter), cb(DeformableConvBackwardData), \ | ||||
cb(BatchConvBiasForward), | cb(BatchConvBiasForward), | ||||
void gopt::enable_opr_algo_profiling_inplace( | |||||
const VarNodeArrayView& dest_vars) { | |||||
#if MGB_ENABLE_FASTRUN | |||||
static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&)> modifiers = | |||||
{ | |||||
#define CONV(t) {opr::t::typeinfo(), &inplace_conv_opr_profile_modifier<opr::t>} | |||||
void gopt::modify_opr_algo_strategy_inplace( | |||||
const VarNodeArrayView& dest_vars, | |||||
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy) { | |||||
#if !MGB_ENABLE_FASTRUN | |||||
using S = opr::mixin::Convolution::ExecutionPolicy::Strategy; | |||||
if (strategy == S::PROFILE || strategy == S::PROFILE_REPRODUCIBLE) { | |||||
mgb_throw(MegBrainError, "fastrun is disabled at compile time"); | |||||
} | |||||
#endif | |||||
const ThinHashMap<Typeinfo*, std::function<void(OperatorNodeBase&)>> | |||||
modifiers = { | |||||
#define CONV(t) \ | |||||
{opr::t::typeinfo(), std::bind(inplace_conv_opr_modifier<opr::t>, \ | |||||
std::placeholders::_1, strategy)} | |||||
MGB_FOREACH_FASTRUN_OPR(CONV) | MGB_FOREACH_FASTRUN_OPR(CONV) | ||||
#undef CONV | #undef CONV | ||||
}; | }; | ||||
@@ -171,34 +175,23 @@ void gopt::enable_opr_algo_profiling_inplace( | |||||
for (auto i : dest_vars) { | for (auto i : dest_vars) { | ||||
dep_iter.add(i); | dep_iter.add(i); | ||||
} | } | ||||
#else | |||||
mgb_throw(MegBrainError, "fastrun is disabled at compile time"); | |||||
#endif | |||||
} | } | ||||
void gopt::enable_opr_use_profiling_cache_inplace( | |||||
void gopt::enable_opr_algo_profiling_inplace( | |||||
const VarNodeArrayView& dest_vars) { | const VarNodeArrayView& dest_vars) { | ||||
static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&)> modifiers = | |||||
{ | |||||
#define CONV(t) \ | |||||
{opr::t::typeinfo(), &inplace_conv_opr_profile_cache_modifier<opr::t>} | |||||
MGB_FOREACH_FASTRUN_OPR(CONV) | |||||
#undef CONV | |||||
}; | |||||
auto on_opr = [&](OperatorNodeBase* opr) { | |||||
auto iter = modifiers.find(opr->dyn_typeinfo()); | |||||
if (iter != modifiers.end()) { | |||||
iter->second(*opr); | |||||
} | |||||
}; | |||||
modify_opr_algo_strategy_inplace(dest_vars, | |||||
opr::mixin::Convolution::ExecutionPolicy:: | |||||
Strategy::PROFILE); | |||||
} | |||||
cg::DepOprIter dep_iter{on_opr}; | |||||
for (auto i : dest_vars) { | |||||
dep_iter.add(i); | |||||
} | |||||
void gopt::enable_opr_use_profiling_cache_inplace( | |||||
const VarNodeArrayView& dest_vars) { | |||||
modify_opr_algo_strategy_inplace(dest_vars, | |||||
opr::mixin::Convolution::ExecutionPolicy:: | |||||
Strategy::PROFILE_HEURISTIC); | |||||
} | } | ||||
void gopt::set_opr_algo_workspace_limit_inplace( | void gopt::set_opr_algo_workspace_limit_inplace( | ||||
const VarNodeArrayView& dest_vars, size_t workspace_limit) { | const VarNodeArrayView& dest_vars, size_t workspace_limit) { | ||||
static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)> | static const ThinHashMap<Typeinfo*, void (*)(OperatorNodeBase&, size_t)> | ||||
@@ -13,6 +13,7 @@ | |||||
#include "megbrain/gopt/framework.h" | #include "megbrain/gopt/framework.h" | ||||
#include "megbrain/graph/cg.h" | #include "megbrain/graph/cg.h" | ||||
#include "megbrain/opr/dnn/convolution.h" | |||||
namespace mgb { | namespace mgb { | ||||
namespace gopt { | namespace gopt { | ||||
@@ -303,6 +304,17 @@ namespace gopt { | |||||
const OptimizeForInferenceOptions& opt = {}); | const OptimizeForInferenceOptions& opt = {}); | ||||
/*! | /*! | ||||
* \brief modify execution strategy for oprs with multiple | |||||
* algorithms | |||||
* | |||||
* This would modify the operators inplace. It can be used for implement | |||||
* the fast-run mode. | |||||
*/ | |||||
void modify_opr_algo_strategy_inplace( | |||||
const VarNodeArrayView& dest_vars, | |||||
opr::mixin::Convolution::ExecutionPolicy::Strategy strategy); | |||||
/*! | |||||
* \brief enable PROFILE execution strategy for oprs with multiple | * \brief enable PROFILE execution strategy for oprs with multiple | ||||
* algorithms | * algorithms | ||||
* | * | ||||
@@ -315,7 +327,7 @@ namespace gopt { | |||||
void enable_opr_algo_profiling_inplace(const VarNodeArrayView& dest_vars); | void enable_opr_algo_profiling_inplace(const VarNodeArrayView& dest_vars); | ||||
/*! | /*! | ||||
* \brief enable opr try profiling cache first, if failed, then try | |||||
* \brief enable opr try profiling cache first, if failed, fallback to | |||||
* heuristic | * heuristic | ||||
* | * | ||||
* This would modify the operators inplace. It is usually used to enable | * This would modify the operators inplace. It is usually used to enable | ||||
@@ -324,7 +336,8 @@ namespace gopt { | |||||
* You may want to implement TimedFuncInvoker::ForkExecImpl and/or | * You may want to implement TimedFuncInvoker::ForkExecImpl and/or | ||||
* PersistentCache for better performance in an SDK. | * PersistentCache for better performance in an SDK. | ||||
*/ | */ | ||||
void enable_opr_use_profiling_cache_inplace(const VarNodeArrayView& dest_vars); | |||||
void enable_opr_use_profiling_cache_inplace( | |||||
const VarNodeArrayView& dest_vars); | |||||
/*! | /*! | ||||
* \brief set workspace_limit for execution strategy for oprs with multiple | * \brief set workspace_limit for execution strategy for oprs with multiple | ||||