Browse Source

fix(mgb): fix attribute uncomplete filter when get_profile_result_from_cache in fast run

GitOrigin-RevId: 599437ad72
release-1.4
Megvii Engine Team 4 years ago
parent
commit
00fa02a9b8
2 changed files with 17 additions and 11 deletions
  1. +16
    -9
      src/opr/impl/search_policy/algo_chooser.cpp
  2. +1
    -2
      src/opr/test/dnn/convolution.cpp

+ 16
- 9
src/opr/impl/search_policy/algo_chooser.cpp View File

@@ -557,10 +557,18 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(

if (prof.empty())
return {};

auto attr_from_strategy =
extract_algo_attribute_from_execution_strategy(selected_strategy);
for (auto&& i : prof) {
if (!(selected_strategy & ExecutionStrategy::REPRODUCIBLE) ||
static_cast<AlgoAttribute>(i.attribute) &
AlgoAttribute::REPRODUCIBLE) {
auto attr_of_algo =
static_cast<megdnn::Algorithm::Attribute>(i.attribute);
bool contain_attr_all_positive =
(attr_from_strategy.first ==
(attr_of_algo & attr_from_strategy.first));
bool contain_attr_any_negative =
static_cast<bool>(attr_of_algo & attr_from_strategy.second);
if (contain_attr_all_positive && !contain_attr_any_negative) {
auto iter = algo_map.find(i.algo);
mgb_assert(iter != algo_map.end(),
"algorithm %s exists in "
@@ -576,12 +584,11 @@ AlgoChooser<Opr>::ExeContext::get_profile_result_from_cache(
}

mgb_log_error(
"Workspace requirement (%zu) could not be satisfied. Abort now "
"to "
"avoid further problems",
WorkspaceLimitGetter::get_workspace_limit(
m_base_mgb_opr->owner_graph(), m_cn,
m_execution_policy.workspace_limit));
"algos read from cache could not satisfy attribute with %s and "
"without %s",
Algorithm::attribute_str(attr_from_strategy.first).c_str(),
Algorithm::attribute_str(attr_from_strategy.second).c_str());

mgb_trap();
MIDOUT_E
}


+ 1
- 2
src/opr/test/dnn/convolution.cpp View File

@@ -364,7 +364,6 @@ TEST(TestOprDNN, ConvBiasExePolicy) {
for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {
#endif

auto graph = ComputingGraph::make();
HostTensorGenerator<> gen;

@@ -454,7 +453,7 @@ TEST(TestOprDNN, ConvolutionExePolicy) {
#if MGB_ENABLE_FASTRUN
for (auto strategy :
SmallVector<S>{S::PROFILE, S::HEURISTIC, S::PROFILE | S::REPRODUCIBLE,
S::PROFILE | S::HEURISTIC, S::PROFILE | S::OPTIMIZED}) {
S::PROFILE | S::HEURISTIC}) {
#else
for (auto strategy :
SmallVector<S>{S : HEURISTIC, S::PROFILE | S::HEURISTIC}) {


Loading…
Cancel
Save