You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

infershape_pass_unittest.cc 18 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include <operator_factory_impl.h>
  18. #define protected public
  19. #define private public
  20. #include "graph/passes/infershape_pass.h"
  21. #include "graph/utils/tensor_utils.h"
  22. #include "graph/utils/graph_utils.h"
  23. #include "graph_builder_utils.h"
  24. #include "inc/external/graph/operator_reg.h"
  25. #include "inc/external/graph/operator.h"
  26. #include "inc/external/graph/operator_factory.h"
  27. #include "inc/graph/operator_factory_impl.h"
  28. using namespace std;
  29. using namespace testing;
  30. namespace ge {
  31. class UtestGraphInfershapePass : public testing::Test {
  32. protected:
  33. void SetUp() {}
  34. void TearDown() {}
  35. };
  36. /*
  37. * data1 const1
  38. * \ /
  39. * case1
  40. * |
  41. * relu10
  42. * |
  43. * netoutput
  44. */
  45. ut::GraphBuilder ParentGraphBuilder() {
  46. ut::GraphBuilder builder = ut::GraphBuilder("g1");
  47. auto data1 = builder.AddNode("data1", "Data", 0, 1);
  48. std::vector<int64_t> const_shape = {1};
  49. auto const1 = builder.AddNode("const1", "Const", 0, 1, FORMAT_NCHW, DT_INT32, const_shape);
  50. auto case1 = builder.AddNode("case1", CASE, 2, 1);
  51. auto relu1 = builder.AddNode("relu10", "Relu", 1, 1);
  52. auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  53. int32_t weight[1] = {1};
  54. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32);
  55. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  56. OpDescUtils::SetWeights(const1, {tensor});
  57. builder.AddDataEdge(data1, 0, case1, 0);
  58. builder.AddDataEdge(const1, 0, case1, 1);
  59. builder.AddDataEdge(case1, 0, relu1, 0);
  60. builder.AddDataEdge(relu1, 0, netoutput, 0);
  61. return builder;
  62. }
  63. /*
  64. * data1 data2
  65. * \ /
  66. * switch
  67. * / \
  68. * relu1 relu2
  69. * \ /
  70. * merge
  71. * |
  72. * netoutput
  73. */
  74. ut::GraphBuilder SwitchSubgraphBuilder(string graph_name, uint32_t num) {
  75. ut::GraphBuilder builder = ut::GraphBuilder(graph_name);
  76. std::vector<int64_t> shape1 = {2,2};
  77. string data1_name = "data1_" + std::to_string(num);
  78. auto data1 = builder.AddNode(data1_name, "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape1);
  79. auto data1_desc = data1->GetOpDesc();
  80. EXPECT_NE(data1_desc, nullptr);
  81. AttrUtils::SetInt(data1_desc, "_parent_node_index", 0);
  82. std::vector<int64_t> shape2 = {3,3};
  83. string data2_name = "data2_" + std::to_string(num);
  84. auto data2 = builder.AddNode(data2_name, "Data", 1, 1, FORMAT_NCHW, DT_INT32, shape2);
  85. auto data2_desc = data2->GetOpDesc();
  86. EXPECT_NE(data2_desc, nullptr);
  87. AttrUtils::SetInt(data2_desc, "_parent_node_index", 1);
  88. string switch_name = "switch_" + std::to_string(num);
  89. auto switch1 = builder.AddNode(switch_name, "Switch", 2, 2);
  90. string relu1_name = "relu1_" + std::to_string(num);
  91. auto relu1 = builder.AddNode(relu1_name, "Relu", 1, 1);
  92. string relu2_name = "relu2_" + std::to_string(num);
  93. auto relu2 = builder.AddNode(relu2_name, "Relu", 1, 1);
  94. string merge_name = "merge_" + std::to_string(num);
  95. auto merge = builder.AddNode(merge_name, "Merge", 2, 1);
  96. std::vector<int64_t> shape7 = {8,8};
  97. string output_name = "output_" + std::to_string(num);
  98. auto netoutput = builder.AddNode(output_name, NETOUTPUT, 1, 0, FORMAT_NCHW, DT_INT32, shape7);
  99. auto input0_desc = netoutput->GetOpDesc()->MutableInputDesc(0);
  100. EXPECT_NE(input0_desc, nullptr);
  101. AttrUtils::SetInt(input0_desc, "_parent_node_index", 0);
  102. builder.AddDataEdge(data1, 0, switch1, 0);
  103. builder.AddDataEdge(data2, 0, switch1, 1);
  104. builder.AddDataEdge(switch1, 0, relu1, 0);
  105. builder.AddDataEdge(switch1, 1, relu2, 0);
  106. builder.AddDataEdge(relu1, 0, merge, 0);
  107. builder.AddDataEdge(relu2, 0, merge, 1);
  108. builder.AddDataEdge(merge, 0, netoutput, 0);
  109. return builder;
  110. }
  111. void AddCaseSubgraph(ComputeGraphPtr &parent_graph, uint32_t branch_num) {
  112. auto case_node = parent_graph->FindNode("case1");
  113. EXPECT_NE(case_node, nullptr);
  114. for (uint32_t i = 0; i < branch_num; ++i) {
  115. string name = "Branch_Graph_" + std::to_string(i);
  116. auto builder_subgraph = SwitchSubgraphBuilder(name, i);
  117. auto switch_subgraph = builder_subgraph.GetGraph();
  118. case_node->GetOpDesc()->AddSubgraphName(switch_subgraph->GetName());
  119. case_node->GetOpDesc()->SetSubgraphInstanceName(i, switch_subgraph->GetName());
  120. switch_subgraph->SetParentNode(case_node);
  121. switch_subgraph->SetParentGraph(parent_graph);
  122. EXPECT_EQ(parent_graph->AddSubgraph(switch_subgraph->GetName(), switch_subgraph), GRAPH_SUCCESS);
  123. }
  124. }
  125. static NodePtr CreateNode(ComputeGraph &graph, const string &name, const string &type, int in_num, int out_num) {
  126. OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
  127. op_desc->SetStreamId(0);
  128. static int32_t index = 0;
  129. op_desc->SetId(index++);
  130. GeTensorDesc tensor(GeShape(), FORMAT_NCHW, DT_FLOAT);
  131. TensorUtils::SetSize(tensor, 512);
  132. vector<int64_t> input_offset;
  133. for (int i = 0; i < in_num; i++) {
  134. op_desc->AddInputDesc(tensor);
  135. input_offset.emplace_back(1024);
  136. }
  137. op_desc->SetInputOffset(input_offset);
  138. vector<int64_t> output_offset;
  139. for (int i = 0; i < out_num; i++) {
  140. op_desc->AddOutputDesc(tensor);
  141. output_offset.emplace_back(1024);
  142. }
  143. op_desc->SetOutputOffset(output_offset);
  144. op_desc->SetWorkspace({});
  145. op_desc->SetWorkspaceBytes({});
  146. op_desc->SetOpKernelLibName("DNN_VM_RTS_OP_STORE");
  147. const auto stub_func = [](Operator &op) { return GRAPH_SUCCESS; };
  148. op_desc->AddInferFunc(stub_func);
  149. op_desc->AddInferFormatFunc(stub_func);
  150. op_desc->AddVerifierFunc(stub_func);
  151. return graph.AddNode(op_desc);
  152. }
  153. TEST_F(UtestGraphInfershapePass, infershape_pass_failed) {
  154. GeTensorDesc ge_tensor_desc(GeShape({-2, 2, 3, 4}), ge::FORMAT_NCHW, DT_FLOAT16);
  155. string type = "AddN";
  156. auto addn_op_desc = std::make_shared<OpDesc>("AddN", type);
  157. addn_op_desc->AddInputDesc(ge_tensor_desc);
  158. addn_op_desc->AddOutputDesc(ge_tensor_desc);
  159. auto graph = std::make_shared<ComputeGraph>("test");
  160. auto addn_node = std::make_shared<Node>(addn_op_desc, graph);
  161. addn_node->Init();
  162. InferShapePass infershape_pass;
  163. EXPECT_EQ(infershape_pass.Run(addn_node), GE_GRAPH_INFERSHAPE_FAILED);
  164. }
  165. TEST_F(UtestGraphInfershapePass, delete_need_infer_again) {
  166. auto graph = std::make_shared<ComputeGraph>("test");
  167. auto no_op_desc = std::make_shared<OpDesc>("No", "NoOp");
  168. auto no_op_node = graph->AddNode(no_op_desc);
  169. AttrUtils::SetBool(no_op_desc, "_need_infer_again", false);
  170. InferShapePass infershape_pass;
  171. infershape_pass.options_[kOptimizeAfterSubGraph] = "yes";
  172. EXPECT_EQ(infershape_pass.Run(no_op_node), SUCCESS);
  173. }
  174. TEST_F(UtestGraphInfershapePass, stop_node_for_while_loop) {
  175. /*******************************************************************************
  176. * Exit Identify
  177. * \ / \.
  178. * \ / \.
  179. * Switch Add
  180. * / | |
  181. * / | |
  182. * / | |
  183. * LoopCond | |
  184. * \ | |
  185. * \ | |
  186. * \ | |
  187. * Less | |
  188. * \ | NextIteration
  189. * \ | |
  190. * \ | |
  191. * Merge <---------|
  192. * |
  193. * |
  194. * Enter
  195. ******************************************************************************/
  196. auto graph = std::make_shared<ComputeGraph>("test_infer_shape");
  197. auto data1 = CreateNode(*graph, "data", DATA, 1, 1);
  198. auto enter1 = CreateNode(*graph, "enter", ENTER, 1, 1);
  199. auto merge1 = CreateNode(*graph, "merge", MERGE, 2, 2);
  200. auto less1 = CreateNode(*graph, "less", LESS, 2, 1);
  201. auto loop1 = CreateNode(*graph, "loopcond", LOOPCOND, 1, 1);
  202. auto switch1 = CreateNode(*graph, "switch", SWITCH, 2, 2);
  203. auto ident1 = CreateNode(*graph, "identity", IDENTITY, 1, 1);
  204. auto add1 = CreateNode(*graph, "add", ADD, 2, 1);
  205. auto next1 = CreateNode(*graph, "next", NEXTITERATION, 1, 1);
  206. auto exit1 = CreateNode(*graph, "exit", EXIT, 1, 1);
  207. auto value0 = CreateNode(*graph, "const", CONSTANT, 0, 1);
  208. auto value1 = CreateNode(*graph, "const", CONSTANT, 0, 1);
  209. auto output1 = CreateNode(*graph, "net_output", NETOUTPUT, 1, 1);
  210. GraphUtils::AddEdge(data1->GetOutDataAnchor(0), enter1->GetInDataAnchor(0));
  211. GraphUtils::AddEdge(enter1->GetOutDataAnchor(0), merge1->GetInDataAnchor(0));
  212. GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), less1->GetInDataAnchor(0));
  213. GraphUtils::AddEdge(value1->GetOutDataAnchor(0), less1->GetInDataAnchor(1));
  214. GraphUtils::AddEdge(less1->GetOutDataAnchor(0), loop1->GetInDataAnchor(0));
  215. GraphUtils::AddEdge(loop1->GetOutDataAnchor(0), switch1->GetInDataAnchor(0));
  216. GraphUtils::AddEdge(merge1->GetOutDataAnchor(0), switch1->GetInDataAnchor(1));
  217. GraphUtils::AddEdge(switch1->GetOutDataAnchor(0), exit1->GetInDataAnchor(0));
  218. GraphUtils::AddEdge(switch1->GetOutDataAnchor(1), ident1->GetInDataAnchor(0));
  219. GraphUtils::AddEdge(ident1->GetOutDataAnchor(0), add1->GetInDataAnchor(0));
  220. GraphUtils::AddEdge(value1->GetOutDataAnchor(0), add1->GetInDataAnchor(1));
  221. GraphUtils::AddEdge(add1->GetOutDataAnchor(0), next1->GetInDataAnchor(0));
  222. GraphUtils::AddEdge(next1->GetOutDataAnchor(0), merge1->GetInDataAnchor(1));
  223. GraphUtils::AddEdge(exit1->GetOutDataAnchor(0), output1->GetInDataAnchor(0));
  224. GEPass ge_passes(graph);
  225. NamesToPass names_to_passes;
  226. InferShapePass infer_shape_pass;
  227. names_to_passes.emplace_back("InferShapePass", &infer_shape_pass);
  228. EXPECT_EQ(ge_passes.Run(names_to_passes), SUCCESS);
  229. }
  230. TEST_F(UtestGraphInfershapePass, infer_with_case_subgraph) {
  231. auto builder = ParentGraphBuilder();
  232. auto parent_graph = builder.GetGraph();
  233. AddCaseSubgraph(parent_graph, 2);
  234. auto subgraphs = parent_graph->GetAllSubgraphs();
  235. EXPECT_EQ(subgraphs.size(), 2);
  236. auto case_node = parent_graph->FindNode("case1");
  237. EXPECT_NE(case_node, nullptr);
  238. InferShapePass infershape_pass;
  239. EXPECT_EQ(infershape_pass.Run(case_node), SUCCESS);
  240. std::vector<int64_t> target_dims_0 = {1, 1, 224, 224};
  241. std::vector<int64_t> target_dims_1 = {1};
  242. {
  243. auto data_node = subgraphs[0]->FindNode("data1_0");
  244. auto dims = data_node->GetOpDesc()->GetInputDescPtr(0)->GetShape().GetDims();
  245. EXPECT_EQ(dims, target_dims_0);
  246. data_node = subgraphs[0]->FindNode("data2_0");
  247. dims = data_node->GetOpDesc()->GetInputDescPtr(0)->GetShape().GetDims();
  248. EXPECT_EQ(dims, target_dims_1);
  249. }
  250. infershape_pass.options_[kOptimizeAfterSubGraph] = "yes";
  251. EXPECT_EQ(infershape_pass.Run(case_node), SUCCESS);
  252. {
  253. auto dims = case_node->GetOpDesc()->GetOutputDescPtr(0)->GetShape().GetDims();
  254. std::vector<int64_t> out_target_dims = {8, 8};
  255. EXPECT_EQ(out_target_dims, dims);
  256. }
  257. }
  258. /*
  259. * data1 const1
  260. * \ /
  261. * while
  262. * / \
  263. * relu1 netoutput
  264. */
  265. ut::GraphBuilder ParentWhileGraphBuilder() {
  266. ut::GraphBuilder builder = ut::GraphBuilder("g1");
  267. auto data1 = builder.AddNode("data1", "Data", 0, 1);
  268. std::vector<int64_t> const_shape = {1};
  269. auto const1 = builder.AddNode("const1", "Const", 0, 1, FORMAT_NCHW, DT_FLOAT, const_shape);
  270. auto case1 = builder.AddNode("case1", WHILE, 2, 2);
  271. auto relu1 = builder.AddNode("relu1", "Relu", 1, 1);
  272. auto netoutput = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  273. int32_t weight[1] = {1};
  274. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_FLOAT);
  275. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  276. OpDescUtils::SetWeights(const1, {tensor});
  277. builder.AddDataEdge(data1, 0, case1, 0);
  278. builder.AddDataEdge(const1, 0, case1, 1);
  279. builder.AddDataEdge(case1, 0, relu1, 0);
  280. builder.AddDataEdge(case1, 1, netoutput, 0);
  281. return builder;
  282. }
  283. /*
  284. * data1 data2
  285. * \ /
  286. * switch
  287. * | |
  288. * \ /
  289. * netoutput
  290. */
  291. ut::GraphBuilder WhileSubgraphBuilder(string graph_name, uint32_t num) {
  292. ut::GraphBuilder builder = ut::GraphBuilder(graph_name);
  293. std::vector<int64_t> shape1 = {2,2};
  294. string data1_name = "data1_" + std::to_string(num);
  295. auto data1 = builder.AddNode(data1_name, "Data", 1, 1, FORMAT_NCHW, DT_FLOAT, shape1);
  296. auto data1_desc = data1->GetOpDesc();
  297. EXPECT_NE(data1_desc, nullptr);
  298. AttrUtils::SetInt(data1_desc, "_parent_node_index", 0);
  299. std::vector<int64_t> shape2 = {3,3};
  300. string data2_name = "data2_" + std::to_string(num);
  301. auto data2 = builder.AddNode(data2_name, "Data", 1, 1, FORMAT_NCHW, DT_FLOAT, shape2);
  302. auto data2_desc = data2->GetOpDesc();
  303. EXPECT_NE(data2_desc, nullptr);
  304. AttrUtils::SetInt(data2_desc, "_parent_node_index", 1);
  305. string switch_name = "switch_" + std::to_string(num);
  306. auto switch1 = builder.AddNode(switch_name, "Switch", 2, 2);
  307. std::vector<int64_t> shape7 = {8,8,8,8};
  308. string output_name = "output_" + std::to_string(num);
  309. auto netoutput = builder.AddNode(output_name, NETOUTPUT, 2, 0, FORMAT_NCHW, DT_FLOAT, shape7);
  310. auto input0_desc = netoutput->GetOpDesc()->MutableInputDesc(0);
  311. EXPECT_NE(input0_desc, nullptr);
  312. AttrUtils::SetInt(input0_desc, "_parent_node_index", 0);
  313. auto input1_desc = netoutput->GetOpDesc()->MutableInputDesc(1);
  314. EXPECT_NE(input1_desc, nullptr);
  315. AttrUtils::SetInt(input1_desc, "_parent_node_index", 1);
  316. builder.AddDataEdge(data1, 0, switch1, 0);
  317. builder.AddDataEdge(data2, 0, switch1, 1);
  318. builder.AddDataEdge(switch1, 0, netoutput, 0);
  319. builder.AddDataEdge(switch1, 1, netoutput, 1);
  320. return builder;
  321. }
  322. void AddWhileSubgraph(ComputeGraphPtr &parent_graph, uint32_t branch_num) {
  323. auto case_node = parent_graph->FindNode("case1");
  324. EXPECT_NE(case_node, nullptr);
  325. for (uint32_t i = 0; i < branch_num; ++i) {
  326. string name = "Branch_Graph_" + std::to_string(i);
  327. auto builder_subgraph = WhileSubgraphBuilder(name, i);
  328. auto switch_subgraph = builder_subgraph.GetGraph();
  329. case_node->GetOpDesc()->AddSubgraphName(switch_subgraph->GetName());
  330. case_node->GetOpDesc()->SetSubgraphInstanceName(i, switch_subgraph->GetName());
  331. switch_subgraph->SetParentNode(case_node);
  332. switch_subgraph->SetParentGraph(parent_graph);
  333. EXPECT_EQ(parent_graph->AddSubgraph(switch_subgraph->GetName(), switch_subgraph), GRAPH_SUCCESS);
  334. }
  335. }
  336. TEST_F(UtestGraphInfershapePass, infer_with_while_subgraph) {
  337. auto builder = ParentWhileGraphBuilder();
  338. auto parent_graph = builder.GetGraph();
  339. AddWhileSubgraph(parent_graph, 1);
  340. auto subgraphs = parent_graph->GetAllSubgraphs();
  341. EXPECT_EQ(subgraphs.size(), 1);
  342. auto case_node = parent_graph->FindNode("case1");
  343. EXPECT_NE(case_node, nullptr);
  344. InferShapePass infershape_pass;
  345. EXPECT_EQ(infershape_pass.Run(case_node), SUCCESS);
  346. std::vector<int64_t> target_dims_0 = {1, 1, 224, 224};
  347. std::vector<int64_t> target_dims_1 = {1};
  348. {
  349. auto data_node = subgraphs[0]->FindNode("data1_0");
  350. auto dims = data_node->GetOpDesc()->GetInputDescPtr(0)->GetShape().GetDims();
  351. EXPECT_EQ(dims, target_dims_0);
  352. data_node = subgraphs[0]->FindNode("data2_0");
  353. dims = data_node->GetOpDesc()->GetInputDescPtr(0)->GetShape().GetDims();
  354. EXPECT_EQ(dims, target_dims_1);
  355. }
  356. infershape_pass.options_[kOptimizeAfterSubGraph] = "yes";
  357. EXPECT_EQ(infershape_pass.Run(case_node), SUCCESS);
  358. {
  359. auto dims = case_node->GetOpDesc()->GetOutputDescPtr(0)->GetShape().GetDims();
  360. std::vector<int64_t> out_target_dims = {-1, -1, -1, -1};
  361. EXPECT_EQ(out_target_dims, dims);
  362. }
  363. }
  364. TEST_F(UtestGraphInfershapePass, infer_with_while_subgraph_failed) {
  365. auto builder = ParentWhileGraphBuilder();
  366. auto parent_graph = builder.GetGraph();
  367. AddWhileSubgraph(parent_graph, 2);
  368. auto subgraphs = parent_graph->GetAllSubgraphs();
  369. EXPECT_EQ(subgraphs.size(), 2);
  370. auto case_node = parent_graph->FindNode("case1");
  371. EXPECT_NE(case_node, nullptr);
  372. InferShapePass infershape_pass;
  373. infershape_pass.options_[kOptimizeAfterSubGraph] = "yes";
  374. EXPECT_EQ(infershape_pass.Run(case_node), GE_GRAPH_INFERSHAPE_FAILED);
  375. }
  376. auto InferFunc = [&](Operator &op) {
  377. return GRAPH_SUCCESS;
  378. };
  379. TEST_F(UtestGraphInfershapePass, infer_forrunning_with_while_subgraph) {
  380. auto builder = ParentWhileGraphBuilder();
  381. auto parent_graph = builder.GetGraph();
  382. AddWhileSubgraph(parent_graph, 1);
  383. auto subgraphs = parent_graph->GetAllSubgraphs();
  384. EXPECT_EQ(subgraphs.size(), 1);
  385. OperatorFactoryImpl::RegisterInferShapeFunc("Relu", InferFunc);
  386. auto relu_node = parent_graph->FindNode("relu1");
  387. EXPECT_NE(relu_node, nullptr);
  388. InferShapeForRunning infershape_for_running;
  389. EXPECT_EQ(infershape_for_running.Run(relu_node), SUCCESS);
  390. }
  391. TEST_F(UtestGraphInfershapePass, infer_static_func) {
  392. auto builder = ut::GraphBuilder("test_graph");
  393. auto data_1 = builder.AddNode("data_1", DATA, 0, 1);
  394. auto data_2 = builder.AddNode("data_2", DATA, 0, 1);
  395. auto add = builder.AddNode("Add", "Add", 2, 1);
  396. builder.AddDataEdge(data_1, 0, add, 0);
  397. builder.AddDataEdge(data_2, 0, add, 1);
  398. auto test_graph = builder.GetGraph();
  399. // OperatorFactoryImpl::CreateOperator("Add", "Flatten");
  400. auto test_node = test_graph->FindNode("Add");
  401. auto ret = InferShapePass::InferShapeAndType(test_node);
  402. EXPECT_EQ(ret, GRAPH_SUCCESS);
  403. OperatorFactoryImpl::RegisterInferShapeFunc("Add", InferFunc);
  404. ret = InferShapePass::InferShapeAndType(test_node);
  405. EXPECT_EQ(ret, GRAPH_SUCCESS);
  406. ret = InferShapePass::InferShapeAndType(test_node, true);
  407. EXPECT_EQ(ret, GRAPH_SUCCESS);
  408. ret = InferShapeForRunning::InferShapeAndTypeForRunning(test_node, true);
  409. EXPECT_EQ(ret, GRAPH_SUCCESS);
  410. }
  411. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示