|
|
@@ -299,18 +299,13 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, |
|
|
|
RealTimer timer; |
|
|
|
for (auto algo : ctx.get_all_candidates()) { |
|
|
|
Maybe<AlgoChooserProfileCache::ResultEntry> cur_rst; |
|
|
|
std::string msg = ssprintf("profiling %s algorithm %s %s", |
|
|
|
ctx.mgb_opr()->dyn_typeinfo()->name, |
|
|
|
algo.desc.name.c_str(), layouts_str.c_str()); |
|
|
|
|
|
|
|
ImplExecutionPolicy policy; |
|
|
|
policy.algo = algo.desc; |
|
|
|
ctx.construct_execution_policy(selected_strategy, policy); |
|
|
|
if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
//! check negative attribute : skip negative attribute |
|
|
|
auto palgo = ctx.megdnn_opr()->get_algorithm_from_desc(policy.algo); |
|
|
|
if (!(palgo->contain_attribute_all(target_attr.first) && |
|
|
|
!palgo->contain_attribute_any(target_attr.second))) { |
|
|
|
if (palgo->contain_attribute_any(target_attr.second)) { |
|
|
|
mgb_log_debug( |
|
|
|
"skip algo %s with attribute(%s), which is not match the " |
|
|
|
"profile strategy required contain attribute(%s) and not " |
|
|
@@ -322,6 +317,15 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
//! check workspace limit |
|
|
|
ctx.construct_execution_policy(selected_strategy, policy); |
|
|
|
if (ctx.get_workspace_size_bytes(policy) >= workspace_limit) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
std::string msg = ssprintf("profiling %s algorithm %s %s", |
|
|
|
ctx.mgb_opr()->dyn_typeinfo()->name, |
|
|
|
algo.desc.name.c_str(), layouts_str.c_str()); |
|
|
|
timer.reset(); |
|
|
|
MGB_TRY { cur_rst = ctx.profile_single_algo(policy, cur_timeout); } |
|
|
|
MGB_CATCH(std::exception & exc, { |
|
|
@@ -350,11 +354,11 @@ void AlgoChooser<Opr>::profile(ExeContext& ctx, |
|
|
|
prof_rst.push_back(rst); |
|
|
|
} |
|
|
|
std::string msg = ssprintf( |
|
|
|
"no usable %s algorithm %s with attribute(%s) and without " |
|
|
|
"attribute(%s)", |
|
|
|
"no usable %s algorithm %s without attribute(%s) or could not meet " |
|
|
|
"workspace limite requirement(%zu)", |
|
|
|
ctx.mgb_opr()->dyn_typeinfo()->name, layouts_str.c_str(), |
|
|
|
Algorithm::attribute_str(target_attr.first).c_str(), |
|
|
|
Algorithm::attribute_str(target_attr.second).c_str()); |
|
|
|
Algorithm::attribute_str(target_attr.second).c_str(), |
|
|
|
workspace_limit); |
|
|
|
mgb_assert(!prof_rst.empty(), "%s", msg.c_str()); |
|
|
|
|
|
|
|
FixedTensorLayouts origin_layouts = ctx.layouts(); |
|
|
@@ -544,6 +548,7 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( |
|
|
|
return {}; |
|
|
|
|
|
|
|
auto target_attr = extract_algo_attribute(selected_strategy); |
|
|
|
bool skip_by_negative = false; |
|
|
|
for (auto&& i : prof) { |
|
|
|
auto attr_of_algo = |
|
|
|
static_cast<megdnn::Algorithm::Attribute>(i.attribute); |
|
|
@@ -552,26 +557,36 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache( |
|
|
|
(attr_of_algo & target_attr.first)); |
|
|
|
bool contain_attr_any_negative = |
|
|
|
static_cast<bool>(attr_of_algo & target_attr.second); |
|
|
|
if (contain_attr_all_positive && !contain_attr_any_negative) { |
|
|
|
auto iter = algo_map.find(i.algo); |
|
|
|
mgb_assert(iter != algo_map.end(), |
|
|
|
"algorithm %s exists in " |
|
|
|
"profiling result but not in algo_map; please " |
|
|
|
"report this " |
|
|
|
"bug; opr: %s{%s}, layouts: %s ", |
|
|
|
i.algo.c_str(), m_base_mgb_opr->cname(), |
|
|
|
m_base_mgb_opr->dyn_typeinfo()->name, |
|
|
|
format_fixlayouts<Opr>(m_layouts, arity_in, arity_out) |
|
|
|
.c_str()); |
|
|
|
return iter->second; |
|
|
|
if (contain_attr_all_positive) { |
|
|
|
if (!contain_attr_any_negative) { |
|
|
|
auto iter = algo_map.find(i.algo); |
|
|
|
mgb_assert( |
|
|
|
iter != algo_map.end(), |
|
|
|
"algorithm %s exists in profiling result but not in " |
|
|
|
"algo_map; please report this bug; opr: %s{%s}, " |
|
|
|
"layouts: %s ", |
|
|
|
i.algo.c_str(), m_base_mgb_opr->cname(), |
|
|
|
m_base_mgb_opr->dyn_typeinfo()->name, |
|
|
|
format_fixlayouts<Opr>(m_layouts, arity_in, arity_out) |
|
|
|
.c_str()); |
|
|
|
return iter->second; |
|
|
|
} else { |
|
|
|
skip_by_negative = true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
mgb_log_error( |
|
|
|
"algos read from cache could not satisfy attribute with %s and " |
|
|
|
"without %s", |
|
|
|
Algorithm::attribute_str(target_attr.first).c_str(), |
|
|
|
Algorithm::attribute_str(target_attr.second).c_str()); |
|
|
|
if (skip_by_negative) { |
|
|
|
mgb_log_error( |
|
|
|
"No usable algo. Only navie algos are available, but negative " |
|
|
|
"stategy is %s.", |
|
|
|
Algorithm::attribute_str(target_attr.second).c_str()); |
|
|
|
} else { |
|
|
|
mgb_log_error( |
|
|
|
"No usable algo. algos read from cache could not satisfy " |
|
|
|
"attribute with %s", |
|
|
|
Algorithm::attribute_str(target_attr.first).c_str()); |
|
|
|
} |
|
|
|
|
|
|
|
mgb_trap(); |
|
|
|
MIDOUT_E |
|
|
|