You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

parallel_group_pass_unittest.cc 18 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include <cstdint>
  18. #include <string>
  19. #define private public
  20. #include "inc/graph/ge_local_context.h"
  21. #include "inc/external/ge/ge_api_types.h"
  22. #include "common/ge_inner_error_codes.h"
  23. #include "inc/pass_manager.h"
  24. #include "utils/graph_utils.h"
  25. #include "graph/passes/parallel_group_pass.h"
  26. #undef private
  27. namespace ge {
  28. namespace {
  29. class UtestGraphPassesParallelGgroupPass : public testing::Test {
  30. protected:
  31. UtestGraphPassesParallelGgroupPass() {
  32. graph_ = std::make_shared<ComputeGraph>("test");
  33. sub_graph_ = std::make_shared<ComputeGraph>("test_subgraph");
  34. vector<int64_t> shape_vec{1, 1, 1, 1};
  35. GeShape shape = GeShape(shape_vec);
  36. default_tensor_desc_ = std::make_shared<GeTensorDesc>();
  37. default_tensor_desc_->SetShape(shape);
  38. default_tensor_desc_->SetFormat(FORMAT_NCHW);
  39. default_tensor_desc_->SetDataType(DT_FLOAT);
  40. }
  41. NodePtr NewNode(const std::string &name, const std::string &type,
  42. int input_cnt, int output_cnt, bool isSubgraph = false) {
  43. OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
  44. for (int i = 0; i < input_cnt; ++i) {
  45. op_desc->AddInputDesc(default_tensor_desc_->Clone());
  46. }
  47. for (int i = 0; i < output_cnt; ++i) {
  48. op_desc->AddOutputDesc(default_tensor_desc_->Clone());
  49. }
  50. NodePtr node = nullptr;
  51. if (isSubgraph) {
  52. node = sub_graph_->AddNode(op_desc);
  53. (void)node->SetOwnerComputeGraph(sub_graph_);
  54. } else {
  55. node = graph_->AddNode(op_desc);
  56. (void)node->SetOwnerComputeGraph(graph_);
  57. }
  58. return node;
  59. }
  60. void BuildDefaultGraph() {
  61. /// input
  62. /// \.
  63. /// sqrt pred
  64. /// \ /
  65. /// cast
  66. /// / \.
  67. /// switch_t switch_f
  68. /// | |
  69. /// F T
  70. /// | |
  71. /// Merge
  72. /// |
  73. /// relu
  74. /// |
  75. /// sqrt1
  76. input_node_ = NewNode("input", RELU, 0, 1);
  77. sqrt_node_ = NewNode("sqrt", SQRT, 1, 1);
  78. pred_node_ = NewNode("pred", GREATER, 2, 1);
  79. cast_node_ = NewNode("cast", CAST, 2, 2);
  80. AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  81. switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1);
  82. AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true);
  83. switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1);
  84. AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false);
  85. output_false_node_ = NewNode("false_output", RELU, 1, 1);
  86. AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  87. output_true_node_ = NewNode("true_output", RELU, 1, 1);
  88. AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  89. merge_node_ = NewNode("merge", STREAMMERGE, 2, 1);
  90. relu_node_ = NewNode("relu", RELU, 1, 1);
  91. sqrt_node1_ = NewNode("sqrt1", SQRT, 1, 1);
  92. AttrUtils::SetStr(sqrt_node1_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  93. GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0));
  94. GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0));
  95. GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1));
  96. GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0));
  97. GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0));
  98. GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
  99. GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0));
  100. GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0));
  101. GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
  102. GraphUtils::AddEdge(merge_node_->GetOutDataAnchor(0), relu_node_->GetInDataAnchor(0));
  103. GraphUtils::AddEdge(relu_node_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0));
  104. output_false_node_->GetOpDesc()->SetIsInputConst({false});
  105. output_true_node_->GetOpDesc()->SetIsInputConst({false});
  106. }
  107. void BuildDefaultGraph1() {
  108. /// input
  109. /// \.
  110. /// sqrt pred
  111. /// \ /
  112. /// Switch
  113. /// | |
  114. /// ----F T----
  115. /// \ | / \.
  116. /// \ Merge1 Merge2
  117. /// \_________|
  118. input_node_ = NewNode("input", RELU, 0, 1);
  119. AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  120. pred_node_ = NewNode("pred", GREATER, 2, 1);
  121. sqrt_node_ = NewNode("sqrt", SQRT, 1, 1);
  122. cast_node_ = NewNode("cast", CAST, 2, 2);
  123. switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1);
  124. AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true);
  125. switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1);
  126. AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false);
  127. output_false_node_ = NewNode("false_output", RELU, 1, 2);
  128. AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  129. output_true_node_ = NewNode("true_output", RELU, 1, 2);
  130. AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  131. merge_node_ = NewNode("merge", STREAMMERGE, 2, 1);
  132. merge_node1_ = NewNode("merge1", STREAMMERGE, 2, 1);
  133. GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0));
  134. GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0));
  135. GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1));
  136. GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0));
  137. GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0));
  138. GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
  139. GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0));
  140. GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0));
  141. GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
  142. GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(0));
  143. GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(1));
  144. output_false_node_->GetOpDesc()->SetIsInputConst({false});
  145. output_true_node_->GetOpDesc()->SetIsInputConst({false});
  146. }
  147. void BuildDefaultGraph2() {
  148. /// input input1
  149. /// \ \.
  150. /// sqrt pred sqrt1 pred1
  151. /// \ / \ /
  152. /// Switch Switch1
  153. /// | | _______|
  154. /// | | /
  155. /// ____F T____
  156. /// \ | / \.
  157. /// \ Merge1 Merge2
  158. /// \__________|
  159. input_node_ = NewNode("input", RELU, 0, 2);
  160. input_node1_ = NewNode("input_1", RELU, 0, 2);
  161. sqrt_node_ = NewNode("sqrt", SQRT, 1, 1);
  162. pred_node_ = NewNode("pred", GREATER, 2, 1);
  163. sqrt_node1_ = NewNode("sqrt_1", SQRT, 1, 1);
  164. pred_node1_ = NewNode("pred_1", LESS, 2, 1);
  165. cast_node_ = NewNode("cast", CAST, 2, 2);
  166. cast_node1_ = NewNode("cast_1", CAST, 2, 2);
  167. AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  168. AttrUtils::SetStr(input_node1_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "2");
  169. switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1);
  170. AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true);
  171. switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1);
  172. AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false);
  173. switch_node1_t = NewNode("switch1_t", STREAMSWITCH, 1, 1);
  174. AttrUtils::SetBool(switch_node1_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true);
  175. switch_node1_f = NewNode("switch1_f", STREAMSWITCH, 1, 1);
  176. AttrUtils::SetBool(switch_node1_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false);
  177. output_false_node_ = NewNode("false_output", RELU, 2, 2);
  178. AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  179. output_true_node_ = NewNode("true_output", RELU, 2, 2);
  180. AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "2");
  181. merge_node_ = NewNode("merge", STREAMMERGE, 2, 1);
  182. merge_node1_ = NewNode("merge1", STREAMMERGE, 2, 1);
  183. GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0));
  184. GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0));
  185. GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1));
  186. GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0));
  187. GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0));
  188. GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
  189. GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0));
  190. GraphUtils::AddEdge(input_node1_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0));
  191. GraphUtils::AddEdge(pred_node1_->GetOutDataAnchor(0), cast_node1_->GetInDataAnchor(0));
  192. GraphUtils::AddEdge(sqrt_node1_->GetOutDataAnchor(0), cast_node1_->GetInDataAnchor(1));
  193. GraphUtils::AddEdge(cast_node1_->GetOutDataAnchor(0), switch_node1_t->GetInDataAnchor(0));
  194. GraphUtils::AddEdge(cast_node1_->GetOutDataAnchor(1), switch_node1_f->GetInDataAnchor(0));
  195. GraphUtils::AddEdge(switch_node1_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(1));
  196. GraphUtils::AddEdge(switch_node1_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(1));
  197. GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0));
  198. GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
  199. GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(0));
  200. GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), merge_node1_->GetInDataAnchor(1));
  201. output_false_node_->GetOpDesc()->SetIsInputConst({false});
  202. output_true_node_->GetOpDesc()->SetIsInputConst({false});
  203. }
  204. void BuildDefaultGraph3() {
  205. /// input
  206. /// \
  207. /// sqrt pred
  208. /// \ /
  209. /// Switch
  210. /// | |
  211. /// F T ------
  212. /// / \_/_ \
  213. /// / / \ \
  214. /// Merge sqrt2 sqrt3
  215. /// / \ \
  216. /// sqrt1 \ relu
  217. /// \ \
  218. /// \ sqrt4
  219. /// \ /
  220. /// Merge1
  221. input_node_ = NewNode("input", RELU, 0, 1);
  222. AttrUtils::SetStr(input_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  223. pred_node_ = NewNode("pred", GREATER, 2, 1);
  224. sqrt_node_ = NewNode("sqrt", SQRT, 1, 1);
  225. cast_node_ = NewNode("cast", CAST, 2, 2);
  226. switch_node_t = NewNode("switch_t", STREAMSWITCH, 1, 1);
  227. AttrUtils::SetBool(switch_node_t->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, true);
  228. switch_node_f = NewNode("switch_f", STREAMSWITCH, 1, 1);
  229. AttrUtils::SetBool(switch_node_f->GetOpDesc(), ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, false);
  230. output_false_node_ = NewNode("false_output", RELU, 1, 2);
  231. AttrUtils::SetStr(output_false_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  232. output_true_node_ = NewNode("true_output", RELU, 1, 2);
  233. AttrUtils::SetStr(output_true_node_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  234. merge_node_ = NewNode("merge", STREAMMERGE, 2, 1);
  235. sqrt_node1_ = NewNode("sqrt1", SQRT, 1, 1);
  236. AttrUtils::SetStr(sqrt_node1_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  237. sqrt_node2_ = NewNode("sqrt2", SQRT, 1, 1);
  238. AttrUtils::SetStr(sqrt_node2_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  239. sqrt_node3_ = NewNode("sqrt3", SQRT, 1, 1);
  240. relu_node_ = NewNode("relu", RELU, 1, 1);
  241. sqrt_node4_ = NewNode("sqrt4", SQRT, 1, 1);
  242. AttrUtils::SetStr(sqrt_node4_->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  243. merge_node1_ = NewNode("merge1", STREAMMERGE, 2, 1);
  244. GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), sqrt_node_->GetInDataAnchor(0));
  245. GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(0));
  246. GraphUtils::AddEdge(sqrt_node_->GetOutDataAnchor(0), cast_node_->GetInDataAnchor(1));
  247. GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(0), switch_node_t->GetInDataAnchor(0));
  248. GraphUtils::AddEdge(cast_node_->GetOutDataAnchor(1), switch_node_f->GetInDataAnchor(0));
  249. GraphUtils::AddEdge(switch_node_f->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
  250. GraphUtils::AddEdge(switch_node_t->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0));
  251. GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0));
  252. GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
  253. GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(1), sqrt_node2_->GetInDataAnchor(0));
  254. GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(1), sqrt_node3_->GetInDataAnchor(0));
  255. GraphUtils::AddEdge(merge_node_->GetOutDataAnchor(0), sqrt_node1_->GetInDataAnchor(0));
  256. GraphUtils::AddEdge(sqrt_node3_->GetOutDataAnchor(0), relu_node_->GetInDataAnchor(0));
  257. GraphUtils::AddEdge(relu_node_->GetOutDataAnchor(0), sqrt_node4_->GetInDataAnchor(0));
  258. GraphUtils::AddEdge(sqrt_node2_->GetOutDataAnchor(0), merge_node1_->GetInDataAnchor(0));
  259. GraphUtils::AddEdge(sqrt_node4_->GetOutDataAnchor(0), merge_node1_->GetInDataAnchor(1));
  260. output_false_node_->GetOpDesc()->SetIsInputConst({false});
  261. output_true_node_->GetOpDesc()->SetIsInputConst({false});
  262. }
  263. ComputeGraphPtr graph_;
  264. ComputeGraphPtr sub_graph_;
  265. GeTensorDescPtr default_tensor_desc_;
  266. ParallelGroupPass pass_;
  267. NodePtr pred_node_;
  268. NodePtr pred_node1_;
  269. NodePtr cast_node_;
  270. NodePtr cast_node1_;
  271. NodePtr sqrt_node_;
  272. NodePtr sqrt_node1_;
  273. NodePtr sqrt_node2_;
  274. NodePtr sqrt_node3_;
  275. NodePtr sqrt_node4_;
  276. NodePtr input_node_;
  277. NodePtr input_node1_;
  278. NodePtr switch_node_t;
  279. NodePtr switch_node_f;
  280. NodePtr switch_node1_t;
  281. NodePtr switch_node1_f;
  282. NodePtr output_false_node_;
  283. NodePtr output_true_node_;
  284. NodePtr merge_node_;
  285. NodePtr merge_node1_;
  286. NodePtr relu_node_;
  287. };
  288. TEST_F(UtestGraphPassesParallelGgroupPass, null_graph) {
  289. ComputeGraphPtr graph = nullptr;
  290. auto ret = pass_.Run(graph);
  291. EXPECT_EQ(ret, PARAM_INVALID);
  292. }
  293. TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph) {
  294. BuildDefaultGraph();
  295. auto ret = pass_.Run(graph_);
  296. EXPECT_EQ(ret, GRAPH_SUCCESS);
  297. EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(cast_node_->GetInControlAnchor()));
  298. EXPECT_EQ(true, merge_node_->GetOutControlAnchor()->IsLinkedWith(sqrt_node1_->GetInControlAnchor()));
  299. EXPECT_EQ(false, output_false_node_->GetOutControlAnchor()->IsLinkedWith(output_true_node_->GetInControlAnchor()));
  300. }
  301. TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph1) {
  302. BuildDefaultGraph1();
  303. auto ret = pass_.Run(graph_);
  304. EXPECT_EQ(ret, GRAPH_SUCCESS);
  305. EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(cast_node_->GetInControlAnchor()));
  306. }
  307. TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph2) {
  308. BuildDefaultGraph2();
  309. auto ret = pass_.Run(graph_);
  310. EXPECT_EQ(ret, GRAPH_SUCCESS);
  311. EXPECT_EQ(true, input_node_->GetOutControlAnchor()->IsLinkedWith(cast_node_->GetInControlAnchor()));
  312. EXPECT_EQ(true, input_node1_->GetOutControlAnchor()->IsLinkedWith(cast_node1_->GetInControlAnchor()));
  313. }
  314. TEST_F(UtestGraphPassesParallelGgroupPass, normal_graph3) {
  315. std::map<std::string, std::string> options;
  316. options.emplace(OPTION_GRAPH_RUN_MODE, "1");
  317. GetThreadLocalContext().SetGraphOption(options);
  318. BuildDefaultGraph3();
  319. auto ret = pass_.Run(graph_);
  320. EXPECT_EQ(ret, GRAPH_SUCCESS);
  321. EXPECT_EQ(true, merge_node1_->GetOutControlAnchor()->IsLinkedWith(sqrt_node1_->GetInControlAnchor()));
  322. }
  323. TEST_F(UtestGraphPassesParallelGgroupPass, normal_subgraph) {
  324. BuildDefaultGraph1();
  325. NodePtr input_node1 = NewNode("input1", RELU, 0, 1, true);
  326. NodePtr input_node2 = NewNode("input2", RELU, 0, 1, true);
  327. NodePtr add = NewNode("add", ADD, 2, 1, true);
  328. AttrUtils::SetStr(input_node1->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  329. AttrUtils::SetStr(input_node2->GetOpDesc(), ATTR_NAME_PARALLEL_GROUP, "1");
  330. sub_graph_->SetParentNode(input_node_);
  331. sub_graph_->SetParentGraph(graph_);
  332. auto ret = graph_->AddSubgraph(sub_graph_->GetName(), sub_graph_);
  333. EXPECT_EQ(ret, GRAPH_SUCCESS);
  334. ret = input_node_->GetOpDesc()->AddSubgraphName(sub_graph_->GetName());
  335. EXPECT_EQ(ret, GRAPH_SUCCESS);
  336. ret = input_node_->GetOpDesc()->SetSubgraphInstanceName(0, sub_graph_->GetName());
  337. EXPECT_EQ(ret, GRAPH_SUCCESS);
  338. ret = pass_.Run(sub_graph_);
  339. EXPECT_EQ(ret, GRAPH_SUCCESS);
  340. ret = pass_.Run(graph_);
  341. EXPECT_EQ(ret, GRAPH_SUCCESS);
  342. }
  343. } // namespace
  344. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示