From 93152dfa1447f4ba1de4a58202aed1fedf6afd9d Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 18 Oct 2021 14:17:52 +0800 Subject: [PATCH] fix(mgb/gopt): fix global layout transform deduplicate the states of the DP problem due to different layout config(NCHW44 & NCHW44_HYBRID) will produce tensors with same layout GitOrigin-RevId: 7f77efd21b5014b4bbddce834a290efebdbdacb2 --- .../dynamic_programming_solver.cpp | 31 +++++++++++++++------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/src/gopt/impl/global_layout_transform/dynamic_programming_solver.cpp b/src/gopt/impl/global_layout_transform/dynamic_programming_solver.cpp index 7f480d7d..5c0a17ea 100644 --- a/src/gopt/impl/global_layout_transform/dynamic_programming_solver.cpp +++ b/src/gopt/impl/global_layout_transform/dynamic_programming_solver.cpp @@ -337,6 +337,24 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( SmallVector cuts; size_t cur = 0; + /* \notes: In the layout selection problem, different operator layout configurations + * will produce tensors with same layout (i.e. same state in the DP problem). This + * means we should only keep the smallest time of all the same states but produced + * by different layout configs. Here we use the following helper function to + * deduplicate the states + */ + auto add_state = [](StateTable& states, const State& state, const Value& value) { + auto iter = states.find(state); + if (iter == states.end()) { + states[state] = value; + } else { + float time = iter->second.time; + if (value.time < time) { + iter->second = value; + } + } + }; + /// initialize states auto init = [&, this](OperatorNodeBase* opr) { auto it = rst.opr_record.find(opr); @@ -388,7 +406,7 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( ivar_time += min_time; } value.time = opr_time + ivar_time + ovar_time; - states[state] = value; + add_state(states, state, value); } }; @@ -455,15 +473,7 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( ivar_time += min_time; } value.time = prev_time + opr_time + ivar_time + ovar_time; - auto iter = states.find(state); - if (iter == states.end()) { - states[state] = value; - } else { - float time = iter->second.time; - if (value.time < time) { - iter->second = value; - } - } + add_state(states, state, value); } } cuts.emplace_back(Cut{}); @@ -481,6 +491,7 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( auto& states = cuts.back().states; prune(states, edges[cur], ctx); force_prune(states); + } cur++; }