diff --git a/src/opr/impl/search_policy/algo_chooser.cpp b/src/opr/impl/search_policy/algo_chooser.cpp index f3695c82..4be461f7 100644 --- a/src/opr/impl/search_policy/algo_chooser.cpp +++ b/src/opr/impl/search_policy/algo_chooser.cpp @@ -627,7 +627,7 @@ AlgoChooser::AlgoChooserHelper::choose_by_profile( } template -typename AlgoChooser::ImplAlgoDesc +std::pair::ImplAlgoDesc, Maybe> AlgoChooser::AlgoChooserHelper::get_profile_result_from_cache( const ExecutionStrategy& selected_strategy) const { MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_profile_result_from_cache"))) @@ -639,11 +639,11 @@ AlgoChooser::AlgoChooserHelper::get_profile_result_from_cache( &origin_param, sizeof(origin_param)}; auto&& rst = cache.get(cache_key); if (!rst.valid()) - return {}; + return {{}, rst}; auto&& prof = rst.val(); if (prof.empty()) - return {}; + return {{}, rst}; auto target_attr = extract_algo_attribute(selected_strategy); bool skip_by_negative = false; @@ -657,7 +657,7 @@ AlgoChooser::AlgoChooserHelper::get_profile_result_from_cache( if (contain_attr_all_positive) { if (!contain_attr_any_negative) { Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo); - return algo_desc; + return {algo_desc, rst}; } else { skip_by_negative = true; } @@ -695,7 +695,7 @@ void AlgoChooser::AlgoChooserHelper::construct_execution_policy( MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("construct_execution_policy"))) if (!policy.algo.valid()) { if (retrive_from_cache) { - policy.algo = get_profile_result_from_cache(selected_strategy); + policy.algo = get_profile_result_from_cache(selected_strategy).first; if (!policy.algo.valid()) { if (allow_log) { auto target_attr = @@ -886,7 +886,8 @@ template void AlgoChooser::AlgoChooserHelper::profile( const ExecutionStrategy& selected_strategy) const { MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("profile"))) - if (get_profile_result_from_cache(selected_strategy).valid()) + auto&& rst = get_profile_result_from_cache(selected_strategy); + if (rst.first.valid()) return; AlgoChooserProfileCache::Result prof_rst; @@ -898,7 +899,20 @@ void AlgoChooser::AlgoChooserHelper::profile( auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit( owner_graph(), m_cn, m_execution_policy.workspace_limit); RealTimer timer; + std::unordered_set rst_algos; + if (rst.second.valid()) { + std::transform(rst.second.val().begin(), rst.second.val().end(), + std::inserter(rst_algos, rst_algos.end()), + [](const AlgoChooserProfileCache::ResultEntry& result) { + return result.algo; + }); + } for (auto algo : get_all_candidates()) { + std::string desc; + serialize_write_pod(algo.desc, desc); + if (rst_algos.find(desc) != rst_algos.end()) { + continue; + } Maybe cur_rst; ImplExecutionPolicy policy; @@ -960,6 +974,9 @@ void AlgoChooser::AlgoChooserHelper::profile( Algorithm::attribute_str(target_attr.second).c_str(), workspace_limit); mgb_assert(!prof_rst.empty(), "%s", msg.c_str()); + if (rst.second.valid()) + prof_rst.insert(prof_rst.end(), rst.second.val().begin(), + rst.second.val().end()); FixedTensorLayouts incache_layouts = m_incache_layouts; typename Opr::Param origin_param = m_dnn_opr->param(); @@ -1058,7 +1075,8 @@ AlgoChooser::AlgoChooserHelper::extract_algo_attribute( AlgoChooser::AlgoChooserHelper::choose_by_profile( \ const ExecutionStrategy& select_strategy, bool enable_update) \ const; \ - template typename AlgoChooser::ImplAlgoDesc \ + template std::pair::ImplAlgoDesc, \ + Maybe> \ AlgoChooser::AlgoChooserHelper:: \ get_profile_result_from_cache( \ const ExecutionStrategy& select_strategy) const; \ diff --git a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h index fdde56cb..8f56cbd4 100644 --- a/src/opr/include/megbrain/opr/search_policy/algo_chooser.h +++ b/src/opr/include/megbrain/opr/search_policy/algo_chooser.h @@ -131,7 +131,8 @@ public: bool enable_update) const; //! get all profile algorithm from cache, return invalid if not exists - ImplAlgoDesc get_profile_result_from_cache( + std::pair> + get_profile_result_from_cache( const ExecutionStrategy& selected_strategy) const; /**