From 80cf7df10c234cf6d32699c680f09f0d4734d049 Mon Sep 17 00:00:00 2001 From: zhou_chao1993 Date: Sat, 28 Nov 2020 16:00:07 +0800 Subject: [PATCH] modify agnostic --- ge/graph/passes/mark_agnostic_pass.cc | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/ge/graph/passes/mark_agnostic_pass.cc b/ge/graph/passes/mark_agnostic_pass.cc index 4fdc8e1b..be8d132a 100644 --- a/ge/graph/passes/mark_agnostic_pass.cc +++ b/ge/graph/passes/mark_agnostic_pass.cc @@ -43,18 +43,17 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { if (node_type == REFMERGE || node_type == REFSWITCH) { GELOGD("Mark format agnostic for regmerge and refswitch node %s", node->GetName().c_str()); AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); - AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector({1})); + AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_output", std::vector({1})); continue; } if (node_type == MERGE) { GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str()); - auto in_nodes = node->GetInAllNodes(); - vector input_nodes(in_nodes.begin(), in_nodes.end()); + const auto &input_nodes = node->GetInAllNodes(); /// Enter-----------+ /// +-> Merge /// NextIteration---+ if (input_nodes.size() == 2) { - if (input_nodes[0]->GetType() == ENTER && input_nodes[1]->GetType() == NEXTITERATION) { + if (input_nodes.at(0)->GetType() == ENTER && input_nodes.at(1)->GetType() == NEXTITERATION) { continue; } } @@ -65,6 +64,14 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { continue; } AttrUtils::SetInt(op_tensor, "_format_continuous", 1); + + // Merge----------->NetOutput only set format_cofntinuous attr + const auto &output_nodes = node->GetOutAllNodes(); + if (output_nodes.size() > 0) { + if (output_nodes.at(0)->GetType() == NETOUTPUT) { + continue; + } + } AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_output", std::vector({1})); continue;