|
|
@@ -243,31 +243,33 @@ typename opr::AlgoChooser<Opr>::FixedTensorLayouts to_fixed_layouts( |
|
|
|
*/ |
|
|
|
template <typename Opr> |
|
|
|
std::vector<megdnn::Algorithm::SearchItem> flatten_search_space( |
|
|
|
const typename opr::AlgoChooser<Opr>::ExeContext& ctx, |
|
|
|
const typename opr::AlgoChooser<Opr>::AlgoChooserHelper& helper, |
|
|
|
CircularDepsChecker& checker) { |
|
|
|
auto&& search_item = megdnn::Algorithm::SearchItem{ |
|
|
|
OprTypeFromOprTrait<Opr>::opr_type, ctx.param(), |
|
|
|
to_layout_array<Opr>(ctx.layouts())}; |
|
|
|
OprTypeFromOprTrait<Opr>::opr_type, helper.param(), |
|
|
|
to_layout_array<Opr>(helper.layouts())}; |
|
|
|
checker.put(search_item); |
|
|
|
std::vector<megdnn::Algorithm::SearchItem> ret; |
|
|
|
for (auto algo_info : ctx.get_all_candidates()) { |
|
|
|
megdnn::Algorithm* algo = ctx.get_algorithm_from_desc(algo_info.desc); |
|
|
|
for (auto algo_info : helper.get_all_candidates()) { |
|
|
|
megdnn::Algorithm* algo = |
|
|
|
helper.get_algorithm_from_desc(algo_info.desc); |
|
|
|
mgb_assert(algo, "Unknown algo description"); |
|
|
|
std::vector<megdnn::Algorithm::SearchItem>&& sub_items = |
|
|
|
algo->get_subopr_list(to_layout_array<Opr>(ctx.layouts()), |
|
|
|
ctx.megdnn_opr()); |
|
|
|
algo->get_subopr_list(to_layout_array<Opr>(helper.layouts()), |
|
|
|
helper.megdnn_opr()); |
|
|
|
|
|
|
|
FOREACH_OPR_TYPE_DISPATCH(sub_items, { |
|
|
|
auto&& megdnn_opr = |
|
|
|
opr::intl::create_megdnn_opr<_Opr>(ctx.comp_node()); |
|
|
|
opr::intl::create_megdnn_opr<_Opr>(helper.comp_node()); |
|
|
|
megdnn_opr->param() = |
|
|
|
Algorithm::deserialize_read_pod<typename _Opr::Param>( |
|
|
|
_item.param); |
|
|
|
typename opr::AlgoChooser<_Opr>::ExeContext sub_ctx( |
|
|
|
typename opr::AlgoChooser<_Opr>::AlgoChooserHelper sub_helper( |
|
|
|
to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), |
|
|
|
_item.param, ctx.mgb_opr(), ctx.comp_node(), |
|
|
|
ctx.execution_policy(), ctx.allow_weight_preprocess()); |
|
|
|
auto space = flatten_search_space<_Opr>(sub_ctx, checker); |
|
|
|
_item.param, helper.mgb_opr(), helper.comp_node(), |
|
|
|
helper.execution_policy(), |
|
|
|
helper.allow_weight_preprocess()); |
|
|
|
auto space = flatten_search_space<_Opr>(sub_helper, checker); |
|
|
|
ret.insert(ret.end(), space.begin(), space.end()); |
|
|
|
}); |
|
|
|
} |
|
|
@@ -280,255 +282,113 @@ std::vector<megdnn::Algorithm::SearchItem> flatten_search_space( |
|
|
|
|
|
|
|
namespace mgb { |
|
|
|
namespace opr { |
|
|
|
|
|
|
|
///////////////////////////// AlgoChooserHelper ////////////////////////// |
|
|
|
template <typename Opr> |
|
|
|
void AlgoChooser<Opr>::profile(ExeContext& ctx, |
|
|
|
ExecutionStrategy selected_strategy) { |
|
|
|
if (ctx.get_profile_result_from_cache(selected_strategy).valid()) |
|
|
|
return; |
|
|
|
AlgoChooserProfileCache::Result prof_rst; |
|
|
|
|
|
|
|
auto target_attr = ctx.extract_algo_attribute(selected_strategy); |
|
|
|
std::string layouts_str = |
|
|
|
format_fixlayouts<Opr>(ctx.layouts(), arity_in, arity_out); |
|
|
|
double cur_timeout = 0; |
|
|
|
AlgoChooser<Opr>::AlgoChooserHelper::AlgoChooserHelper( |
|
|
|
const FixedTensorLayouts& layouts, Opr* megdnn_opr, |
|
|
|
const std::string& param_str, const cg::OperatorNodeBase* mgb_opr, |
|
|
|
const CompNode& cn, |
|
|
|
const megdnn::param::ExecutionPolicy& execution_policy, |
|
|
|
bool allow_weight_preprocess) |
|
|
|
: m_layouts{layouts}, |
|
|
|
m_megdnn_opr{megdnn_opr}, |
|
|
|
m_param{param_str}, |
|
|
|
m_base_mgb_opr{mgb_opr}, |
|
|
|
m_cn{cn}, |
|
|
|
m_execution_policy{execution_policy}, |
|
|
|
m_allow_weight_preprocess{allow_weight_preprocess} { |
|
|
|
mgb_assert(m_layouts.size() == layouts.size()); |
|
|
|
static_assert(std::tuple_size<FixedTensorLayouts>::value == 3 || |
|
|
|
std::tuple_size<FixedTensorLayouts>::value == 5 || |
|
|
|
std::tuple_size<FixedTensorLayouts>::value == 8, |
|
|
|
"Convolution AlgoChooser assumes arity = 3 , 5 or 8 (for " |
|
|
|
"deformable conv)"); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
typename AlgoChooser<Opr>::ImplExecutionPolicy |
|
|
|
AlgoChooser<Opr>::AlgoChooserHelper::choose_by_heuristic( |
|
|
|
const ExecutionStrategy& selected_strategy) const { |
|
|
|
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("choose_by_heuristic"))) |
|
|
|
ImplExecutionPolicy policy; |
|
|
|
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( |
|
|
|
ctx.owner_graph(), ctx.comp_node(), |
|
|
|
ctx.execution_policy().workspace_limit); |
|
|
|
RealTimer timer; |
|
|
|
for (auto algo : ctx.get_all_candidates()) { |
|
|
|
Maybe<AlgoChooserProfileCache::ResultEntry> cur_rst; |
|
|
|
|
|
|
|
ImplExecutionPolicy policy; |
|
|
|
policy.algo = algo.desc; |
|
|
|
|
|
|
|
//! check negative attribute : skip negative attribute |
|
|
|
auto palgo = ctx.megdnn_opr()->get_algorithm_from_desc(policy.algo); |
|
|
|
if (palgo->contain_attribute_any(target_attr.second)) { |
|
|
|
mgb_log_debug( |
|
|
|
"skip algo %s with attribute(%s), which is not match the " |
|
|
|
"profile strategy required contain attribute(%s) and not " |
|
|
|
"contain attribute(%s).", |
|
|
|
algo.desc.name.c_str(), |
|
|
|
Algorithm::attribute_str(palgo->attribute()).c_str(), |
|
|
|
Algorithm::attribute_str(target_attr.first).c_str(), |
|
|
|
Algorithm::attribute_str(target_attr.second).c_str()); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
//! check workspace limit |
|
|
|
ctx.construct_execution_policy(selected_strategy, policy); |
|
|
|
if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
owner_graph(), m_cn, m_execution_policy.workspace_limit); |
|
|
|
auto attr = extract_algo_attribute(selected_strategy); |
|
|
|
policy.algo = |
|
|
|
APPLY(m_megdnn_opr->get_algorithm_info_heuristic( |
|
|
|
args..., workspace_limit, attr.first, attr.second), |
|
|
|
m_layouts) |
|
|
|
.desc; |
|
|
|
|
|
|
|
std::string msg = ssprintf("profiling %s algorithm %s %s", |
|
|
|
ctx.mgb_opr()->dyn_typeinfo()->name, |
|
|
|
algo.desc.name.c_str(), layouts_str.c_str()); |
|
|
|
timer.reset(); |
|
|
|
MGB_TRY { cur_rst = ctx.profile_single_algo(policy, cur_timeout); } |
|
|
|
MGB_CATCH(std::exception & exc, { |
|
|
|
mgb_log_warn("caught exception during %s: %s", msg.c_str(), |
|
|
|
exc.what()); |
|
|
|
continue; |
|
|
|
}) |
|
|
|
MGB_CATCH(..., { |
|
|
|
mgb_log_warn("caught exception during %s", msg.c_str()); |
|
|
|
continue; |
|
|
|
}) |
|
|
|
if (!cur_rst.valid()) { |
|
|
|
mgb_log_warn("timeout when %s; timeout setting: %.3fsec", |
|
|
|
msg.c_str(), cur_timeout); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (!cur_timeout) { |
|
|
|
cur_timeout = timer.get_secs() + TIMEOUT_TOLERANCE; |
|
|
|
} else { |
|
|
|
cur_timeout = |
|
|
|
std::min(cur_timeout, timer.get_secs() + TIMEOUT_TOLERANCE); |
|
|
|
} |
|
|
|
auto&& rst = cur_rst.val(); |
|
|
|
mgb_log_debug("%s: workspace: %zu; time: %.3gsec", msg.c_str(), |
|
|
|
rst.workspace, rst.time); |
|
|
|
prof_rst.push_back(rst); |
|
|
|
} |
|
|
|
std::string msg = ssprintf( |
|
|
|
"no usable %s algorithm %s without attribute(%s) or could not meet " |
|
|
|
"workspace limite requirement(%zu)", |
|
|
|
ctx.mgb_opr()->dyn_typeinfo()->name, layouts_str.c_str(), |
|
|
|
Algorithm::attribute_str(target_attr.second).c_str(), |
|
|
|
workspace_limit); |
|
|
|
mgb_assert(!prof_rst.empty(), "%s", msg.c_str()); |
|
|
|
Algorithm* algo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); |
|
|
|
mgb_assert(algo, "Unknown algo description"); |
|
|
|
std::vector<Algorithm::SearchItem>&& sub_items = algo->get_subopr_list( |
|
|
|
to_layout_array<Opr>(m_layouts), m_megdnn_opr); |
|
|
|
|
|
|
|
FixedTensorLayouts origin_layouts = ctx.layouts(); |
|
|
|
typename Opr::Param origin_param = ctx.megdnn_opr()->param(); |
|
|
|
AlgoChooserProfileCache::Key cache_key{origin_layouts.data(), |
|
|
|
origin_layouts.size(), &origin_param, |
|
|
|
sizeof(origin_param)}; |
|
|
|
FOREACH_OPR_TYPE_DISPATCH(sub_items, { |
|
|
|
auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(m_cn); |
|
|
|
megdnn_opr->param() = |
|
|
|
Algorithm::deserialize_read_pod<typename _Opr::Param>( |
|
|
|
_item.param); |
|
|
|
typename AlgoChooser<_Opr>::AlgoChooserHelper sub_helper( |
|
|
|
to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), |
|
|
|
_item.param, m_base_mgb_opr, m_cn, m_execution_policy, |
|
|
|
m_allow_weight_preprocess); |
|
|
|
policy.sub_policy.push_back( |
|
|
|
sub_helper.choose_by_heuristic(selected_strategy)); |
|
|
|
}); |
|
|
|
|
|
|
|
AlgoChooserProfileCache cache(ctx.comp_node(), |
|
|
|
profile_name(ctx.megdnn_opr()).c_str()); |
|
|
|
cache.put(cache_key, prof_rst); |
|
|
|
return policy; |
|
|
|
MIDOUT_E |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
typename AlgoChooser<Opr>::ImplExecutionPolicy |
|
|
|
AlgoChooser<Opr>::choose_by_profile(ExeContext& ctx, |
|
|
|
ExecutionStrategy selected_strategy, |
|
|
|
bool enable_update) { |
|
|
|
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("AlgoChooser::choose_by_profile"))) |
|
|
|
if (ctx.owner_graph()->options().no_profiling_on_shape_change) { |
|
|
|
auto policy = ctx.megdnn_opr()->execution_policy(); |
|
|
|
if (policy.algo.valid()){ |
|
|
|
AlgoChooser<Opr>::AlgoChooserHelper::choose_by_profile( |
|
|
|
const ExecutionStrategy& selected_strategy, bool enable_update) const { |
|
|
|
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("choose_by_profile"))) |
|
|
|
if (owner_graph()->options().no_profiling_on_shape_change) { |
|
|
|
auto policy = m_megdnn_opr->execution_policy(); |
|
|
|
if (policy.algo.valid()) { |
|
|
|
return policy; |
|
|
|
} |
|
|
|
if (!algo_usable_on_shape_change<Opr>()) { |
|
|
|
mgb_log_warn( |
|
|
|
"choose algo by heuristic, which may cause performance " |
|
|
|
"regression."); |
|
|
|
return ctx.choose_by_heuristic(selected_strategy); |
|
|
|
return choose_by_heuristic(selected_strategy); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
if (enable_update) { |
|
|
|
CircularDepsChecker circular_deps_checker; |
|
|
|
auto&& search_items = |
|
|
|
flatten_search_space<Opr>(ctx, circular_deps_checker); |
|
|
|
flatten_search_space<Opr>(*this, circular_deps_checker); |
|
|
|
FOREACH_OPR_TYPE_DISPATCH(search_items, { |
|
|
|
auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(ctx.comp_node()); |
|
|
|
auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(m_cn); |
|
|
|
megdnn_opr->param() = |
|
|
|
Algorithm::deserialize_read_pod<typename _Opr::Param>( |
|
|
|
_item.param); |
|
|
|
typename AlgoChooser<_Opr>::ExeContext sub_ctx( |
|
|
|
typename AlgoChooser<_Opr>::AlgoChooserHelper sub_helper( |
|
|
|
to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), |
|
|
|
_item.param, ctx.mgb_opr(), ctx.comp_node(), |
|
|
|
ctx.execution_policy(), ctx.allow_weight_preprocess()); |
|
|
|
AlgoChooser<_Opr>::profile(sub_ctx, selected_strategy); |
|
|
|
_item.param, m_base_mgb_opr, m_cn, m_execution_policy, |
|
|
|
m_allow_weight_preprocess); |
|
|
|
sub_helper.profile(selected_strategy); |
|
|
|
}); |
|
|
|
} |
|
|
|
|
|
|
|
typename AlgoChooser<Opr>::ImplExecutionPolicy policy; |
|
|
|
ctx.construct_execution_policy(selected_strategy, policy); |
|
|
|
construct_execution_policy(selected_strategy, true, policy); |
|
|
|
return policy; |
|
|
|
MIDOUT_E |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts, |
|
|
|
Opr* megdnn_opr, const MGBOpr* mgb_opr, |
|
|
|
bool allow_weight_preprocess) { |
|
|
|
if (WorkspaceLimitGetter::is_prealloc_run(mgb_opr->owner_graph())) { |
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
std::string param_str; |
|
|
|
Algorithm::serialize_write_pod(megdnn_opr->param(), param_str); |
|
|
|
ExeContext ctx(layouts, megdnn_opr, param_str, mgb_opr, |
|
|
|
mgb_opr->comp_node(), mgb_opr->execution_policy(), |
|
|
|
allow_weight_preprocess); |
|
|
|
|
|
|
|
ImplExecutionPolicy policy; |
|
|
|
if (auto algo_choose_hook = mgb_opr->algo_chooser()) { |
|
|
|
policy = algo_choose_hook(mgb_opr); |
|
|
|
ctx.construct_execution_policy((ExecutionStrategy::HEURISTIC | |
|
|
|
ExecutionStrategy::REPRODUCIBLE), |
|
|
|
policy, false); |
|
|
|
} |
|
|
|
if (!policy.algo.valid()) { |
|
|
|
policy = get_policy(ctx); |
|
|
|
} |
|
|
|
size_t workspace = ctx.get_workspace_size_bytes(policy); |
|
|
|
|
|
|
|
std::string ret; |
|
|
|
ret.append(mgb_opr->dyn_typeinfo()->name); |
|
|
|
ret += format_fixlayouts<Opr>(layouts, arity_in, arity_out); |
|
|
|
Algorithm* palgo = megdnn_opr->get_algorithm_from_desc(policy.algo); |
|
|
|
mgb_assert(palgo, "Unknown algo description"); |
|
|
|
ret.append("): algo=" + std::string(palgo->name())); |
|
|
|
ret.append(ssprintf(" workspace=%.2fMiB attirbute(%s)", |
|
|
|
workspace / (1024 * 1024.0), |
|
|
|
Algorithm::attribute_str(palgo->attribute()).c_str())); |
|
|
|
mgb_log_debug("%s", ret.c_str()); |
|
|
|
|
|
|
|
megdnn_opr->execution_policy() = policy; |
|
|
|
return workspace; |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy( |
|
|
|
ExeContext& ctx) { |
|
|
|
MGB_MARK_USED_VAR(TIMEOUT_TOLERANCE); |
|
|
|
auto opr_strategy = ctx.execution_policy().strategy; |
|
|
|
if ((opr_strategy & ExecutionStrategy::HEURISTIC) && |
|
|
|
(opr_strategy & ExecutionStrategy::PROFILE)) { |
|
|
|
ImplExecutionPolicy policy = |
|
|
|
choose_by_profile(ctx, opr_strategy, false); |
|
|
|
if (!policy.algo.valid()) |
|
|
|
policy = ctx.choose_by_heuristic(opr_strategy); |
|
|
|
return policy; |
|
|
|
} else if (!static_cast<int>(opr_strategy) || |
|
|
|
(opr_strategy & ExecutionStrategy::HEURISTIC)) { |
|
|
|
return ctx.choose_by_heuristic(opr_strategy); |
|
|
|
} |
|
|
|
#if MGB_ENABLE_FASTRUN |
|
|
|
else if (opr_strategy & ExecutionStrategy::PROFILE) { |
|
|
|
return choose_by_profile(ctx, opr_strategy); |
|
|
|
} |
|
|
|
#endif |
|
|
|
else { |
|
|
|
mgb_throw(GraphError, "bad ExecutionPolicy strategy"); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
#define INST(Opr) \ |
|
|
|
template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \ |
|
|
|
AlgoChooser<megdnn::Opr>::get_policy(ExeContext& ctx); \ |
|
|
|
template void AlgoChooser<megdnn::Opr>::profile(ExeContext& ctx, \ |
|
|
|
ExecutionStrategy); \ |
|
|
|
template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \ |
|
|
|
AlgoChooser<megdnn::Opr>::choose_by_profile( \ |
|
|
|
ExeContext& ctx, ExecutionStrategy, bool enable_update); \ |
|
|
|
template size_t AlgoChooser<megdnn::Opr>::setup_algo( \ |
|
|
|
const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \ |
|
|
|
const MGBOpr* mgb_opr, bool allow_weight_preprocess); |
|
|
|
|
|
|
|
MGB_FOREACH_FASTRUN_OPR(INST) |
|
|
|
|
|
|
|
#undef INST |
|
|
|
|
|
|
|
//////////////////////////////// ExeContext ///////////////////////////// |
|
|
|
template <typename Opr> |
|
|
|
AlgoChooser<Opr>::ExeContext::ExeContext( |
|
|
|
const FixedTensorLayouts& layouts, Opr* megdnn_opr, |
|
|
|
const std::string& param_str, const cg::OperatorNodeBase* mgb_opr, |
|
|
|
const CompNode& cn, |
|
|
|
const megdnn::param::ExecutionPolicy& execution_policy, |
|
|
|
bool allow_weight_preprocess) |
|
|
|
: m_layouts{layouts}, |
|
|
|
m_megdnn_opr{megdnn_opr}, |
|
|
|
m_param{param_str}, |
|
|
|
m_base_mgb_opr{mgb_opr}, |
|
|
|
m_cn{cn}, |
|
|
|
m_execution_policy{execution_policy}, |
|
|
|
m_allow_weight_preprocess{allow_weight_preprocess} { |
|
|
|
mgb_assert(m_layouts.size() == layouts.size()); |
|
|
|
static_assert(std::tuple_size<FixedTensorLayouts>::value == 3 || |
|
|
|
std::tuple_size<FixedTensorLayouts>::value == 5 || |
|
|
|
std::tuple_size<FixedTensorLayouts>::value == 8, |
|
|
|
"Convolution AlgoChooser assumes arity = 3 , 5 or 8 (for " |
|
|
|
"deformable conv)"); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
typename AlgoChooser<Opr>::ImplAlgo |
|
|
|
AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( |
|
|
|
ExecutionStrategy selected_strategy) const { |
|
|
|
MIDOUT_B(Opr, |
|
|
|
midout_iv(MGB_HASH_STR( |
|
|
|
"AlgoChooser::ExeContext::get_profile_result_from_cache"))) |
|
|
|
AlgoChooserProfileCache cache(m_cn, |
|
|
|
profile_name(m_megdnn_opr).c_str()); |
|
|
|
AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache( |
|
|
|
const ExecutionStrategy& selected_strategy) const { |
|
|
|
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_profile_result_from_cache"))) |
|
|
|
AlgoChooserProfileCache cache(m_cn, profile_name(m_megdnn_opr).c_str()); |
|
|
|
|
|
|
|
typename Opr::Param origin_param = m_megdnn_opr->param(); |
|
|
|
AlgoChooserProfileCache::Key cache_key{m_layouts.data(), m_layouts.size(), |
|
|
@@ -538,23 +398,22 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( |
|
|
|
return {}; |
|
|
|
|
|
|
|
auto&& prof = rst.val(); |
|
|
|
if (prof.empty()) |
|
|
|
return {}; |
|
|
|
|
|
|
|
std::unordered_map<std::string, ImplAlgo> algo_map; |
|
|
|
for (auto i : get_all_candidates()) { |
|
|
|
auto ins = algo_map.emplace(i.desc.name.c_str(), i); |
|
|
|
mgb_assert(ins.second, "duplicated algo name: %s", i.desc.name.c_str()); |
|
|
|
} |
|
|
|
|
|
|
|
if (prof.empty()) |
|
|
|
return {}; |
|
|
|
|
|
|
|
auto target_attr = extract_algo_attribute(selected_strategy); |
|
|
|
bool skip_by_negative = false; |
|
|
|
for (auto&& i : prof) { |
|
|
|
auto attr_of_algo = |
|
|
|
static_cast<megdnn::Algorithm::Attribute>(i.attribute); |
|
|
|
bool contain_attr_all_positive = |
|
|
|
(target_attr.first == |
|
|
|
(attr_of_algo & target_attr.first)); |
|
|
|
(target_attr.first == (attr_of_algo & target_attr.first)); |
|
|
|
bool contain_attr_any_negative = |
|
|
|
static_cast<bool>(attr_of_algo & target_attr.second); |
|
|
|
if (contain_attr_all_positive) { |
|
|
@@ -578,13 +437,14 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( |
|
|
|
|
|
|
|
if (skip_by_negative) { |
|
|
|
mgb_log_error( |
|
|
|
"No usable algo. Only navie algos are available, but negative " |
|
|
|
"stategy is %s.", |
|
|
|
"No usable algo. There are available algos match positive " |
|
|
|
"strategy(%s), but filtered by negative stategy(%s).", |
|
|
|
Algorithm::attribute_str(target_attr.first).c_str(), |
|
|
|
Algorithm::attribute_str(target_attr.second).c_str()); |
|
|
|
} else { |
|
|
|
mgb_log_error( |
|
|
|
"No usable algo. algos read from cache could not satisfy " |
|
|
|
"attribute with %s", |
|
|
|
"positive strategy(%s)", |
|
|
|
Algorithm::attribute_str(target_attr.first).c_str()); |
|
|
|
} |
|
|
|
|
|
|
@@ -593,75 +453,10 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
typename AlgoChooser<Opr>::ImplExecutionPolicy |
|
|
|
AlgoChooser<Opr>::ExeContext::choose_by_heuristic( |
|
|
|
ExecutionStrategy selected_strategy) const { |
|
|
|
if (m_execution_policy.workspace_limit != |
|
|
|
std::numeric_limits<decltype( |
|
|
|
m_execution_policy.workspace_limit)>::max()) { |
|
|
|
mgb_log_warn( |
|
|
|
"workspace_limit should not be setted if choose algo by " |
|
|
|
"heuristic"); |
|
|
|
} |
|
|
|
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( |
|
|
|
owner_graph(), m_cn, m_execution_policy.workspace_limit); |
|
|
|
auto attr = extract_algo_attribute(selected_strategy); |
|
|
|
ImplExecutionPolicy policy; |
|
|
|
policy.algo = |
|
|
|
APPLY(m_megdnn_opr->get_algorithm_info_heuristic( |
|
|
|
args..., workspace_limit, attr.first, attr.second), |
|
|
|
m_layouts) |
|
|
|
.desc; |
|
|
|
|
|
|
|
Algorithm* algo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); |
|
|
|
mgb_assert(algo, "Unknown algo description"); |
|
|
|
std::vector<Algorithm::SearchItem>&& sub_items = algo->get_subopr_list( |
|
|
|
to_layout_array<Opr>(m_layouts), m_megdnn_opr); |
|
|
|
|
|
|
|
FOREACH_OPR_TYPE_DISPATCH(sub_items, { |
|
|
|
auto&& megdnn_opr = intl::create_megdnn_opr<_Opr>(m_cn); |
|
|
|
megdnn_opr->param() = |
|
|
|
Algorithm::deserialize_read_pod<typename _Opr::Param>( |
|
|
|
_item.param); |
|
|
|
typename AlgoChooser<_Opr>::ExeContext sub_ctx( |
|
|
|
to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), |
|
|
|
_item.param, m_base_mgb_opr, m_cn, m_execution_policy, |
|
|
|
m_allow_weight_preprocess); |
|
|
|
policy.sub_policy.push_back( |
|
|
|
sub_ctx.choose_by_heuristic(selected_strategy)); |
|
|
|
}); |
|
|
|
|
|
|
|
return policy; |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
std::vector<typename AlgoChooser<Opr>::ImplAlgo> |
|
|
|
AlgoChooser<Opr>::ExeContext::get_all_candidates() const { |
|
|
|
auto heu = choose_by_heuristic(ExecutionStrategy::HEURISTIC); |
|
|
|
auto&& ret = APPLY(m_megdnn_opr->get_all_algorithms_info(args...), m_layouts); |
|
|
|
bool found = false; |
|
|
|
for (size_t i = 0; i < ret.size(); ++i) { |
|
|
|
if (ret[i].desc == heu.algo) { |
|
|
|
found = true; |
|
|
|
std::swap(ret[i], ret[0]); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
Algorithm* palgo = m_megdnn_opr->get_algorithm_from_desc(heu.algo); |
|
|
|
mgb_assert(palgo, "Unknown algo description"); |
|
|
|
mgb_assert(found, |
|
|
|
"algo %s got by heuristic not found in " |
|
|
|
"candidate list", |
|
|
|
palgo->name()); |
|
|
|
return std::move(ret); |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
void AlgoChooser<Opr>::ExeContext::construct_execution_policy( |
|
|
|
ExecutionStrategy selected_strategy, |
|
|
|
typename AlgoChooser<Opr>::ImplExecutionPolicy& policy, |
|
|
|
bool retrive_from_cache) const { |
|
|
|
void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy( |
|
|
|
const ExecutionStrategy& selected_strategy, bool retrive_from_cache, |
|
|
|
typename AlgoChooser<Opr>::ImplExecutionPolicy& policy) const { |
|
|
|
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("construct_execution_policy"))) |
|
|
|
if (!policy.algo.valid()) { |
|
|
|
if (retrive_from_cache) { |
|
|
|
policy.algo = get_profile_result_from_cache(selected_strategy).desc; |
|
|
@@ -712,26 +507,28 @@ void AlgoChooser<Opr>::ExeContext::construct_execution_policy( |
|
|
|
megdnn_opr->param() = |
|
|
|
Algorithm::deserialize_read_pod<typename _Opr::Param>( |
|
|
|
_item.param); |
|
|
|
typename AlgoChooser<_Opr>::ExeContext sub_ctx( |
|
|
|
typename AlgoChooser<_Opr>::AlgoChooserHelper sub_helper( |
|
|
|
to_fixed_layouts<_Opr>(_item.layouts), megdnn_opr.get(), |
|
|
|
_item.param, m_base_mgb_opr, m_cn, m_execution_policy, |
|
|
|
m_allow_weight_preprocess); |
|
|
|
policy.sub_policy.push_back({}); |
|
|
|
sub_ctx.construct_execution_policy(selected_strategy, |
|
|
|
policy.sub_policy.back(), |
|
|
|
retrive_from_cache); |
|
|
|
sub_helper.construct_execution_policy(selected_strategy, |
|
|
|
retrive_from_cache, |
|
|
|
policy.sub_policy.back()); |
|
|
|
if (!policy.sub_policy.back().algo.valid()) { |
|
|
|
// means sub_ctx.construct_execution_policy fails. clean up |
|
|
|
// means sub_helper.construct_execution_policy fails. clean up |
|
|
|
// policy.algo and return |
|
|
|
policy = {}; |
|
|
|
return; |
|
|
|
} |
|
|
|
}); |
|
|
|
MIDOUT_E |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
size_t AlgoChooser<Opr>::ExeContext::get_workspace_size_bytes( |
|
|
|
size_t AlgoChooser<Opr>::AlgoChooserHelper::get_workspace_size_bytes( |
|
|
|
const ImplExecutionPolicy& policy) const { |
|
|
|
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_workspace_size_bytes"))) |
|
|
|
m_megdnn_opr->execution_policy() = policy; |
|
|
|
size_t result; |
|
|
|
if_constexpr<opr_supports_preprocess<Opr>()>( |
|
|
@@ -752,12 +549,40 @@ size_t AlgoChooser<Opr>::ExeContext::get_workspace_size_bytes( |
|
|
|
m_layouts); |
|
|
|
}); |
|
|
|
return result; |
|
|
|
MIDOUT_E |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
std::vector<typename AlgoChooser<Opr>::ImplAlgo> |
|
|
|
AlgoChooser<Opr>::AlgoChooserHelper::get_all_candidates() const { |
|
|
|
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_all_candidates"))) |
|
|
|
auto heu = choose_by_heuristic(m_execution_policy.strategy); |
|
|
|
auto&& ret = |
|
|
|
APPLY(m_megdnn_opr->get_all_algorithms_info(args...), m_layouts); |
|
|
|
bool found = false; |
|
|
|
for (size_t i = 0; i < ret.size(); ++i) { |
|
|
|
if (ret[i].desc == heu.algo) { |
|
|
|
found = true; |
|
|
|
std::swap(ret[i], ret[0]); |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
Algorithm* palgo = m_megdnn_opr->get_algorithm_from_desc(heu.algo); |
|
|
|
mgb_assert(palgo, "Unknown algo description"); |
|
|
|
mgb_assert(found, |
|
|
|
"algo %s got by heuristic not found in " |
|
|
|
"candidate list", |
|
|
|
palgo->name()); |
|
|
|
return std::move(ret); |
|
|
|
MIDOUT_E |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
Maybe<AlgoChooserProfileCache::ResultEntry> |
|
|
|
AlgoChooser<Opr>::ExeContext::profile_single_algo( |
|
|
|
AlgoChooser<Opr>::AlgoChooserHelper::profile_single_algo( |
|
|
|
const ImplExecutionPolicy& policy, double& timeout) const { |
|
|
|
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("profile_single_algo"))) |
|
|
|
typename TimedProfiler<Opr>::Param param; |
|
|
|
// force check copy size <= dest len-1 from gcc8 for safe |
|
|
|
param.execution_policy = |
|
|
@@ -791,14 +616,103 @@ AlgoChooser<Opr>::ExeContext::profile_single_algo( |
|
|
|
if (!rst.valid()) |
|
|
|
return None; |
|
|
|
return AlgoChooserProfileCache::ResultEntry{ |
|
|
|
palgo->name(), |
|
|
|
static_cast<uint32_t>(palgo->attribute()), |
|
|
|
palgo->name(), static_cast<uint32_t>(palgo->attribute()), |
|
|
|
rst.val().time, param.workspace}; |
|
|
|
MIDOUT_E |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
void AlgoChooser<Opr>::AlgoChooserHelper::profile( |
|
|
|
const ExecutionStrategy& selected_strategy) const { |
|
|
|
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("profile"))) |
|
|
|
if (get_profile_result_from_cache(selected_strategy).valid()) |
|
|
|
return; |
|
|
|
AlgoChooserProfileCache::Result prof_rst; |
|
|
|
|
|
|
|
auto target_attr = extract_algo_attribute(selected_strategy); |
|
|
|
std::string layouts_str = |
|
|
|
format_fixlayouts<Opr>(m_layouts, arity_in, arity_out); |
|
|
|
double cur_timeout = 0; |
|
|
|
|
|
|
|
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( |
|
|
|
owner_graph(), m_cn, m_execution_policy.workspace_limit); |
|
|
|
RealTimer timer; |
|
|
|
for (auto algo : get_all_candidates()) { |
|
|
|
Maybe<AlgoChooserProfileCache::ResultEntry> cur_rst; |
|
|
|
|
|
|
|
ImplExecutionPolicy policy; |
|
|
|
policy.algo = algo.desc; |
|
|
|
|
|
|
|
//! check negative attribute : skip negative attribute |
|
|
|
auto palgo = m_megdnn_opr->get_algorithm_from_desc(policy.algo); |
|
|
|
if (palgo->contain_attribute_any(target_attr.second)) { |
|
|
|
mgb_log_debug( |
|
|
|
"skip algo %s, which matches the profile strategy required " |
|
|
|
"'not contain attribute(%s).'", |
|
|
|
algo.desc.name.c_str(), |
|
|
|
Algorithm::attribute_str(target_attr.second).c_str()); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
//! check workspace limit |
|
|
|
construct_execution_policy(selected_strategy, true, policy); |
|
|
|
if (get_workspace_size_bytes(policy) >= workspace_limit) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
std::string msg = ssprintf("profiling %s algorithm %s %s", |
|
|
|
m_base_mgb_opr->dyn_typeinfo()->name, |
|
|
|
algo.desc.name.c_str(), layouts_str.c_str()); |
|
|
|
timer.reset(); |
|
|
|
MGB_TRY { cur_rst = profile_single_algo(policy, cur_timeout); } |
|
|
|
MGB_CATCH(std::exception & exc, { |
|
|
|
mgb_log_warn("caught exception during %s: %s", msg.c_str(), |
|
|
|
exc.what()); |
|
|
|
continue; |
|
|
|
}) |
|
|
|
MGB_CATCH(..., { |
|
|
|
mgb_log_warn("caught exception during %s", msg.c_str()); |
|
|
|
continue; |
|
|
|
}) |
|
|
|
if (!cur_rst.valid()) { |
|
|
|
mgb_log_warn("timeout when %s; timeout setting: %.3fsec", |
|
|
|
msg.c_str(), cur_timeout); |
|
|
|
continue; |
|
|
|
} |
|
|
|
if (!cur_timeout) { |
|
|
|
cur_timeout = timer.get_secs() + TIMEOUT_TOLERANCE; |
|
|
|
} else { |
|
|
|
cur_timeout = |
|
|
|
std::min(cur_timeout, timer.get_secs() + TIMEOUT_TOLERANCE); |
|
|
|
} |
|
|
|
auto&& rst = cur_rst.val(); |
|
|
|
mgb_log_debug("%s: workspace: %zu; time: %.3gsec", msg.c_str(), |
|
|
|
rst.workspace, rst.time); |
|
|
|
prof_rst.push_back(rst); |
|
|
|
} |
|
|
|
std::string msg = ssprintf( |
|
|
|
"no usable %s algorithm %s without attribute(%s) or could not meet " |
|
|
|
"workspace limite requirement(%zu)", |
|
|
|
m_base_mgb_opr->dyn_typeinfo()->name, layouts_str.c_str(), |
|
|
|
Algorithm::attribute_str(target_attr.second).c_str(), |
|
|
|
workspace_limit); |
|
|
|
mgb_assert(!prof_rst.empty(), "%s", msg.c_str()); |
|
|
|
|
|
|
|
FixedTensorLayouts origin_layouts = m_layouts; |
|
|
|
typename Opr::Param origin_param = m_megdnn_opr->param(); |
|
|
|
AlgoChooserProfileCache::Key cache_key{origin_layouts.data(), |
|
|
|
origin_layouts.size(), &origin_param, |
|
|
|
sizeof(origin_param)}; |
|
|
|
|
|
|
|
AlgoChooserProfileCache cache(m_cn, profile_name(m_megdnn_opr).c_str()); |
|
|
|
cache.put(cache_key, prof_rst); |
|
|
|
MIDOUT_E |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
Maybe<PreprocessFilter<Opr>> |
|
|
|
AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const { |
|
|
|
AlgoChooser<Opr>::AlgoChooserHelper::construct_fake_preprocess_filter() const { |
|
|
|
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("construct_fake_preprocess_filter"))) |
|
|
|
Maybe<PreprocessFilter<Opr>> result = None; |
|
|
|
if_constexpr<opr_supports_preprocess<Opr>()>([&](auto _) { |
|
|
|
if (!m_allow_weight_preprocess) |
|
|
@@ -830,11 +744,12 @@ AlgoChooser<Opr>::ExeContext::construct_fake_preprocess_filter() const { |
|
|
|
} |
|
|
|
}); |
|
|
|
return result; |
|
|
|
MIDOUT_E |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
std::pair<AlgoAttribute, AlgoAttribute> |
|
|
|
AlgoChooser<Opr>::ExeContext::extract_algo_attribute( |
|
|
|
AlgoChooser<Opr>::AlgoChooserHelper::extract_algo_attribute( |
|
|
|
const ExecutionStrategy& strategy) const { |
|
|
|
std::pair<AlgoAttribute, AlgoAttribute> ret = |
|
|
|
std::make_pair(AlgoAttribute::DEFAULT, AlgoAttribute::DEFAULT); |
|
|
@@ -851,41 +766,128 @@ AlgoChooser<Opr>::ExeContext::extract_algo_attribute( |
|
|
|
} |
|
|
|
|
|
|
|
#define INST(Opr) \ |
|
|
|
template AlgoChooser<megdnn::Opr>::ExeContext::ExeContext( \ |
|
|
|
template AlgoChooser<megdnn::Opr>::AlgoChooserHelper::AlgoChooserHelper( \ |
|
|
|
const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \ |
|
|
|
const std::string& param_str, const cg::OperatorNodeBase* mgb_opr, \ |
|
|
|
const CompNode& cn, \ |
|
|
|
const megdnn::param::ExecutionPolicy& execution_policy, \ |
|
|
|
bool allow_weight_preprocess); \ |
|
|
|
template typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \ |
|
|
|
AlgoChooser<megdnn::Opr>::ExeContext::choose_by_heuristic( \ |
|
|
|
ExecutionStrategy select_strategy) const; \ |
|
|
|
AlgoChooser<megdnn::Opr>::AlgoChooserHelper::choose_by_heuristic( \ |
|
|
|
const ExecutionStrategy& select_strategy) const; \ |
|
|
|
template typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \ |
|
|
|
AlgoChooser<megdnn::Opr>::AlgoChooserHelper::choose_by_profile( \ |
|
|
|
const ExecutionStrategy& select_strategy, bool enable_update) \ |
|
|
|
const; \ |
|
|
|
template typename AlgoChooser<megdnn::Opr>::ImplAlgo \ |
|
|
|
AlgoChooser<megdnn::Opr>::ExeContext::get_profile_result_from_cache( \ |
|
|
|
ExecutionStrategy select_strategy) const; \ |
|
|
|
template std::vector<typename AlgoChooser<megdnn::Opr>::ImplAlgo> \ |
|
|
|
AlgoChooser<megdnn::Opr>::ExeContext::get_all_candidates() const; \ |
|
|
|
AlgoChooser<megdnn::Opr>::AlgoChooserHelper:: \ |
|
|
|
get_profile_result_from_cache( \ |
|
|
|
const ExecutionStrategy& select_strategy) const; \ |
|
|
|
template void \ |
|
|
|
AlgoChooser<megdnn::Opr>::AlgoChooserHelper::construct_execution_policy( \ |
|
|
|
const ExecutionStrategy& select_strategy, bool retrive_from_cache, \ |
|
|
|
typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& policy) \ |
|
|
|
const; \ |
|
|
|
template size_t \ |
|
|
|
AlgoChooser<megdnn::Opr>::ExeContext::get_workspace_size_bytes( \ |
|
|
|
AlgoChooser<megdnn::Opr>::AlgoChooserHelper::get_workspace_size_bytes( \ |
|
|
|
const typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& \ |
|
|
|
policy) const; \ |
|
|
|
template void \ |
|
|
|
AlgoChooser<megdnn::Opr>::ExeContext::construct_execution_policy( \ |
|
|
|
ExecutionStrategy select_strategy, \ |
|
|
|
typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& policy, \ |
|
|
|
bool retrive_from_cache) const; \ |
|
|
|
template std::vector<typename AlgoChooser<megdnn::Opr>::ImplAlgo> \ |
|
|
|
AlgoChooser<megdnn::Opr>::AlgoChooserHelper::get_all_candidates() const; \ |
|
|
|
template Maybe<AlgoChooserProfileCache::ResultEntry> \ |
|
|
|
AlgoChooser<megdnn::Opr>::ExeContext::profile_single_algo( \ |
|
|
|
AlgoChooser<megdnn::Opr>::AlgoChooserHelper::profile_single_algo( \ |
|
|
|
const typename AlgoChooser<megdnn::Opr>::ImplExecutionPolicy& \ |
|
|
|
policy, \ |
|
|
|
double& timeout) const; \ |
|
|
|
template std::pair<AlgoAttribute, AlgoAttribute> \ |
|
|
|
AlgoChooser<megdnn::Opr>::ExeContext::extract_algo_attribute( \ |
|
|
|
const ExecutionStrategy& strategy) const; |
|
|
|
AlgoChooser<megdnn::Opr>::AlgoChooserHelper::extract_algo_attribute( \ |
|
|
|
const ExecutionStrategy& strategy) const; \ |
|
|
|
template void AlgoChooser<megdnn::Opr>::AlgoChooserHelper::profile( \ |
|
|
|
const ExecutionStrategy& selected_strategy) const; |
|
|
|
|
|
|
|
MGB_FOREACH_FASTRUN_OPR(INST) |
|
|
|
#undef INST |
|
|
|
|
|
|
|
//////////////////////////////// AlgoChoose ///////////////////////////// |
|
|
|
template <typename Opr> |
|
|
|
typename AlgoChooser<Opr>::ImplExecutionPolicy AlgoChooser<Opr>::get_policy( |
|
|
|
const AlgoChooserHelper& helper) { |
|
|
|
auto opr_strategy = helper.execution_policy().strategy; |
|
|
|
if (opr_strategy & ExecutionStrategy::HEURISTIC) { |
|
|
|
if (opr_strategy & ExecutionStrategy::PROFILE) { |
|
|
|
//! this strategy will choose from cache first, then choost by |
|
|
|
//! heuristic if fail. |
|
|
|
ImplExecutionPolicy policy = |
|
|
|
helper.choose_by_profile(opr_strategy, false); |
|
|
|
if (!policy.algo.valid()) { |
|
|
|
policy = helper.choose_by_heuristic(opr_strategy); |
|
|
|
} |
|
|
|
return policy; |
|
|
|
} else { |
|
|
|
return helper.choose_by_heuristic(opr_strategy); |
|
|
|
} |
|
|
|
} |
|
|
|
#if MGB_ENABLE_FASTRUN |
|
|
|
else if (opr_strategy & ExecutionStrategy::PROFILE) { |
|
|
|
return helper.choose_by_profile(opr_strategy, true); |
|
|
|
} |
|
|
|
#endif |
|
|
|
else { |
|
|
|
mgb_throw(GraphError, "bad ExecutionPolicy strategy"); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
size_t AlgoChooser<Opr>::setup_algo(const FixedTensorLayouts& layouts, |
|
|
|
Opr* megdnn_opr, const MGBOpr* mgb_opr, |
|
|
|
bool allow_weight_preprocess) { |
|
|
|
if (WorkspaceLimitGetter::is_prealloc_run(mgb_opr->owner_graph())) { |
|
|
|
return 0; |
|
|
|
} |
|
|
|
|
|
|
|
std::string param_str; |
|
|
|
Algorithm::serialize_write_pod(megdnn_opr->param(), param_str); |
|
|
|
AlgoChooserHelper helper(layouts, megdnn_opr, param_str, mgb_opr, |
|
|
|
mgb_opr->comp_node(), mgb_opr->execution_policy(), |
|
|
|
allow_weight_preprocess); |
|
|
|
|
|
|
|
ImplExecutionPolicy policy; |
|
|
|
if (auto algo_choose_hook = mgb_opr->algo_chooser()) { |
|
|
|
policy = algo_choose_hook(mgb_opr); |
|
|
|
auto strategy = |
|
|
|
ExecutionStrategy::HEURISTIC | ExecutionStrategy::REPRODUCIBLE; |
|
|
|
helper.construct_execution_policy(strategy, false, policy); |
|
|
|
} |
|
|
|
if (!policy.algo.valid()) { |
|
|
|
policy = get_policy(helper); |
|
|
|
} |
|
|
|
size_t workspace = helper.get_workspace_size_bytes(policy); |
|
|
|
|
|
|
|
std::string ret; |
|
|
|
ret.append(mgb_opr->dyn_typeinfo()->name); |
|
|
|
ret += format_fixlayouts<Opr>(layouts, arity_in, arity_out); |
|
|
|
Algorithm* palgo = megdnn_opr->get_algorithm_from_desc(policy.algo); |
|
|
|
mgb_assert(palgo, "Unknown algo description"); |
|
|
|
ret.append("): algo=" + std::string(palgo->name())); |
|
|
|
ret.append(ssprintf(" workspace=%.2fMiB attirbute=%d", |
|
|
|
workspace / (1024 * 1024.0), |
|
|
|
static_cast<uint32_t>(palgo->attribute()))); |
|
|
|
mgb_log_debug("%s", ret.c_str()); |
|
|
|
|
|
|
|
megdnn_opr->execution_policy() = policy; |
|
|
|
return workspace; |
|
|
|
} |
|
|
|
|
|
|
|
#define INST(Opr) \ |
|
|
|
template AlgoChooser<megdnn::Opr>::ImplExecutionPolicy \ |
|
|
|
AlgoChooser<megdnn::Opr>::get_policy(const AlgoChooserHelper& proxy); \ |
|
|
|
template size_t AlgoChooser<megdnn::Opr>::setup_algo( \ |
|
|
|
const FixedTensorLayouts& layouts, megdnn::Opr* megdnn_opr, \ |
|
|
|
const MGBOpr* mgb_opr, bool allow_weight_preprocess); |
|
|
|
|
|
|
|
MGB_FOREACH_FASTRUN_OPR(INST) |
|
|
|
#undef INST |
|
|
|
|
|
|
|
} // namespace opr |
|
|
|
} // namespace mgb |
|
|
|
|
|
|
|