Browse Source

feat(mgb): fastrun algo profile deduplication

GitOrigin-RevId: 0d1bed781d
release-1.6
Megvii Engine Team 3 years ago
parent
commit
567586a037
2 changed files with 27 additions and 8 deletions
  1. +25
    -7
      src/opr/impl/search_policy/algo_chooser.cpp
  2. +2
    -1
      src/opr/include/megbrain/opr/search_policy/algo_chooser.h

+ 25
- 7
src/opr/impl/search_policy/algo_chooser.cpp View File

@@ -627,7 +627,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::choose_by_profile(
}

template <typename Opr>
typename AlgoChooser<Opr>::ImplAlgoDesc
std::pair<typename AlgoChooser<Opr>::ImplAlgoDesc, Maybe<AlgoChooserProfileCache::Result>>
AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache(
const ExecutionStrategy& selected_strategy) const {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("get_profile_result_from_cache")))
@@ -639,11 +639,11 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache(
&origin_param, sizeof(origin_param)};
auto&& rst = cache.get(cache_key);
if (!rst.valid())
return {};
return {{}, rst};

auto&& prof = rst.val();
if (prof.empty())
return {};
return {{}, rst};

auto target_attr = extract_algo_attribute(selected_strategy);
bool skip_by_negative = false;
@@ -657,7 +657,7 @@ AlgoChooser<Opr>::AlgoChooserHelper::get_profile_result_from_cache(
if (contain_attr_all_positive) {
if (!contain_attr_any_negative) {
Algorithm::Info::Desc algo_desc = deserialize_read_pod(i.algo);
return algo_desc;
return {algo_desc, rst};
} else {
skip_by_negative = true;
}
@@ -695,7 +695,7 @@ void AlgoChooser<Opr>::AlgoChooserHelper::construct_execution_policy(
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("construct_execution_policy")))
if (!policy.algo.valid()) {
if (retrive_from_cache) {
policy.algo = get_profile_result_from_cache(selected_strategy);
policy.algo = get_profile_result_from_cache(selected_strategy).first;
if (!policy.algo.valid()) {
if (allow_log) {
auto target_attr =
@@ -886,7 +886,8 @@ template <typename Opr>
void AlgoChooser<Opr>::AlgoChooserHelper::profile(
const ExecutionStrategy& selected_strategy) const {
MIDOUT_B(Opr, midout_iv(MGB_HASH_STR("profile")))
if (get_profile_result_from_cache(selected_strategy).valid())
auto&& rst = get_profile_result_from_cache(selected_strategy);
if (rst.first.valid())
return;
AlgoChooserProfileCache::Result prof_rst;

@@ -898,7 +899,20 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile(
auto workspace_limit = WorkspaceLimitGetter::get_workspace_limit(
owner_graph(), m_cn, m_execution_policy.workspace_limit);
RealTimer timer;
std::unordered_set<std::string> rst_algos;
if (rst.second.valid()) {
std::transform(rst.second.val().begin(), rst.second.val().end(),
std::inserter(rst_algos, rst_algos.end()),
[](const AlgoChooserProfileCache::ResultEntry& result) {
return result.algo;
});
}
for (auto algo : get_all_candidates()) {
std::string desc;
serialize_write_pod(algo.desc, desc);
if (rst_algos.find(desc) != rst_algos.end()) {
continue;
}
Maybe<AlgoChooserProfileCache::ResultEntry> cur_rst;

ImplExecutionPolicy policy;
@@ -960,6 +974,9 @@ void AlgoChooser<Opr>::AlgoChooserHelper::profile(
Algorithm::attribute_str(target_attr.second).c_str(),
workspace_limit);
mgb_assert(!prof_rst.empty(), "%s", msg.c_str());
if (rst.second.valid())
prof_rst.insert(prof_rst.end(), rst.second.val().begin(),
rst.second.val().end());

FixedTensorLayouts incache_layouts = m_incache_layouts;
typename Opr::Param origin_param = m_dnn_opr->param();
@@ -1058,7 +1075,8 @@ AlgoChooser<Opr>::AlgoChooserHelper::extract_algo_attribute(
AlgoChooser<megdnn::Opr>::AlgoChooserHelper::choose_by_profile( \
const ExecutionStrategy& select_strategy, bool enable_update) \
const; \
template typename AlgoChooser<megdnn::Opr>::ImplAlgoDesc \
template std::pair<typename AlgoChooser<megdnn::Opr>::ImplAlgoDesc, \
Maybe<AlgoChooserProfileCache::Result>> \
AlgoChooser<megdnn::Opr>::AlgoChooserHelper:: \
get_profile_result_from_cache( \
const ExecutionStrategy& select_strategy) const; \


+ 2
- 1
src/opr/include/megbrain/opr/search_policy/algo_chooser.h View File

@@ -131,7 +131,8 @@ public:
bool enable_update) const;

//! get all profile algorithm from cache, return invalid if not exists
ImplAlgoDesc get_profile_result_from_cache(
std::pair<ImplAlgoDesc, Maybe<AlgoChooserProfileCache::Result>>
get_profile_result_from_cache(
const ExecutionStrategy& selected_strategy) const;

/**


Loading…
Cancel
Save