Browse Source

Add unit test for NextIteration

pull/1691/head
zhangxiaokun 4 years ago
parent
commit
81f3382073
1 changed files with 6 additions and 4 deletions
  1. +6
    -4
      tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc

+ 6
- 4
tests/ut/ge/hybrid/executor/subgraph_executor_unittest.cc View File

@@ -127,7 +127,7 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt
AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight); AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, weight);
} }


const auto less1 = CreateNode(graph, "less", EXIT, 2, 1);
const auto less1 = CreateNode(graph, "less", EXIT, 2, 1); // Mock for less, just pass input0.


const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0); const auto active1 = CreateNode(graph, "active1", STREAMACTIVE, 0, 0);
switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0); switch_t = CreateNode(graph, "switch_t", STREAMSWITCH, 2, 0);
@@ -135,13 +135,14 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt
AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true. AttrUtils::SetInt(switch_t->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_EQUAL); // 101 for true.
AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL); AttrUtils::SetInt(switch_f->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_COND, RT_NOT_EQUAL);


const auto add1 = CreateNode(graph, "add", EXIT, 2, 1);
const auto sub1 = CreateNode(graph, "sub", EXIT, 2, 1);
const auto add1 = CreateNode(graph, "add", EXIT, 2, 1); // Mock for add, just pass input0.
const auto sub1 = CreateNode(graph, "sub", EXIT, 2, 1); // Mock for sub, just pass input0.


const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2); const auto merge1 = CreateNode(graph, "merge", STREAMMERGE, 2, 2);
const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0); const auto active2 = CreateNode(graph, "active2", STREAMACTIVE, 0, 0);
const auto active3 = CreateNode(graph, "active3", STREAMACTIVE, 0, 0); const auto active3 = CreateNode(graph, "active3", STREAMACTIVE, 0, 0);


const auto iteration1 = CreateNode(graph, "iteration1", NEXTITERATION, 1, 1);
const auto output1 = CreateNode(graph, "net_output", NETOUTPUT, 1, 1); const auto output1 = CreateNode(graph, "net_output", NETOUTPUT, 1, 1);
output1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE"); output1->GetOpDesc()->SetOpKernelLibName("DNN_VM_GE_LOCAL_OP_STORE");


@@ -170,7 +171,8 @@ static void CreateSimpleCondGraph(ComputeGraph &graph, NodePtr &switch_t, NodePt
GraphUtils::AddEdge(sub1->GetOutControlAnchor(), active3->GetInControlAnchor()); GraphUtils::AddEdge(sub1->GetOutControlAnchor(), active3->GetInControlAnchor());
GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor()); GraphUtils::AddEdge(active3->GetOutControlAnchor(), merge1->GetInControlAnchor());


GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), output1->GetInDataAnchor(0));
GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), iteration1->GetInDataAnchor(0));
GraphUtils::AddEdge(iteration1->GetOutDataAnchor(0), output1->GetInDataAnchor(0));
} }


TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) { TEST_F(UtestSubgraphExecutor, simple_schedule_tasks) {


Loading…
Cancel
Save