diff --git a/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc b/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc index fbda3776..bd7df2b7 100644 --- a/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc +++ b/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc @@ -86,7 +86,7 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt * | * Merge * / \. - * / \. + * Active / \ Active * / \. * Add Sub * | \ / | @@ -96,8 +96,8 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt * Switch Switch * | \ / | * | \ / | - * | \ / | - * | \ / | + * | Active | + * | \ / | * | Less | * | / \ | * | / \ | @@ -127,7 +127,7 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); } - const auto less1 = CreateNode(graph, "less", ENTER, 2, 1); + const auto less1 = CreateNode(graph, "less", EXIT, 2, 1); const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); @@ -135,8 +135,8 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true. AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL); - const auto add1 = CreateNode(graph, "add", ENTER, 2, 1); - const auto sub1 = CreateNode(graph, "sub", ENTER, 2, 1); + const auto add1 = CreateNode(graph, "add", EXIT, 2, 1); + const auto sub1 = CreateNode(graph, "sub", EXIT, 2, 1); const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2); const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0); diff --git a/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc b/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc index 9630b193..f073f519 100644 --- a/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc +++ b/tests/ut/ge/hybrid/model/hybrid_model_builder_unittest.cc @@ -27,6 +27,7 @@ #include "graph/utils/tensor_utils.h" #include "graph/utils/graph_utils.h" #include "graph/debug/ge_attr_define.h" +#include "graph/common/omg_util.h" using namespace std; using namespace testing; @@ -147,7 +148,7 @@ TEST_F(UtestHybridModelBuilder, normal_hybrid_model_build) { GraphUtils::AddEdge(next1->GetOutControlAnchor(), active3->GetInControlAnchor()); GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); - AttrUtils::SetStr(merge1->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, next1->GetName()); + SetNextIteration(merge1, next1); AttrUtils::SetBool(enter1->GetOpDesc(), ATTR_NAME_INSERT_FP_PROFILILNG_TASK, true); AttrUtils::SetBool(output1->GetOpDesc(), ATTR_NAME_INSERT_BP_PROFILILNG_TASK, true);