diff --git a/ge/graph/passes/mark_agnostic_pass.cc b/ge/graph/passes/mark_agnostic_pass.cc index 77fa64fb..1ddd571f 100644 --- a/ge/graph/passes/mark_agnostic_pass.cc +++ b/ge/graph/passes/mark_agnostic_pass.cc @@ -21,7 +21,7 @@ namespace ge { Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { for (const auto &node : graph->GetDirectNode()) { auto node_type = NodeUtils::GetNodeType(*node); - if (node_type == SWITCH || node_type == REFSWITCH || node_type == SWITCHN) { + if (node_type == SWITCH || node_type == SWITCHN) { GELOGD("Mark format agnostic and continuous for switch node %s", node->GetName().c_str()); const OpDescPtr op_desc = node->GetOpDesc(); const GeTensorDescPtr op_tensor = op_desc->MutableInputDesc(0); @@ -34,13 +34,13 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector({1})); continue; } - if (node_type == IDENTITY) { + if (node_type == IDENTITY || node_type == REFMERGE || node_type == REFSWITCH) { GELOGD("Mark format agnostic for identity node %s", node->GetName().c_str()); AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector({1})); continue; } - if (node_type == MERGE || node_type == REFMERGE) { + if (node_type == MERGE) { GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str()); const OpDescPtr op_desc = node->GetOpDesc(); const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0); diff --git a/ge/graph/preprocess/graph_preprocess.cc b/ge/graph/preprocess/graph_preprocess.cc index f90c0d80..98371426 100644 --- a/ge/graph/preprocess/graph_preprocess.cc +++ b/ge/graph/preprocess/graph_preprocess.cc @@ -117,6 +117,7 @@ #include "graph/passes/variable_op_pass.h" #include "graph/passes/variable_prepare_op_pass.h" #include "graph/passes/variable_ref_delete_op_pass.h" +#include "graph/passes/mark_agnostic_pass.h" namespace ge { @@ -1626,6 +1627,7 @@ Status GraphPrepare::PrepareOptimize() { try { (void)original_graph_passes.AddPass("PrepareOptimize::ShapeOperateOpRemovePass", new ShapeOperateOpRemovePass); (void)original_graph_passes.AddPass("PrepareOptimize::ReplaceTransShapePass", new ReplaceTransShapePass); + (void)original_graph_passes.AddPass("PrepareOptimize::MarkAgnosticPass", new MarkAgnosticPass); } catch (std::bad_alloc &e) { GELOGE(INTERNAL_ERROR, "Add pass failed, bad memory allocation occurs."); return INTERNAL_ERROR;