|
|
@@ -627,7 +627,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::choose_by_profile( |
|
|
|
} |
|
|
|
|
|
|
|
template <typename Opr> |
|
|
|
typename AlgoChooser<Opr>::ImplAlgoDesc |
|
|
|
std::pair<typename AlgoChooser<Opr>::ImplAlgoDesc, Maybe<AlgoChooserProfileCache::Result>> |
|
|
|
AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache( |
|
|
|
const ExecutionStrategy& selected_strategy) const { |
|
|
|
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_profile_result_from_cache"))) |
|
|
@@ -639,11 +639,11 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache( |
|
|
|
&origin_param, sizeof(origin_param)}; |
|
|
|
auto&& rst = cache.get(cache_key); |
|
|
|
if (!rst.valid()) |
|
|
|
return {}; |
|
|
|
return {{}, rst}; |
|
|
|
|
|
|
|
auto&& prof = rst.val(); |
|
|
|
if (prof.empty()) |
|
|
|
return {}; |
|
|
|
return {{}, rst}; |
|
|
|
|
|
|
|
auto target_attr = extract_algo_attribute(selected_strategy); |
|
|
|
bool skip_by_negative = false; |
|
|
@@ -657,7 +657,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache( |
|
|
|
if (contain_attr_all_positive) { |
|
|
|
if (!contain_attr_any_negative) { |
|
|
|
Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo); |
|
|
|
return algo_desc; |
|
|
|
return {algo_desc, rst}; |
|
|
|
} else { |
|
|
|
skip_by_negative = true; |
|
|
|
} |
|
|
@@ -695,7 +695,7 @@ void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy( |
|
|
|
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("construct_execution_policy"))) |
|
|
|
if (!policy.algo.valid()) { |
|
|
|
if (retrive_from_cache) { |
|
|
|
policy.algo = get_profile_result_from_cache(selected_strategy); |
|
|
|
policy.algo = get_profile_result_from_cache(selected_strategy).first; |
|
|
|
if (!policy.algo.valid()) { |
|
|
|
if (allow_log) { |
|
|
|
auto target_attr = |
|
|
@@ -886,7 +886,8 @@ template <typename Opr> |
|
|
|
void AlgoChooser<Opr>::AlgoChooserHelper::profile( |
|
|
|
const ExecutionStrategy& selected_strategy) const { |
|
|
|
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("profile"))) |
|
|
|
if (get_profile_result_from_cache(selected_strategy).valid()) |
|
|
|
auto&& rst = get_profile_result_from_cache(selected_strategy); |
|
|
|
if (rst.first.valid()) |
|
|
|
return; |
|
|
|
AlgoChooserProfileCache::Result prof_rst; |
|
|
|
|
|
|
@@ -898,7 +899,20 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile( |
|
|
|
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( |
|
|
|
owner_graph(), m_cn, m_execution_policy.workspace_limit); |
|
|
|
RealTimer timer; |
|
|
|
std::unordered_set<std::string> rst_algos; |
|
|
|
if (rst.second.valid()) { |
|
|
|
std::transform(rst.second.val().begin(), rst.second.val().end(), |
|
|
|
std::inserter(rst_algos, rst_algos.end()), |
|
|
|
[](const AlgoChooserProfileCache::ResultEntry& result) { |
|
|
|
return result.algo; |
|
|
|
}); |
|
|
|
} |
|
|
|
for (auto algo : get_all_candidates()) { |
|
|
|
std::string desc; |
|
|
|
serialize_write_pod(algo.desc, desc); |
|
|
|
if (rst_algos.find(desc) != rst_algos.end()) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
Maybe<AlgoChooserProfileCache::ResultEntry> cur_rst; |
|
|
|
|
|
|
|
ImplExecutionPolicy policy; |
|
|
@@ -960,6 +974,9 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile( |
|
|
|
Algorithm::attribute_str(target_attr.second).c_str(), |
|
|
|
workspace_limit); |
|
|
|
mgb_assert(!prof_rst.empty(), "%s", msg.c_str()); |
|
|
|
if (rst.second.valid()) |
|
|
|
prof_rst.insert(prof_rst.end(), rst.second.val().begin(), |
|
|
|
rst.second.val().end()); |
|
|
|
|
|
|
|
FixedTensorLayouts incache_layouts = m_incache_layouts; |
|
|
|
typename Opr::Param origin_param = m_dnn_opr->param(); |
|
|
@@ -1058,7 +1075,8 @@ AlgoChooser<Opr>::AlgoChooserHelper::extract_algo_attribute( |
|
|
|
AlgoChooser<megdnn::Opr>::AlgoChooserHelper::choose_by_profile( \ |
|
|
|
const ExecutionStrategy& select_strategy, bool enable_update) \ |
|
|
|
const; \ |
|
|
|
template typename AlgoChooser<megdnn::Opr>::ImplAlgoDesc \ |
|
|
|
template std::pair<typename AlgoChooser<megdnn::Opr>::ImplAlgoDesc, \ |
|
|
|
Maybe<AlgoChooserProfileCache::Result>> \ |
|
|
|
AlgoChooser<megdnn::Opr>::AlgoChooserHelper:: \ |
|
|
|
get_profile_result_from_cache( \ |
|
|
|
const ExecutionStrategy& select_strategy) const; \ |
|
|
|