You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

flow_ctrl_pass_unittest.cc 18 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include "common/ge_inner_error_codes.h"
  18. #include "common/types.h"
  19. #include "graph/manager/graph_var_manager.h"
  20. #include "graph/utils/attr_utils.h"
  21. #include "graph/utils/graph_utils.h"
  22. #include "inc/pass_manager.h"
  23. #define private public
  24. #include "graph/passes/flow_ctrl_pass.h"
  25. #undef private
  26. namespace ge {
  27. class UtestGraphPassesFlowCtrlPass : public testing::Test {
  28. protected:
  29. void SetUp() {
  30. uint64_t session_id = 0;
  31. uint32_t device_id = 0;
  32. uint64_t job_id = 0;
  33. uint32_t session_version = 0;
  34. EXPECT_EQ(SUCCESS, ge::VarManager::Instance(0)->Init(session_version, session_id, device_id, job_id));
  35. }
  36. void TearDown() { VarManagerPool::Instance().Destroy(); }
  37. public:
  38. /// Set up a graph with the following network structure
  39. /// IteratorGetNext
  40. /// |
  41. /// MemcpyAsync
  42. /// |
  43. /// A
  44. /// |
  45. /// NetOutput
  46. void MakeGraph(ge::ComputeGraphPtr &graph) {
  47. auto desc_ptr = make_shared<ge::GeTensorDesc>();
  48. auto desc = *desc_ptr;
  49. ge::OpDescPtr op_desc_get_next = make_shared<ge::OpDesc>("IteratorGetNext", FRAMEWORKOP);
  50. op_desc_get_next->AddOutputDesc(desc);
  51. ge::OpDescPtr op_desc_memcpy = make_shared<ge::OpDesc>("MemcpyAsync", MEMCPYASYNC);
  52. op_desc_memcpy->AddInputDesc(desc);
  53. op_desc_memcpy->AddOutputDesc(desc);
  54. ge::AttrUtils::SetBool(op_desc_memcpy, ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, true);
  55. ge::OpDescPtr op_desc_a = make_shared<ge::OpDesc>("A", RESOURCEAPPLYMOMENTUM);
  56. op_desc_a->AddInputDesc(desc);
  57. op_desc_a->AddOutputDesc(desc);
  58. ge::OpDescPtr op_desc_gatherv2 = make_shared<ge::OpDesc>("GatherV2", GATHERV2);
  59. op_desc_gatherv2->AddInputDesc(desc);
  60. op_desc_gatherv2->AddOutputDesc(desc);
  61. ge::OpDescPtr op_desc_global_step = make_shared<ge::OpDesc>("global_step", VARIABLE);
  62. op_desc_global_step->AddOutputDesc(desc);
  63. ge::OpDescPtr op_desc_netout = make_shared<ge::OpDesc>("NetOutput", NETOUTPUT);
  64. ge::AttrUtils::SetInt(op_desc_netout, ATTR_NAME_TRUE_BRANCH_STREAM, TRUE_STREAM_ID);
  65. op_desc_netout->AddInputDesc(desc);
  66. op_desc_netout->AddInputDesc(desc);
  67. // add node
  68. ge::NodePtr get_next_node = graph->AddNode(op_desc_get_next);
  69. ge::NodePtr memcpy_node = graph->AddNode(op_desc_memcpy);
  70. ge::NodePtr node_a = graph->AddNode(op_desc_a);
  71. ge::NodePtr global_step = graph->AddNode(op_desc_global_step);
  72. ge::NodePtr gatherv2 = graph->AddNode(op_desc_gatherv2);
  73. ge::NodePtr netoutput = graph->AddNode(op_desc_netout);
  74. // add edge
  75. ge::GraphUtils::AddEdge(get_next_node->GetOutDataAnchor(0), memcpy_node->GetInDataAnchor(0));
  76. ge::GraphUtils::AddEdge(memcpy_node->GetOutDataAnchor(0), node_a->GetInDataAnchor(0));
  77. ge::GraphUtils::AddEdge(node_a->GetOutDataAnchor(0), netoutput->GetInDataAnchor(0));
  78. ge::GraphUtils::AddEdge(gatherv2->GetOutDataAnchor(0), netoutput->GetInDataAnchor(1));
  79. ge::GraphUtils::AddEdge(global_step->GetOutDataAnchor(0), gatherv2->GetInDataAnchor(0));
  80. }
  81. void AddSessionVariables(void) {
  82. static std::set<std::string> var_list = {
  83. NODE_NAME_FLOWCTRL_LOOP_PER_ITER,
  84. NODE_NAME_FLOWCTRL_LOOP_COND,
  85. NODE_NAME_FLOWCTRL_LOOP_INCREMENT,
  86. NODE_NAME_FLOWCTRL_LOOP_RESETVALUE,
  87. NODE_NAME_GLOBAL_STEP,
  88. };
  89. uint8_t *dev_ptr = nullptr;
  90. ge::GeTensorDesc tensor_desc(ge::GeShape({1}), ge::FORMAT_NHWC, ge::DT_UINT64);
  91. for (std::string var_name : var_list) {
  92. EXPECT_EQ(SUCCESS, ge::VarManager::Instance(0)->SetVarAddr(var_name, tensor_desc, dev_ptr, RT_MEMORY_HBM));
  93. }
  94. }
  95. };
  96. TEST_F(UtestGraphPassesFlowCtrlPass, flow_ctrl_pass_success_test) {
  97. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("FlowCtrlPassSuccess");
  98. graph->SetNeedIteration(true);
  99. // Create graph
  100. MakeGraph(graph);
  101. graph->TopologicalSorting();
  102. AddSessionVariables();
  103. FlowCtrlPass flow_ctrl_pass;
  104. Status ret = flow_ctrl_pass.Run(graph);
  105. EXPECT_EQ(ret, SUCCESS);
  106. EXPECT_EQ(16, graph->GetDirectNodesSize());
  107. int stream_switch_cnt = 0;
  108. int stream_activeCnt = 0;
  109. for (ge::NodePtr node : graph->GetDirectNode()) {
  110. if (node->GetOpDesc()->GetType() == STREAMSWITCH) {
  111. stream_switch_cnt++;
  112. } else if (node->GetOpDesc()->GetType() == STREAMACTIVE) {
  113. stream_activeCnt++;
  114. }
  115. }
  116. EXPECT_EQ(stream_switch_cnt, 2);
  117. EXPECT_EQ(stream_activeCnt, 2);
  118. }
  119. TEST_F(UtestGraphPassesFlowCtrlPass, flow_ctrl_pass_success_var_node_add_before) {
  120. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("FlowCtrlPassSuccess");
  121. graph->SetNeedIteration(true);
  122. // Create graph
  123. MakeGraph(graph);
  124. graph->TopologicalSorting();
  125. AddSessionVariables();
  126. FlowCtrlPass flow_ctrl_pass;
  127. NodePtr loop_cond_node = flow_ctrl_pass.AddVariableNode(graph, NODE_NAME_FLOWCTRL_LOOP_COND);
  128. EXPECT_NE(loop_cond_node, nullptr);
  129. NodePtr loop_increment_node = flow_ctrl_pass.AddVariableNode(graph, NODE_NAME_FLOWCTRL_LOOP_INCREMENT);
  130. EXPECT_NE(loop_increment_node, nullptr);
  131. NodePtr loop_reset_node = flow_ctrl_pass.AddVariableNode(graph, NODE_NAME_FLOWCTRL_LOOP_RESETVALUE);
  132. EXPECT_NE(loop_reset_node, nullptr);
  133. NodePtr iter_per_loop_node = flow_ctrl_pass.AddVariableNode(graph, NODE_NAME_FLOWCTRL_LOOP_PER_ITER);
  134. EXPECT_NE(iter_per_loop_node, nullptr);
  135. Status ret = flow_ctrl_pass.Run(graph);
  136. EXPECT_EQ(ret, ge::SUCCESS);
  137. }
  138. TEST_F(UtestGraphPassesFlowCtrlPass, flow_ctrl_pass_not_train) {
  139. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("TestNotChange");
  140. graph->SetNeedIteration(false);
  141. FlowCtrlPass flow_ctrl_pass;
  142. Status ret = flow_ctrl_pass.Run(graph);
  143. EXPECT_EQ(ret, NOT_CHANGED);
  144. }
  145. TEST_F(UtestGraphPassesFlowCtrlPass, add_fpbp_iterator_ctrl_without_var) {
  146. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("TestNotChange");
  147. graph->SetNeedIteration(true);
  148. // Create graph
  149. MakeGraph(graph);
  150. graph->TopologicalSorting();
  151. // must have NODE_NAME_FLOWCTRL_LOOP_PER_ITER
  152. ge::GeTensorDesc tensor_desc(ge::GeShape({1}), ge::FORMAT_NHWC, ge::DT_UINT64);
  153. uint8_t *dev_ptr = nullptr;
  154. EXPECT_EQ(SUCCESS, ge::VarManager::Instance(0)->SetVarAddr(NODE_NAME_FLOWCTRL_LOOP_PER_ITER, tensor_desc,
  155. dev_ptr, RT_MEMORY_HBM));
  156. // not add var
  157. FlowCtrlPass flow_ctrl_pass;
  158. Status ret = flow_ctrl_pass.Run(graph);
  159. EXPECT_NE(ret, ge::SUCCESS);
  160. }
  161. TEST_F(UtestGraphPassesFlowCtrlPass, run_add_special_node_iterator_ctrl_no_inanchor) {
  162. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_WITHOUT_LOOP_PER_ITER");
  163. graph->SetNeedIteration(true);
  164. // Create graph
  165. MakeGraph(graph);
  166. graph->TopologicalSorting();
  167. AddSessionVariables();
  168. FlowCtrlPass flow_ctrl_pass;
  169. NodePtr getnext_node = graph->FindNode("IteratorGetNext");
  170. NodePtr memcpy_node = graph->FindNode("MemcpyAsync");
  171. GraphUtils::RemoveEdge(getnext_node->GetOutDataAnchor(0), memcpy_node->GetInDataAnchor(0));
  172. Status ret = flow_ctrl_pass.Run(graph);
  173. EXPECT_NE(ret, ge::SUCCESS);
  174. }
  175. TEST_F(UtestGraphPassesFlowCtrlPass, add_fpbp_iterator_ctrl_without_loop_cond) {
  176. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_WITHOUT_LOOP_COND");
  177. graph->SetNeedIteration(true);
  178. // Create graph
  179. MakeGraph(graph);
  180. graph->TopologicalSorting();
  181. std::set<std::string> var_list = {
  182. NODE_NAME_FLOWCTRL_LOOP_PER_ITER,
  183. NODE_NAME_FLOWCTRL_LOOP_INCREMENT,
  184. NODE_NAME_FLOWCTRL_LOOP_RESETVALUE,
  185. NODE_NAME_GLOBAL_STEP,
  186. };
  187. // must have NODE_NAME_FLOWCTRL_LOOP_PER_ITER
  188. ge::GeTensorDesc tensor_desc(ge::GeShape({1}), ge::FORMAT_NHWC, ge::DT_UINT64);
  189. uint8_t *dev_ptr = nullptr;
  190. for (std::string var_name : var_list) {
  191. EXPECT_EQ(SUCCESS, ge::VarManager::Instance(0)->SetVarAddr(var_name, tensor_desc, dev_ptr, RT_MEMORY_HBM));
  192. }
  193. // not add var
  194. FlowCtrlPass flow_ctrl_pass;
  195. NodePtr pre_node = graph->FindNode("NetOutput");
  196. Status ret = flow_ctrl_pass.AddFpBpIteratorCtrl(graph, pre_node);
  197. EXPECT_EQ(ret, FAILED);
  198. }
  199. TEST_F(UtestGraphPassesFlowCtrlPass, add_fpbp_iterator_ctrl_without_loop_increment) {
  200. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_WITHOUT_LOOP_INCREMENT");
  201. graph->SetNeedIteration(true);
  202. // Create graph
  203. MakeGraph(graph);
  204. graph->TopologicalSorting();
  205. std::set<std::string> var_list = {
  206. NODE_NAME_FLOWCTRL_LOOP_PER_ITER,
  207. NODE_NAME_FLOWCTRL_LOOP_COND,
  208. NODE_NAME_FLOWCTRL_LOOP_RESETVALUE,
  209. NODE_NAME_GLOBAL_STEP,
  210. };
  211. // must have NODE_NAME_FLOWCTRL_LOOP_PER_ITER
  212. ge::GeTensorDesc tensor_desc(ge::GeShape({1}), ge::FORMAT_NHWC, ge::DT_UINT64);
  213. uint8_t *dev_ptr = nullptr;
  214. for (std::string var_name : var_list) {
  215. EXPECT_EQ(SUCCESS, ge::VarManager::Instance(0)->SetVarAddr(var_name, tensor_desc, dev_ptr, RT_MEMORY_HBM));
  216. }
  217. // not add var
  218. FlowCtrlPass flow_ctrl_pass;
  219. NodePtr pre_node = graph->FindNode("NetOutput");
  220. Status ret = flow_ctrl_pass.AddFpBpIteratorCtrl(graph, pre_node);
  221. EXPECT_EQ(ret, FAILED);
  222. }
  223. TEST_F(UtestGraphPassesFlowCtrlPass, add_fpbp_iterator_ctrl_without_loop_reset_value) {
  224. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_WITHOUT_LOOP_RESETVALUE");
  225. graph->SetNeedIteration(true);
  226. // Create graph
  227. MakeGraph(graph);
  228. graph->TopologicalSorting();
  229. std::set<std::string> var_list = {
  230. NODE_NAME_FLOWCTRL_LOOP_PER_ITER,
  231. NODE_NAME_FLOWCTRL_LOOP_COND,
  232. NODE_NAME_FLOWCTRL_LOOP_INCREMENT,
  233. NODE_NAME_GLOBAL_STEP,
  234. };
  235. // must have NODE_NAME_FLOWCTRL_LOOP_PER_ITER
  236. ge::GeTensorDesc tensor_desc(ge::GeShape({1}), ge::FORMAT_NHWC, ge::DT_UINT64);
  237. uint8_t *dev_ptr = nullptr;
  238. for (std::string var_name : var_list) {
  239. EXPECT_EQ(SUCCESS, ge::VarManager::Instance(0)->SetVarAddr(var_name, tensor_desc, dev_ptr, RT_MEMORY_HBM));
  240. }
  241. // not add var
  242. FlowCtrlPass flow_ctrl_pass;
  243. NodePtr pre_node = graph->FindNode("NetOutput");
  244. Status ret = flow_ctrl_pass.AddFpBpIteratorCtrl(graph, pre_node);
  245. EXPECT_EQ(ret, FAILED);
  246. }
  247. TEST_F(UtestGraphPassesFlowCtrlPass, add_fpbp_iterator_ctrl_without_loop_ref_iter) {
  248. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_WITHOUT_LOOP_PER_ITER");
  249. graph->SetNeedIteration(true);
  250. // Create graph
  251. MakeGraph(graph);
  252. graph->TopologicalSorting();
  253. std::set<std::string> var_list = {
  254. NODE_NAME_FLOWCTRL_LOOP_COND,
  255. NODE_NAME_FLOWCTRL_LOOP_INCREMENT,
  256. NODE_NAME_FLOWCTRL_LOOP_RESETVALUE,
  257. NODE_NAME_GLOBAL_STEP,
  258. };
  259. // must have NODE_NAME_FLOWCTRL_LOOP_PER_ITER
  260. ge::GeTensorDesc tensor_desc(ge::GeShape({1}), ge::FORMAT_NHWC, ge::DT_UINT64);
  261. uint8_t *dev_ptr = nullptr;
  262. for (std::string var_name : var_list) {
  263. EXPECT_EQ(SUCCESS, ge::VarManager::Instance(0)->SetVarAddr(var_name, tensor_desc, dev_ptr, RT_MEMORY_HBM));
  264. }
  265. FlowCtrlPass flow_ctrl_pass;
  266. NodePtr pre_node = graph->FindNode("NetOutput");
  267. Status ret = flow_ctrl_pass.AddFpBpIteratorCtrl(graph, pre_node);
  268. EXPECT_EQ(ret, FAILED);
  269. }
  270. TEST_F(UtestGraphPassesFlowCtrlPass, add_special_node_iterator_ctrl_without_loop_cond) {
  271. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_WITHOUT_LOOP_COND");
  272. graph->SetNeedIteration(true);
  273. // Create graph
  274. MakeGraph(graph);
  275. graph->TopologicalSorting();
  276. std::set<std::string> var_list = {
  277. NODE_NAME_FLOWCTRL_LOOP_PER_ITER,
  278. NODE_NAME_FLOWCTRL_LOOP_INCREMENT,
  279. NODE_NAME_FLOWCTRL_LOOP_RESETVALUE,
  280. NODE_NAME_GLOBAL_STEP,
  281. };
  282. // must have NODE_NAME_FLOWCTRL_LOOP_PER_ITER
  283. ge::GeTensorDesc tensor_desc(ge::GeShape({1}), ge::FORMAT_NHWC, ge::DT_UINT64);
  284. uint8_t *dev_ptr = nullptr;
  285. for (std::string var_name : var_list) {
  286. EXPECT_EQ(SUCCESS, ge::VarManager::Instance(0)->SetVarAddr(var_name, tensor_desc, dev_ptr, RT_MEMORY_HBM));
  287. }
  288. FlowCtrlPass flow_ctrl_pass;
  289. NodePtr iter_per_loop_node = flow_ctrl_pass.AddVariableNode(graph, NODE_NAME_FLOWCTRL_LOOP_PER_ITER);
  290. EXPECT_NE(iter_per_loop_node, nullptr);
  291. NodePtr memcpy_node = graph->FindNode("MemcpyAsync");
  292. Status ret = flow_ctrl_pass.AddSpecialNodeIteratorCtrl(graph, memcpy_node);
  293. EXPECT_EQ(ret, FAILED);
  294. }
  295. TEST_F(UtestGraphPassesFlowCtrlPass, add_special_node_iterator_ctrl_without_loop_ref_iter) {
  296. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_WITHOUT_LOOP_PER_ITER");
  297. graph->SetNeedIteration(true);
  298. // Create graph
  299. MakeGraph(graph);
  300. graph->TopologicalSorting();
  301. std::set<std::string> var_list = {
  302. NODE_NAME_FLOWCTRL_LOOP_COND,
  303. NODE_NAME_FLOWCTRL_LOOP_INCREMENT,
  304. NODE_NAME_FLOWCTRL_LOOP_RESETVALUE,
  305. NODE_NAME_GLOBAL_STEP,
  306. };
  307. ge::GeTensorDesc tensor_desc(ge::GeShape({1}), ge::FORMAT_NHWC, ge::DT_UINT64);
  308. uint8_t *dev_ptr = nullptr;
  309. for (std::string var_name : var_list) {
  310. EXPECT_EQ(SUCCESS, ge::VarManager::Instance(0)->SetVarAddr(var_name, tensor_desc, dev_ptr, RT_MEMORY_HBM));
  311. }
  312. FlowCtrlPass flow_ctrl_pass;
  313. NodePtr loop_cond_node = flow_ctrl_pass.AddVariableNode(graph, NODE_NAME_FLOWCTRL_LOOP_COND);
  314. EXPECT_NE(loop_cond_node, nullptr);
  315. NodePtr memcpy_node = graph->FindNode("MemcpyAsync");
  316. Status ret = flow_ctrl_pass.AddSpecialNodeIteratorCtrl(graph, memcpy_node);
  317. EXPECT_EQ(ret, FAILED);
  318. }
  319. TEST_F(UtestGraphPassesFlowCtrlPass, add_special_node_iterator_ctrl_no_inchor) {
  320. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_WITHOUT_LOOP_PER_ITER");
  321. graph->SetNeedIteration(true);
  322. // Create graph
  323. MakeGraph(graph);
  324. graph->TopologicalSorting();
  325. FlowCtrlPass flow_ctrl_pass;
  326. NodePtr getnext_node = graph->FindNode("IteratorGetNext");
  327. NodePtr memcpy_node = graph->FindNode("MemcpyAsync");
  328. GraphUtils::RemoveEdge(getnext_node->GetOutDataAnchor(0), memcpy_node->GetInDataAnchor(0));
  329. Status ret = flow_ctrl_pass.AddSpecialNodeIteratorCtrl(graph, memcpy_node);
  330. EXPECT_EQ(ret, FAILED);
  331. }
  332. TEST_F(UtestGraphPassesFlowCtrlPass, insert_assign_op_success) {
  333. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_InsertAssignOp");
  334. FlowCtrlPass flow_ctrl_pass;
  335. GeTensorDesc tmp_geT_tensor_desc;
  336. NodePtr ref_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {tmp_geT_tensor_desc});
  337. NodePtr value_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {tmp_geT_tensor_desc});
  338. NodePtr add_node = flow_ctrl_pass.InsertAssignOp(graph, ASSIGNADD, "add_node", ref_node, value_node);
  339. EXPECT_NE(add_node, nullptr);
  340. }
  341. TEST_F(UtestGraphPassesFlowCtrlPass, insert_assign_op_ref_node_no_outanchor) {
  342. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_InsertAssignOp");
  343. FlowCtrlPass flow_ctrl_pass;
  344. GeTensorDesc tmp_geT_tensor_desc;
  345. NodePtr ref_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {});
  346. NodePtr value_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {tmp_geT_tensor_desc});
  347. NodePtr add_node = flow_ctrl_pass.InsertAssignOp(graph, ASSIGNADD, "add_node", ref_node, value_node);
  348. EXPECT_EQ(add_node, nullptr);
  349. }
  350. TEST_F(UtestGraphPassesFlowCtrlPass, insert_assign_op_value_node_no_outanchor) {
  351. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_InsertAssignOp");
  352. FlowCtrlPass flow_ctrl_pass;
  353. GeTensorDesc tmp_geT_tensor_desc;
  354. NodePtr ref_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {tmp_geT_tensor_desc});
  355. NodePtr value_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {});
  356. NodePtr add_node = flow_ctrl_pass.InsertAssignOp(graph, ASSIGNADD, "add_node", ref_node, value_node);
  357. EXPECT_EQ(add_node, nullptr);
  358. }
  359. TEST_F(UtestGraphPassesFlowCtrlPass, create_iter_ctrl_false_branch_insert_assign_op_failed) {
  360. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("Test_CreateIterCtrlFalseBranch_InsertAssignOp_FAILED");
  361. FlowCtrlPass flow_ctrl_pass;
  362. GeTensorDesc tmp_geT_tensor_desc;
  363. NodePtr ref_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {tmp_geT_tensor_desc});
  364. NodePtr value_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {});
  365. NodePtr switch_node = flow_ctrl_pass.InsertOp(graph, STREAMSWITCH, "switch_node", {}, {});
  366. Status ret = flow_ctrl_pass.CreateIterCtrlFalseBranch(graph, ref_node, value_node, switch_node);
  367. EXPECT_EQ(ret, FAILED);
  368. }
  369. TEST_F(UtestGraphPassesFlowCtrlPass, create_iter_ctrl_true_branch_insert_assign_op_failed) {
  370. ge::ComputeGraphPtr graph = make_shared<ge::ComputeGraph>("CreateIterCtrlTrueBranch_InsertAssignOp_FAILED");
  371. FlowCtrlPass flow_ctrl_pass;
  372. GeTensorDesc tmp_geT_tensor_desc;
  373. NodePtr ref_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {tmp_geT_tensor_desc});
  374. NodePtr value_node = flow_ctrl_pass.InsertOp(graph, VARIABLE, "ref_node", {}, {});
  375. NodePtr switch_node = flow_ctrl_pass.InsertOp(graph, STREAMSWITCH, "switch_node", {}, {});
  376. Status ret = flow_ctrl_pass.CreateIterCtrlTrueBranch(graph, ref_node, value_node, switch_node);
  377. EXPECT_EQ(ret, FAILED);
  378. }
  379. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示