You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

switch_op_pass_unittest.cc 22 kB

5 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include "omg/omg_inner_types.h"
  18. #define protected public
  19. #define private public
  20. #include "common/debug/log.h"
  21. #include "common/debug/memory_dumper.h"
  22. #include "common/op/attr_value_util.h"
  23. #include "common/types.h"
  24. #include "graph/debug/ge_attr_define.h"
  25. #include "graph/graph.h"
  26. #include "inc/pass_manager.h"
  27. #undef protected
  28. #undef private
  29. using namespace testing;
  30. using namespace ge;
  31. class UtestGraphPassesSwitchOpPass : public testing::Test {
  32. protected:
  33. void SetUp() {}
  34. void TearDown() {}
  35. public:
  36. void make_graph(ComputeGraphPtr graph, bool match = true) {
  37. GeTensorDesc bool_tensor_desc(GeShape(), ge::FORMAT_NCHW, ge::DT_BOOL);
  38. GeTensorDesc int_tensor_desc(GeShape(), ge::FORMAT_NCHW, ge::DT_INT32);
  39. GeTensorDesc scalar_tensor_desc(GeShape(), ge::FORMAT_NCHW, ge::DT_FLOAT);
  40. auto xOpDef = std::make_shared<OpDesc>("x", VARIABLEV2);
  41. xOpDef->AddOutputDesc(scalar_tensor_desc);
  42. auto xNode = graph->AddNode(xOpDef);
  43. auto yOpDef = std::make_shared<OpDesc>("y", VARIABLEV2);
  44. yOpDef->AddOutputDesc(scalar_tensor_desc);
  45. auto yNode = graph->AddNode(yOpDef);
  46. auto zOpDef = std::make_shared<OpDesc>("z", VARIABLEV2);
  47. zOpDef->AddOutputDesc(scalar_tensor_desc);
  48. auto zNode = graph->AddNode(zOpDef);
  49. auto condOpDef = std::make_shared<OpDesc>("Less", "Less");
  50. condOpDef->AddInputDesc(scalar_tensor_desc);
  51. condOpDef->AddInputDesc(scalar_tensor_desc);
  52. condOpDef->AddOutputDesc(bool_tensor_desc);
  53. auto condNode = graph->AddNode(condOpDef);
  54. auto switch_op_def1 = std::make_shared<OpDesc>("Add/Switch", SWITCH);
  55. switch_op_def1->AddInputDesc(scalar_tensor_desc);
  56. switch_op_def1->AddInputDesc(bool_tensor_desc);
  57. switch_op_def1->AddOutputDesc(scalar_tensor_desc);
  58. switch_op_def1->AddOutputDesc(scalar_tensor_desc);
  59. auto switch_node1 = graph->AddNode(switch_op_def1);
  60. auto switch_op_def2 = std::make_shared<OpDesc>("Add/Switch_1", SWITCH);
  61. switch_op_def2->AddInputDesc(scalar_tensor_desc);
  62. switch_op_def2->AddInputDesc(bool_tensor_desc);
  63. switch_op_def2->AddOutputDesc(scalar_tensor_desc);
  64. switch_op_def2->AddOutputDesc(scalar_tensor_desc);
  65. auto switch_node2 = graph->AddNode(switch_op_def2);
  66. auto switch_op_def3 = std::make_shared<OpDesc>("Square/Switch", SWITCH);
  67. switch_op_def3->AddInputDesc(scalar_tensor_desc);
  68. switch_op_def3->AddInputDesc(bool_tensor_desc);
  69. switch_op_def3->AddOutputDesc(scalar_tensor_desc);
  70. switch_op_def3->AddOutputDesc(scalar_tensor_desc);
  71. auto switch_node3 = graph->AddNode(switch_op_def3);
  72. auto addOpDef = std::make_shared<OpDesc>("Add", "ADD");
  73. addOpDef->AddInputDesc(scalar_tensor_desc);
  74. addOpDef->AddInputDesc(scalar_tensor_desc);
  75. addOpDef->AddOutputDesc(scalar_tensor_desc);
  76. auto addNode = graph->AddNode(addOpDef);
  77. auto mergeOpDef = std::make_shared<OpDesc>("Merge", "Merge");
  78. mergeOpDef->AddInputDesc(scalar_tensor_desc);
  79. mergeOpDef->AddInputDesc(scalar_tensor_desc);
  80. mergeOpDef->AddOutputDesc(scalar_tensor_desc);
  81. mergeOpDef->AddOutputDesc(int_tensor_desc);
  82. auto mergeNode = graph->AddNode(mergeOpDef);
  83. auto output_op_def = std::make_shared<OpDesc>("NetOutput", "NetOutput");
  84. output_op_def->AddInputDesc(scalar_tensor_desc);
  85. output_op_def->AddOutputDesc(scalar_tensor_desc);
  86. auto output_node = graph->AddNode(output_op_def);
  87. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), condNode->GetInDataAnchor(0));
  88. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), condNode->GetInDataAnchor(1));
  89. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), switch_node1->GetInDataAnchor(0));
  90. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switch_node1->GetInDataAnchor(1));
  91. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), switch_node2->GetInDataAnchor(0));
  92. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switch_node2->GetInDataAnchor(1));
  93. (void)GraphUtils::AddEdge(zNode->GetOutDataAnchor(0), switch_node3->GetInDataAnchor(0));
  94. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switch_node3->GetInDataAnchor(1));
  95. (void)GraphUtils::AddEdge(switch_node1->GetOutDataAnchor(1), addNode->GetInDataAnchor(0));
  96. (void)GraphUtils::AddEdge(switch_node2->GetOutDataAnchor(1), addNode->GetInDataAnchor(1));
  97. (void)GraphUtils::AddEdge(addNode->GetOutDataAnchor(0), mergeNode->GetInDataAnchor(1));
  98. (void)GraphUtils::AddEdge(switch_node3->GetOutDataAnchor(0), mergeNode->GetInDataAnchor(0));
  99. (void)GraphUtils::AddEdge(mergeNode->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));
  100. }
  101. void make_graph_const(ComputeGraphPtr graph, bool match = true) {
  102. // resnet50 PolynomialDecay
  103. GeTensorDesc scalar_tensor_desc(GeShape({1, 1, 1, 1}));
  104. GeTensorDesc bool_tensor_desc(GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_BOOL);
  105. GeTensorDesc int_tensor_desc(GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_INT32);
  106. auto xOpDef = std::make_shared<OpDesc>("x", VARIABLEV2);
  107. xOpDef->AddOutputDesc(scalar_tensor_desc);
  108. auto xNode = graph->AddNode(xOpDef);
  109. auto yOpDef = std::make_shared<OpDesc>("y", "Const");
  110. yOpDef->AddOutputDesc(scalar_tensor_desc);
  111. auto yNode = graph->AddNode(yOpDef);
  112. auto zOpDef = std::make_shared<OpDesc>("z", VARIABLEV2);
  113. zOpDef->AddOutputDesc(scalar_tensor_desc);
  114. auto zNode = graph->AddNode(zOpDef);
  115. auto constOpDef = std::make_shared<OpDesc>("Const", "Const");
  116. constOpDef->AddOutputDesc(scalar_tensor_desc);
  117. auto constNode = graph->AddNode(constOpDef);
  118. auto condOpDef = std::make_shared<OpDesc>("Equal", "Equal");
  119. condOpDef->AddInputDesc(scalar_tensor_desc);
  120. condOpDef->AddInputDesc(scalar_tensor_desc);
  121. condOpDef->AddOutputDesc(bool_tensor_desc);
  122. auto condNode = graph->AddNode(condOpDef);
  123. auto identityOpDef = std::make_shared<OpDesc>("identity", "Identity");
  124. identityOpDef->AddInputDesc(bool_tensor_desc);
  125. identityOpDef->AddOutputDesc(bool_tensor_desc);
  126. auto identityNode = graph->AddNode(identityOpDef);
  127. auto switch_op_def1 = std::make_shared<OpDesc>("Switch", SWITCH);
  128. switch_op_def1->AddInputDesc(bool_tensor_desc);
  129. switch_op_def1->AddInputDesc(bool_tensor_desc);
  130. switch_op_def1->AddOutputDesc(bool_tensor_desc);
  131. switch_op_def1->AddOutputDesc(bool_tensor_desc);
  132. auto switch_node1 = graph->AddNode(switch_op_def1);
  133. auto tIdentityOpDef = std::make_shared<OpDesc>("switch_t", "Identity");
  134. tIdentityOpDef->AddInputDesc(scalar_tensor_desc);
  135. tIdentityOpDef->AddOutputDesc(scalar_tensor_desc);
  136. auto tIdentityNode = graph->AddNode(tIdentityOpDef);
  137. auto fIdentityOpDef = std::make_shared<OpDesc>("switch_f", "Identity");
  138. fIdentityOpDef->AddInputDesc(scalar_tensor_desc);
  139. fIdentityOpDef->AddOutputDesc(scalar_tensor_desc);
  140. auto fIdentityNode = graph->AddNode(fIdentityOpDef);
  141. auto switch_op_def2 = std::make_shared<OpDesc>("Switch_1", SWITCH);
  142. switch_op_def2->AddInputDesc(scalar_tensor_desc);
  143. switch_op_def2->AddInputDesc(bool_tensor_desc);
  144. switch_op_def2->AddOutputDesc(scalar_tensor_desc);
  145. switch_op_def2->AddOutputDesc(scalar_tensor_desc);
  146. auto switch_node2 = graph->AddNode(switch_op_def2);
  147. auto mulOpDef = std::make_shared<OpDesc>("truediv", "Mul");
  148. mulOpDef->AddInputDesc(scalar_tensor_desc);
  149. mulOpDef->AddInputDesc(scalar_tensor_desc);
  150. mulOpDef->AddOutputDesc(scalar_tensor_desc);
  151. auto mulNode = graph->AddNode(mulOpDef);
  152. auto ceilOpDef = std::make_shared<OpDesc>("Ceil", "Ceil");
  153. ceilOpDef->AddInputDesc(scalar_tensor_desc);
  154. ceilOpDef->AddOutputDesc(scalar_tensor_desc);
  155. auto ceilNode = graph->AddNode(ceilOpDef);
  156. auto mergeOpDef = std::make_shared<OpDesc>("Merge", "Merge");
  157. mergeOpDef->AddInputDesc(scalar_tensor_desc);
  158. mergeOpDef->AddInputDesc(scalar_tensor_desc);
  159. mergeOpDef->AddOutputDesc(scalar_tensor_desc);
  160. mergeOpDef->AddOutputDesc(int_tensor_desc);
  161. auto mergeNode = graph->AddNode(mergeOpDef);
  162. auto output_op_def = std::make_shared<OpDesc>("NetOutput", "NetOutput");
  163. output_op_def->AddInputDesc(scalar_tensor_desc);
  164. output_op_def->AddOutputDesc(scalar_tensor_desc);
  165. auto output_node = graph->AddNode(output_op_def);
  166. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), condNode->GetInDataAnchor(0));
  167. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), condNode->GetInDataAnchor(1));
  168. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), identityNode->GetInDataAnchor(0));
  169. (void)GraphUtils::AddEdge(identityNode->GetOutDataAnchor(0), switch_node1->GetInDataAnchor(0));
  170. (void)GraphUtils::AddEdge(identityNode->GetOutDataAnchor(0), switch_node1->GetInDataAnchor(1));
  171. (void)GraphUtils::AddEdge(switch_node1->GetOutDataAnchor(0), fIdentityNode->GetInDataAnchor(0));
  172. (void)GraphUtils::AddEdge(switch_node1->GetOutDataAnchor(1), tIdentityNode->GetInDataAnchor(0));
  173. (void)GraphUtils::AddEdge(fIdentityNode->GetOutControlAnchor(), zNode->GetInControlAnchor());
  174. (void)GraphUtils::AddEdge(tIdentityNode->GetOutControlAnchor(), constNode->GetInControlAnchor());
  175. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), switch_node2->GetInDataAnchor(0));
  176. (void)GraphUtils::AddEdge(identityNode->GetOutDataAnchor(0), switch_node2->GetInDataAnchor(1));
  177. (void)GraphUtils::AddEdge(zNode->GetOutDataAnchor(0), mulNode->GetInDataAnchor(0));
  178. (void)GraphUtils::AddEdge(switch_node2->GetOutDataAnchor(0), mulNode->GetInDataAnchor(1));
  179. (void)GraphUtils::AddEdge(mulNode->GetOutDataAnchor(0), ceilNode->GetInDataAnchor(0));
  180. (void)GraphUtils::AddEdge(constNode->GetOutDataAnchor(0), mergeNode->GetInDataAnchor(1));
  181. (void)GraphUtils::AddEdge(ceilNode->GetOutDataAnchor(0), mergeNode->GetInDataAnchor(0));
  182. (void)GraphUtils::AddEdge(mergeNode->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));
  183. }
  184. void make_graph_cyclic_dependence(ComputeGraphPtr graph, bool match = true) {
  185. GeTensorDesc scalar_tensor_desc(GeShape({1, 1, 1, 1}));
  186. GeTensorDesc bool_tensor_desc(GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_BOOL);
  187. GeTensorDesc int_tensor_desc(GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_INT32);
  188. auto xOpDef = std::make_shared<OpDesc>("x", VARIABLEV2);
  189. xOpDef->AddOutputDesc(scalar_tensor_desc);
  190. auto xNode = graph->AddNode(xOpDef);
  191. auto yOpDef = std::make_shared<OpDesc>("y", VARIABLEV2);
  192. yOpDef->AddOutputDesc(scalar_tensor_desc);
  193. auto yNode = graph->AddNode(yOpDef);
  194. auto zOpDef = std::make_shared<OpDesc>("z", VARIABLEV2);
  195. zOpDef->AddOutputDesc(scalar_tensor_desc);
  196. auto zNode = graph->AddNode(zOpDef);
  197. auto condOpDef = std::make_shared<OpDesc>("Less", "Less");
  198. condOpDef->AddInputDesc(scalar_tensor_desc);
  199. condOpDef->AddInputDesc(scalar_tensor_desc);
  200. condOpDef->AddOutputDesc(bool_tensor_desc);
  201. auto condNode = graph->AddNode(condOpDef);
  202. auto switch_op_def1 = std::make_shared<OpDesc>("Switch_f_1", SWITCH);
  203. switch_op_def1->AddInputDesc(scalar_tensor_desc);
  204. switch_op_def1->AddInputDesc(bool_tensor_desc);
  205. switch_op_def1->AddOutputDesc(scalar_tensor_desc);
  206. switch_op_def1->AddOutputDesc(scalar_tensor_desc);
  207. auto switch_node1 = graph->AddNode(switch_op_def1);
  208. auto switch_op_def2 = std::make_shared<OpDesc>("Switch_t_1", SWITCH);
  209. switch_op_def2->AddInputDesc(scalar_tensor_desc);
  210. switch_op_def2->AddInputDesc(bool_tensor_desc);
  211. switch_op_def2->AddOutputDesc(scalar_tensor_desc);
  212. switch_op_def2->AddOutputDesc(scalar_tensor_desc);
  213. auto switch_node2 = graph->AddNode(switch_op_def2);
  214. auto switch_op_def3 = std::make_shared<OpDesc>("Switch_f_2", SWITCH);
  215. switch_op_def3->AddInputDesc(scalar_tensor_desc);
  216. switch_op_def3->AddInputDesc(bool_tensor_desc);
  217. switch_op_def3->AddOutputDesc(scalar_tensor_desc);
  218. switch_op_def3->AddOutputDesc(scalar_tensor_desc);
  219. auto switch_node3 = graph->AddNode(switch_op_def3);
  220. auto switch_op_def4 = std::make_shared<OpDesc>("Switch_t_2", SWITCH);
  221. switch_op_def4->AddInputDesc(scalar_tensor_desc);
  222. switch_op_def4->AddInputDesc(bool_tensor_desc);
  223. switch_op_def4->AddOutputDesc(scalar_tensor_desc);
  224. switch_op_def4->AddOutputDesc(scalar_tensor_desc);
  225. auto switch_node4 = graph->AddNode(switch_op_def4);
  226. auto squareOpDef1 = std::make_shared<OpDesc>("Square1", "Square");
  227. squareOpDef1->AddInputDesc(scalar_tensor_desc);
  228. squareOpDef1->AddOutputDesc(scalar_tensor_desc);
  229. auto squareNode1 = graph->AddNode(squareOpDef1);
  230. auto squareOpDef2 = std::make_shared<OpDesc>("Square2", "Square");
  231. squareOpDef2->AddInputDesc(scalar_tensor_desc);
  232. squareOpDef2->AddOutputDesc(scalar_tensor_desc);
  233. auto squareNode2 = graph->AddNode(squareOpDef2);
  234. auto squareOpDef3 = std::make_shared<OpDesc>("Square3", "Square");
  235. squareOpDef3->AddInputDesc(scalar_tensor_desc);
  236. squareOpDef3->AddOutputDesc(scalar_tensor_desc);
  237. auto squareNode3 = graph->AddNode(squareOpDef3);
  238. auto squareOpDef4 = std::make_shared<OpDesc>("Square4", "Square");
  239. squareOpDef4->AddInputDesc(scalar_tensor_desc);
  240. squareOpDef4->AddOutputDesc(scalar_tensor_desc);
  241. auto squareNode4 = graph->AddNode(squareOpDef4);
  242. auto merge_op_def1 = std::make_shared<OpDesc>("Merge1", "Merge");
  243. merge_op_def1->AddInputDesc(scalar_tensor_desc);
  244. merge_op_def1->AddInputDesc(scalar_tensor_desc);
  245. merge_op_def1->AddOutputDesc(scalar_tensor_desc);
  246. merge_op_def1->AddOutputDesc(int_tensor_desc);
  247. auto merge_node1 = graph->AddNode(merge_op_def1);
  248. auto merge_op_def2 = std::make_shared<OpDesc>("Merge2", "Merge");
  249. merge_op_def2->AddInputDesc(scalar_tensor_desc);
  250. merge_op_def2->AddInputDesc(scalar_tensor_desc);
  251. merge_op_def2->AddOutputDesc(scalar_tensor_desc);
  252. merge_op_def2->AddOutputDesc(int_tensor_desc);
  253. auto merge_node2 = graph->AddNode(merge_op_def2);
  254. auto output_op_def = std::make_shared<OpDesc>("NetOutput", "NetOutput");
  255. output_op_def->AddInputDesc(scalar_tensor_desc);
  256. output_op_def->AddOutputDesc(scalar_tensor_desc);
  257. auto output_node = graph->AddNode(output_op_def);
  258. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), condNode->GetInDataAnchor(0));
  259. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), condNode->GetInDataAnchor(1));
  260. (void)GraphUtils::AddEdge(zNode->GetOutDataAnchor(0), switch_node1->GetInDataAnchor(0));
  261. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switch_node1->GetInDataAnchor(1));
  262. (void)GraphUtils::AddEdge(zNode->GetOutDataAnchor(0), switch_node2->GetInDataAnchor(0));
  263. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switch_node2->GetInDataAnchor(1));
  264. (void)GraphUtils::AddEdge(switch_node1->GetOutDataAnchor(0), squareNode1->GetInDataAnchor(0));
  265. (void)GraphUtils::AddEdge(switch_node2->GetOutDataAnchor(1), squareNode2->GetInDataAnchor(0));
  266. (void)GraphUtils::AddEdge(squareNode1->GetOutDataAnchor(0), merge_node1->GetInDataAnchor(0));
  267. (void)GraphUtils::AddEdge(squareNode2->GetOutDataAnchor(0), merge_node1->GetInDataAnchor(1));
  268. (void)GraphUtils::AddEdge(merge_node1->GetOutDataAnchor(0), switch_node3->GetInDataAnchor(0));
  269. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switch_node3->GetInDataAnchor(1));
  270. (void)GraphUtils::AddEdge(zNode->GetOutDataAnchor(0), switch_node4->GetInDataAnchor(0));
  271. (void)GraphUtils::AddEdge(condNode->GetOutDataAnchor(0), switch_node4->GetInDataAnchor(1));
  272. (void)GraphUtils::AddEdge(switch_node3->GetOutDataAnchor(0), squareNode3->GetInDataAnchor(0));
  273. (void)GraphUtils::AddEdge(switch_node4->GetOutDataAnchor(1), squareNode4->GetInDataAnchor(0));
  274. (void)GraphUtils::AddEdge(squareNode3->GetOutDataAnchor(0), merge_node2->GetInDataAnchor(0));
  275. (void)GraphUtils::AddEdge(squareNode4->GetOutDataAnchor(0), merge_node2->GetInDataAnchor(1));
  276. (void)GraphUtils::AddEdge(merge_node2->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));
  277. }
  278. void make_graph_case(ComputeGraphPtr graph, bool match = true) {
  279. GeTensorDesc scalar_tensor_desc(GeShape({1, 1, 1, 1}));
  280. GeTensorDesc bool_tensor_desc(GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_BOOL);
  281. GeTensorDesc int_tensor_desc(GeShape({1, 1, 1, 1}), ge::FORMAT_NCHW, ge::DT_INT32);
  282. auto xOpDef = std::make_shared<OpDesc>("x", VARIABLEV2);
  283. xOpDef->AddOutputDesc(scalar_tensor_desc);
  284. auto xNode = graph->AddNode(xOpDef);
  285. auto yOpDef = std::make_shared<OpDesc>("y", VARIABLEV2);
  286. yOpDef->AddOutputDesc(scalar_tensor_desc);
  287. auto yNode = graph->AddNode(yOpDef);
  288. auto zOpDef = std::make_shared<OpDesc>("z", VARIABLEV2);
  289. zOpDef->AddOutputDesc(scalar_tensor_desc);
  290. auto zNode = graph->AddNode(zOpDef);
  291. auto greater_op_def = std::make_shared<OpDesc>("Greater", "Greater");
  292. greater_op_def->AddInputDesc(scalar_tensor_desc);
  293. greater_op_def->AddInputDesc(scalar_tensor_desc);
  294. greater_op_def->AddOutputDesc(bool_tensor_desc);
  295. auto greaterNode = graph->AddNode(greater_op_def);
  296. auto less_op_def = std::make_shared<OpDesc>("Less", "Less");
  297. less_op_def->AddInputDesc(scalar_tensor_desc);
  298. less_op_def->AddInputDesc(scalar_tensor_desc);
  299. less_op_def->AddOutputDesc(bool_tensor_desc);
  300. auto less_node = graph->AddNode(less_op_def);
  301. auto switch_op_def1 = std::make_shared<OpDesc>("greater/Switch_t", SWITCH);
  302. switch_op_def1->AddInputDesc(bool_tensor_desc);
  303. switch_op_def1->AddInputDesc(bool_tensor_desc);
  304. switch_op_def1->AddOutputDesc(bool_tensor_desc);
  305. switch_op_def1->AddOutputDesc(bool_tensor_desc);
  306. auto switch_node1 = graph->AddNode(switch_op_def1);
  307. auto switch_op_def2 = std::make_shared<OpDesc>("greater/Switch_f", SWITCH);
  308. switch_op_def2->AddInputDesc(scalar_tensor_desc);
  309. switch_op_def2->AddInputDesc(bool_tensor_desc);
  310. switch_op_def2->AddOutputDesc(scalar_tensor_desc);
  311. switch_op_def2->AddOutputDesc(scalar_tensor_desc);
  312. auto switch_node2 = graph->AddNode(switch_op_def2);
  313. auto switch_op_def3 = std::make_shared<OpDesc>("less/Switch_t", SWITCH);
  314. switch_op_def3->AddInputDesc(scalar_tensor_desc);
  315. switch_op_def3->AddInputDesc(bool_tensor_desc);
  316. switch_op_def3->AddOutputDesc(scalar_tensor_desc);
  317. switch_op_def3->AddOutputDesc(scalar_tensor_desc);
  318. auto switch_node3 = graph->AddNode(switch_op_def3);
  319. auto switch_op_def4 = std::make_shared<OpDesc>("less/Switch_f", SWITCH);
  320. switch_op_def4->AddInputDesc(scalar_tensor_desc);
  321. switch_op_def4->AddInputDesc(bool_tensor_desc);
  322. switch_op_def4->AddOutputDesc(scalar_tensor_desc);
  323. switch_op_def4->AddOutputDesc(scalar_tensor_desc);
  324. auto switch_node4 = graph->AddNode(switch_op_def4);
  325. auto merge_op_def1 = std::make_shared<OpDesc>("Merge1", "Merge");
  326. merge_op_def1->AddInputDesc(scalar_tensor_desc);
  327. merge_op_def1->AddInputDesc(scalar_tensor_desc);
  328. merge_op_def1->AddOutputDesc(scalar_tensor_desc);
  329. merge_op_def1->AddOutputDesc(int_tensor_desc);
  330. auto merge_node1 = graph->AddNode(merge_op_def1);
  331. auto merge_op_def2 = std::make_shared<OpDesc>("Merge2", "Merge");
  332. merge_op_def2->AddInputDesc(scalar_tensor_desc);
  333. merge_op_def2->AddInputDesc(scalar_tensor_desc);
  334. merge_op_def2->AddOutputDesc(scalar_tensor_desc);
  335. merge_op_def2->AddOutputDesc(int_tensor_desc);
  336. auto merge_node2 = graph->AddNode(merge_op_def2);
  337. auto output_op_def = std::make_shared<OpDesc>("NetOutput", "NetOutput");
  338. output_op_def->AddInputDesc(scalar_tensor_desc);
  339. output_op_def->AddOutputDesc(scalar_tensor_desc);
  340. auto output_node = graph->AddNode(output_op_def);
  341. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), greaterNode->GetInDataAnchor(0));
  342. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), greaterNode->GetInDataAnchor(1));
  343. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), less_node->GetInDataAnchor(0));
  344. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), less_node->GetInDataAnchor(1));
  345. (void)GraphUtils::AddEdge(xNode->GetOutDataAnchor(0), switch_node1->GetInDataAnchor(0));
  346. (void)GraphUtils::AddEdge(greaterNode->GetOutDataAnchor(0), switch_node1->GetInDataAnchor(1));
  347. (void)GraphUtils::AddEdge(less_node->GetOutDataAnchor(0), switch_node2->GetInDataAnchor(0));
  348. (void)GraphUtils::AddEdge(greaterNode->GetOutDataAnchor(0), switch_node2->GetInDataAnchor(1));
  349. (void)GraphUtils::AddEdge(yNode->GetOutDataAnchor(0), switch_node3->GetInDataAnchor(0));
  350. (void)GraphUtils::AddEdge(switch_node2->GetOutDataAnchor(0), switch_node3->GetInDataAnchor(1));
  351. (void)GraphUtils::AddEdge(zNode->GetOutDataAnchor(0), switch_node4->GetInDataAnchor(0));
  352. (void)GraphUtils::AddEdge(switch_node2->GetOutDataAnchor(0), switch_node4->GetInDataAnchor(1));
  353. (void)GraphUtils::AddEdge(switch_node3->GetOutDataAnchor(1), merge_node1->GetInDataAnchor(0));
  354. (void)GraphUtils::AddEdge(switch_node4->GetOutDataAnchor(0), merge_node1->GetInDataAnchor(1));
  355. (void)GraphUtils::AddEdge(switch_node1->GetOutDataAnchor(1), merge_node2->GetInDataAnchor(0));
  356. (void)GraphUtils::AddEdge(merge_node1->GetOutDataAnchor(0), merge_node2->GetInDataAnchor(1));
  357. (void)GraphUtils::AddEdge(merge_node2->GetOutDataAnchor(0), output_node->GetInDataAnchor(0));
  358. }
  359. };

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示