Browse Source

!1942 Fix cross merge for switch

Merge pull request !1942 from 张晓昆/r1.5.0
pull/1942/MERGE
i-robot Gitee 3 years ago
parent
commit
641a5f9cd3
1 changed files with 55 additions and 9 deletions
  1. +55
    -9
      ge/graph/passes/mark_force_unknown_for_cond_pass.cc

+ 55
- 9
ge/graph/passes/mark_force_unknown_for_cond_pass.cc View File

@@ -145,17 +145,63 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std:
/// @return
///
void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) {
for (auto it = switch_groups.begin(); it != switch_groups.end(); ++it) {
const auto &op_node = it->first;
const auto &op_desc = op_node->GetOpDesc();
if (op_desc->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) {
continue;
// Step 0: no group assigned. such as:
// Merge1{id=0, group=} => {Switch1{id=1, group=}, Switch2{id=2, group=}}
// Merge2{id=3, group=} => {Switch1{id=1, group=}, Switch3{id=4, group=}}
// Merge3{id=5, group=} => {Switch4{id=6, group=}, Switch5{id=7, group=}}
// Merge4{id=8, group=} => {Switch1{id=1, group=}, Switch5{id=7, group=}}
std::map<int64_t, int64_t> unique_groups;
const auto get_group_index = [&unique_groups](const NodePtr &merge, const std::vector<NodePtr> &switch_group) {
int64_t group_index = merge->GetOpDesc()->GetId();
std::set<int64_t> group_ids{group_index};
for (const auto &node : switch_group) {
if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) {
GELOGI("[%s] Get group from [%s], index[%ld]", merge->GetName().c_str(), node->GetName().c_str(), group_index);
group_ids.insert(group_index);
}
}

const auto it = unique_groups.find(group_index);
if (it != unique_groups.end()) {
group_index = it->second;
}

int64_t group_index = op_desc->GetId();
SetControlFlowGroup(op_node, group_index);
for (const auto &n : it->second) {
SetControlFlowGroup(n, group_index);
for (auto id : group_ids) {
unique_groups[id] = group_index;
}

return group_index;
};

const auto set_group_index = [](const NodePtr &merge, const std::vector<NodePtr> &switch_group, int64_t group_index) {
SetControlFlowGroup(merge, group_index);
for (const auto &node : switch_group) {
SetControlFlowGroup(node, group_index);
}
};

// Step 1: Set group index to merge, if switch already has group, use assigned group.
// Merge1{id=0, group=0} => {Switch1{id=1, group=0}, Switch2{id=2, group=0}}
// Merge2{id=3, group=0} => {Switch1{id=1, group=0}, Switch3{id=4, group=0}}
// Merge3{id=5, group=5} => {Switch4{id=6, group=5}, Switch5{id=7, group=5}}
// Merge4{id=8, group=0} => {Switch1{id=1, group=0}, Switch5{id=7, group=0}}
for (const auto group : switch_groups) {
int64_t group_index = get_group_index(group.first, group.second);
set_group_index(group.first, group.second, group_index);
}

// Step 2: Adjust crossed merge group for unique group.
// Merge1{id=0, group=0} => {Switch1{id=1, group=0}, Switch2{id=2, group=0}}
// Merge2{id=3, group=0} => {Switch1{id=1, group=0}, Switch3{id=4, group=0}}
// Merge3{id=5, group=0} => {Switch4{id=6, group=0}, Switch5{id=7, group=0}}
// Merge4{id=8, group=0} => {Switch1{id=1, group=0}, Switch5{id=7, group=0}}
for (const auto group : switch_groups) {
int64_t group_index = -1;
(void)AttrUtils::GetInt(group.first->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);

const auto it = unique_groups.find(group_index);
if (it != unique_groups.end() && it->first != it->second) {
set_group_index(group.first, group.second, it->second);
}
}
}


Loading…
Cancel
Save