diff --git a/ge/graph/passes/mark_agnostic_pass.cc b/ge/graph/passes/mark_agnostic_pass.cc index 4fdc8e1b..be8d132a 100644 --- a/ge/graph/passes/mark_agnostic_pass.cc +++ b/ge/graph/passes/mark_agnostic_pass.cc @@ -43,18 +43,17 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { if (node_type == REFMERGE || node_type == REFSWITCH) { GELOGD("Mark format agnostic for regmerge and refswitch node %s", node->GetName().c_str()); AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); - AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector({1})); + AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_output", std::vector({1})); continue; } if (node_type == MERGE) { GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str()); - auto in_nodes = node->GetInAllNodes(); - vector input_nodes(in_nodes.begin(), in_nodes.end()); + const auto &input_nodes = node->GetInAllNodes(); /// Enter-----------+ /// +-> Merge /// NextIteration---+ if (input_nodes.size() == 2) { - if (input_nodes[0]->GetType() == ENTER && input_nodes[1]->GetType() == NEXTITERATION) { + if (input_nodes.at(0)->GetType() == ENTER && input_nodes.at(1)->GetType() == NEXTITERATION) { continue; } } @@ -65,6 +64,14 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { continue; } AttrUtils::SetInt(op_tensor, "_format_continuous", 1); + + // Merge----------->NetOutput only set format_cofntinuous attr + const auto &output_nodes = node->GetOutAllNodes(); + if (output_nodes.size() > 0) { + if (output_nodes.at(0)->GetType() == NETOUTPUT) { + continue; + } + } AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_output", std::vector({1})); continue;