diff --git a/ge/graph/passes/mark_agnostic_pass.cc b/ge/graph/passes/mark_agnostic_pass.cc index 00a3dad9..4fdc8e1b 100644 --- a/ge/graph/passes/mark_agnostic_pass.cc +++ b/ge/graph/passes/mark_agnostic_pass.cc @@ -16,6 +16,7 @@ #include "graph/passes/mark_agnostic_pass.h" #include "graph/utils/node_utils.h" +#include "graph/utils/tensor_utils.h" namespace ge { Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { @@ -47,6 +48,16 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { } if (node_type == MERGE) { GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str()); + auto in_nodes = node->GetInAllNodes(); + vector input_nodes(in_nodes.begin(), in_nodes.end()); + /// Enter-----------+ + /// +-> Merge + /// NextIteration---+ + if (input_nodes.size() == 2) { + if (input_nodes[0]->GetType() == ENTER && input_nodes[1]->GetType() == NEXTITERATION) { + continue; + } + } const OpDescPtr op_desc = node->GetOpDesc(); const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0); if (op_tensor == nullptr) { diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index 58718c9f..b899ee83 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -117,6 +117,7 @@ #include "graph/passes/variable_op_pass.h" #include "graph/passes/variable_prepare_op_pass.h" #include "graph/passes/variable_ref_delete_op_pass.h" +#include "graph/passes/mark_agnostic_pass.h" namespace ge { @@ -1700,6 +1701,7 @@ Status GraphPrepare::PrepareOptimize() { try { (void)original_graph_passes.AddPass("PrepareOptimize::ShapeOperateOpRemovePass", new ShapeOperateOpRemovePass); (void)original_graph_passes.AddPass("PrepareOptimize::ReplaceTransShapePass", new ReplaceTransShapePass); + (void)original_graph_passes.AddPass("PrepareOptimize::MarkAgnosticPass" , new MarkAgnosticPass); } catch (std::bad_alloc &e) { GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); return INTERNAL_ERROR;