From 81f33820733b1424d6959c9b45e7d8141daee17c Mon Sep 17 00:00:00 2001 From: zhangxiaokun Date: Wed, 26 May 2021 11:16:08 +0800 Subject: [PATCH] Add unit test for NextIteration --- tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc b/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc index fcbddbb1..d97629cf 100644 --- a/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc +++ b/tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc @@ -127,7 +127,7 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); } - const auto less1 = CreateNode(graph, "less", EXIT, 2, 1); + const auto less1 = CreateNode(graph, "less", EXIT, 2, 1); // Mock for less, just pass input0. const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); @@ -135,13 +135,14 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true. AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL); - const auto add1 = CreateNode(graph, "add", EXIT, 2, 1); - const auto sub1 = CreateNode(graph, "sub", EXIT, 2, 1); + const auto add1 = CreateNode(graph, "add", EXIT, 2, 1); // Mock for add, just pass input0. + const auto sub1 = CreateNode(graph, "sub", EXIT, 2, 1); // Mock for sub, just pass input0. const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2); const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0); const auto active3 = CreateNode(graph, "active3", STREAMACTIVE, 0, 0); + const auto iteration1 = CreateNode(graph, "iteration1", NEXTITERATION, 1, 1); const auto output1 = CreateNode(graph, "net_output", NETOUTPUT, 1, 1); output1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); @@ -170,7 +171,8 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt GraphUtils::AddEdge(sub1->GetOutControlAnchor(), active3->GetInControlAnchor()); GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor()); - GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); + GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), iteration1->GetInDataAnchor(0)); + GraphUtils::AddEdge(iteration1->GetOutDataAnchor(0), output1->GetInDataAnchor(0)); } TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) {