From f902ba243324e4868d2d39caafe76eb0821498dc Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 21 Apr 2022 19:19:32 +0800 Subject: [PATCH] docs(megbrain): add notes for fastrun GitOrigin-RevId: b59f7f205d98e127c6dcaaaedfab556cdf2dba21 --- src/rdnn/impl/algo_chooser.cpp | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/rdnn/impl/algo_chooser.cpp b/src/rdnn/impl/algo_chooser.cpp index 00bcf8b7..c6cd264a 100644 --- a/src/rdnn/impl/algo_chooser.cpp +++ b/src/rdnn/impl/algo_chooser.cpp @@ -565,6 +565,7 @@ typename AlgoChooser::ImplExecutionPolicy AlgoChooser::AlgoChooserHelp choose_by_profile( const ExecutionStrategy& selected_strategy, bool enable_update) const { MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("choose_by_profile"))) + // no_profiling_on_shape_change is usually false, no interface to change it easily if (m_desc.no_profiling_on_shape_change) { auto policy = m_dnn_opr->execution_policy(); if (policy.algo.valid()) { @@ -579,6 +580,8 @@ typename AlgoChooser::ImplExecutionPolicy AlgoChooser::AlgoChooserHelp } } + // if update enabled, do profiling and update cache + // enable_update = false only when using HEURISRIC_PROFILE strategy typename AlgoChooser::ImplExecutionPolicy tmp_policy; bool retrive_from_cache = true; bool allow_log = false; @@ -604,6 +607,8 @@ typename AlgoChooser::ImplExecutionPolicy AlgoChooser::AlgoChooserHelp }); } + // try to retrive algorithm from fastrun cache, this time it's guaranteed to get + // result, retrive_from_cache = true, allow_log = true typename AlgoChooser::ImplExecutionPolicy policy; construct_execution_policy(selected_strategy, policy); return policy; @@ -623,13 +628,16 @@ AlgoChooser::AlgoChooserHelper::get_profile_result_from_cache( m_incache_layouts.data(), m_incache_layouts.size(), &origin_param, sizeof(origin_param)}; auto&& rst = cache.get(cache_key); + // failed to find a cache entry, return if (!rst.valid()) return {{}, rst}; + // found a cache entry(it's a vector of Result), but it's empty auto&& prof = rst.val(); if (prof.empty()) return {{}, rst}; + // found non-empty cache result, filter it by workspace limit and attribute size_t workspace_limit = m_desc.get_workspace_limit(m_cn, m_execution_policy.workspace_limit); auto target_attr = extract_algo_attribute(selected_strategy); @@ -644,6 +652,8 @@ AlgoChooser::AlgoChooserHelper::get_profile_result_from_cache( if (contain_attr_all_positive) { if (!contain_attr_any_negative) { if (i.workspace <= workspace_limit) { + // found a well-suited algothrim with good workspace limit and + // correct attribute Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo); return {algo_desc, rst}; } @@ -654,9 +664,11 @@ AlgoChooser::AlgoChooserHelper::get_profile_result_from_cache( } } + // failed to find an algorithm that satisfies the actual workspace limit if (skip_by_workspace) return {}; + // failed to find an algorithm that satisfies the actual attribute std::string layouts_str = AlgoChooser::format_fixlayouts(m_fastrun_layouts); if (skip_by_negative) { mgb_log_error( @@ -685,9 +697,12 @@ void AlgoChooser::AlgoChooserHelper::construct_execution_policy( typename AlgoChooser::ImplExecutionPolicy& policy, bool retrive_from_cache, bool allow_log) const { MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("construct_execution_policy"))) + // policy.algo is always invalid when called from choose_by_profile + // policy.algo will be valid when called from profile if (!policy.algo.valid()) { if (retrive_from_cache) { policy.algo = get_profile_result_from_cache(selected_strategy).first; + // nothing is found even with profiling if (!policy.algo.valid()) { if (allow_log) { auto target_attr = extract_algo_attribute(selected_strategy); @@ -710,6 +725,8 @@ void AlgoChooser::AlgoChooserHelper::construct_execution_policy( return; } } else { + // retrive_from_cache = false happens when using algo choose hook in + // megbrain graph return heuristic algorithm in this case auto workspace_limit = m_desc.get_workspace_limit( m_cn, m_execution_policy.workspace_limit); @@ -727,11 +744,13 @@ void AlgoChooser::AlgoChooserHelper::construct_execution_policy( } } + // construct current algorithm Algorithm* algo = m_dnn_opr->get_algorithm_from_desc(policy.algo); mgb_assert(algo, "Unknown algo description"); std::vector&& sub_items = algo->get_subopr_list(to_layout_array(m_fastrun_layouts), m_dnn_opr); + // construct sub oprs' algorithm FOREACH_OPR_TYPE_DISPATCH(sub_items, { auto&& megdnn_opr = opr::intl::create_megdnn_opr<_Opr>(m_cn); megdnn_opr->param() = @@ -790,6 +809,8 @@ std::vector::ImplAlgo> AlgoChooser< auto heu = choose_by_heuristic(m_execution_policy.strategy); auto&& ret = APPLY(m_dnn_opr->get_all_algorithms_info(args...), m_fastrun_layouts); bool found = false; + // make heuristic algorithm always the first in all candidate alrogrithms + // so profiling step will always run heuristic algorithm first for (size_t i = 0; i < ret.size(); ++i) { if (ret[i].desc == heu.algo) { found = true; @@ -798,6 +819,7 @@ std::vector::ImplAlgo> AlgoChooser< } } + // make sure heuristic algorithm is valid Algorithm* palgo = m_dnn_opr->get_algorithm_from_desc(heu.algo); mgb_assert(palgo, "Unknown algo description"); mgb_assert( @@ -813,6 +835,7 @@ template Maybe AlgoChooser::AlgoChooserHelper:: profile_single_algo(const ImplExecutionPolicy& policy, double& timeout) const { MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("profile_single_algo"))) + // fill TimedProfiler::param and run actual timed profiler typename TimedProfiler::Param param; // force check copy size <= dest len-1 from gcc8 for safe param.execution_policy = @@ -867,7 +890,11 @@ template void AlgoChooser::AlgoChooserHelper::profile( const ExecutionStrategy& selected_strategy) const { MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("profile"))) + // some sub oprs have beed profiled before + // sub oprs won't be checked at the beginning of choose_by_profile auto&& rst = get_profile_result_from_cache(selected_strategy); + // rst.first.valid means there exists valid algorithms for current opr, just return + // otherwise need to profile if (rst.first.valid()) return; AlgoChooserProfileCache::Result prof_rst; @@ -957,6 +984,7 @@ void AlgoChooser::AlgoChooserHelper::profile( Algorithm::attribute_str(target_attr.second).c_str(), workspace_limit); mgb_assert(!prof_rst.empty(), "%s", msg.c_str()); + // append some previous profiled results if (rst.second.valid()) prof_rst.insert( prof_rst.end(), rst.second.val().begin(), rst.second.val().end());