From 36a90e0313922185b27da9eb792ab091e422540b Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Sat, 3 Jul 2021 19:10:51 +0800 Subject: [PATCH 1/2] Fix cross merge for switch --- .../passes/mark_force_unknown_for_cond_pass.cc | 64 +++++++++++++++++++--- 1 file changed, 55 insertions(+), 9 deletions(-) diff --git a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc index e024217f..eeb0747f 100644 --- a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc @@ -145,17 +145,63 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: /// @return /// void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map> &switch_groups) { - for (auto it = switch_groups.begin(); it != switch_groups.end(); ++it) { - const auto &op_node = it->first; - const auto &op_desc = op_node->GetOpDesc(); - if (op_desc->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { - continue; + // Step 0: no group assigned. such as: + // Merge1{id=0, group=} => {Switch1{id=1, group=}, Switch2{id=2, group=}} + // Merge2{id=3, group=} => {Switch1{id=1, group=}, Switch3{id=4, group=}} + // Merge3{id=5, group=} => {Switch4{id=6, group=}, Switch5{id=7, group=}} + // Merge4{id=8, group=} => {Switch1{id=1, group=}, Switch5{id=7, group=}} + std::map unique_groups; + const auto GetGroupIndex = [&unique_groups](const NodePtr &merge, const std::vector &switch_group) { + int64_t group_index = merge->GetOpDesc()->GetId(); + std::set group_ids{group_index}; + for (const auto &node : switch_group) { + if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { + GELOGI("[%s] Get group from [%s], index[%ld]", merge->GetName().c_str(), node->GetName().c_str(), group_index); + group_ids.insert(group_index); + } + } + + const auto it = unique_groups.find(group_index); + if (it != unique_groups.end()) { + group_index = it->second; } - int64_t group_index = op_desc->GetId(); - SetControlFlowGroup(op_node, group_index); - for (const auto &n : it->second) { - SetControlFlowGroup(n, group_index); + for (auto id : group_ids) { + unique_groups[id] = group_index; + } + + return group_index; + }; + + const auto SetGroupIndex = [](const NodePtr &merge, const std::vector &switch_group, int64_t group_index) { + SetControlFlowGroup(merge, group_index); + for (const auto &node : switch_group) { + SetControlFlowGroup(node, group_index); + } + }; + + // Step 1: Set group index to merge, if switch already has group, use assigned group. + // Merge1{id=0, group=0} => {Switch1{id=1, group=0}, Switch2{id=2, group=0}} + // Merge2{id=3, group=0} => {Switch1{id=1, group=0}, Switch3{id=4, group=0}} + // Merge3{id=5, group=5} => {Switch4{id=6, group=5}, Switch5{id=7, group=5}} + // Merge4{id=8, group=0} => {Switch1{id=1, group=0}, Switch5{id=7, group=0}} + for (const auto group : switch_groups) { + int64_t group_index = GetGroupIndex(group.first, group.second); + SetGroupIndex(group.first, group.second, group_index); + } + + // Step 2: Adjust crossed merge group for unique group. + // Merge1{id=0, group=0} => {Switch1{id=1, group=0}, Switch2{id=2, group=0}} + // Merge2{id=3, group=0} => {Switch1{id=1, group=0}, Switch3{id=4, group=0}} + // Merge3{id=5, group=0} => {Switch4{id=6, group=0}, Switch5{id=7, group=0}} + // Merge4{id=8, group=0} => {Switch1{id=1, group=0}, Switch5{id=7, group=0}} + for (const auto group : switch_groups) { + int64_t group_index = -1; + (void)AttrUtils::GetInt(group.first->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); + + const auto it = unique_groups.find(group_index); + if (it != unique_groups.end() && it->first != it->second) { + SetGroupIndex(group.first, group.second, it->second); } } } From 6b0dacc00f1647aa85427f8522e7ca62ea35a7bc Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Mon, 5 Jul 2021 08:41:54 +0800 Subject: [PATCH 2/2] Fix lambda expression name --- ge/graph/passes/mark_force_unknown_for_cond_pass.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc index eeb0747f..f864d7d2 100644 --- a/ge/graph/passes/mark_force_unknown_for_cond_pass.cc +++ b/ge/graph/passes/mark_force_unknown_for_cond_pass.cc @@ -151,7 +151,7 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map {Switch4{id=6, group=}, Switch5{id=7, group=}} // Merge4{id=8, group=} => {Switch1{id=1, group=}, Switch5{id=7, group=}} std::map unique_groups; - const auto GetGroupIndex = [&unique_groups](const NodePtr &merge, const std::vector &switch_group) { + const auto get_group_index = [&unique_groups](const NodePtr &merge, const std::vector &switch_group) { int64_t group_index = merge->GetOpDesc()->GetId(); std::set group_ids{group_index}; for (const auto &node : switch_group) { @@ -173,7 +173,7 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map &switch_group, int64_t group_index) { + const auto set_group_index = [](const NodePtr &merge, const std::vector &switch_group, int64_t group_index) { SetControlFlowGroup(merge, group_index); for (const auto &node : switch_group) { SetControlFlowGroup(node, group_index); @@ -186,8 +186,8 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map {Switch4{id=6, group=5}, Switch5{id=7, group=5}} // Merge4{id=8, group=0} => {Switch1{id=1, group=0}, Switch5{id=7, group=0}} for (const auto group : switch_groups) { - int64_t group_index = GetGroupIndex(group.first, group.second); - SetGroupIndex(group.first, group.second, group_index); + int64_t group_index = get_group_index(group.first, group.second); + set_group_index(group.first, group.second, group_index); } // Step 2: Adjust crossed merge group for unique group. @@ -201,7 +201,7 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::mapfirst != it->second) { - SetGroupIndex(group.first, group.second, it->second); + set_group_index(group.first, group.second, it->second); } } }