You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

prune_pass_unittest.cc 20 kB

5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include <vector>
  18. #include "omg/omg_inner_types.h"
  19. #define protected public
  20. #define private public
  21. #include "graph/passes/prune_pass.h"
  22. #include "anchor.h"
  23. #include "common/debug/log.h"
  24. #include "common/debug/memory_dumper.h"
  25. #include "common/op/attr_value_util.h"
  26. #include "common/types.h"
  27. #include "framework/common/ge_inner_error_codes.h"
  28. #include "graph/attr_value.h"
  29. #include "graph/debug/ge_attr_define.h"
  30. #include "inc/pass_manager.h"
  31. #undef protected
  32. #undef private
  33. using namespace testing;
  34. using namespace ge;
  35. using namespace std;
  36. class UtestGraphPassesPrunePass : public testing::Test {
  37. protected:
  38. void SetUp() {}
  39. void TearDown() {}
  40. };
  41. // case1:no net_out_put_node
  42. TEST_F(UtestGraphPassesPrunePass, no_net_out_put_node) {
  43. ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
  44. ge::OpDescPtr reverse_op = std::make_shared<ge::OpDesc>();
  45. reverse_op->SetType(REVERSE);
  46. reverse_op->SetName("Reverse");
  47. reverse_op->AddOutputDesc(ge::GeTensorDesc());
  48. ge::NodePtr reverse_node = graph->AddNode(reverse_op);
  49. ge::OpDescPtr floor_op = std::make_shared<ge::OpDesc>();
  50. floor_op->SetType(FLOOR);
  51. floor_op->SetName("Floor");
  52. floor_op->AddInputDesc(ge::GeTensorDesc());
  53. floor_op->AddOutputDesc(ge::GeTensorDesc());
  54. ge::NodePtr floor_node = graph->AddNode(floor_op);
  55. ge::GraphUtils::AddEdge(reverse_node->GetOutDataAnchor(0), floor_node->GetInDataAnchor(0));
  56. uint64_t size_ori = graph->GetDirectNode().size();
  57. PrunePass prune_pass;
  58. std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} };
  59. Status status = PassManager::Run(graph, passes);
  60. EXPECT_EQ(ge::SUCCESS, status);
  61. uint64_t size = graph->GetDirectNode().size();
  62. EXPECT_EQ(size, size_ori);
  63. }
  64. // case2: one net path with one bypass branch
  65. TEST_F(UtestGraphPassesPrunePass, has_net_out_put_node_with_only_one_path) {
  66. ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
  67. ge::OpDescPtr reverse_op = std::make_shared<ge::OpDesc>();
  68. reverse_op->SetType(REVERSE);
  69. reverse_op->SetName("Reverse");
  70. reverse_op->AddOutputDesc(ge::GeTensorDesc());
  71. ge::NodePtr reverse_node = graph->AddNode(reverse_op);
  72. ge::OpDescPtr floor_op = std::make_shared<ge::OpDesc>();
  73. floor_op->SetType(FLOOR);
  74. floor_op->SetName("Floor");
  75. floor_op->AddInputDesc(ge::GeTensorDesc());
  76. floor_op->AddOutputDesc(ge::GeTensorDesc());
  77. ge::NodePtr floor_node = graph->AddNode(floor_op);
  78. ge::OpDescPtr net_output_op = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
  79. net_output_op->AddInputDesc(ge::GeTensorDesc());
  80. net_output_op->AddOutputDesc(ge::GeTensorDesc());
  81. ge::AttrUtils::SetBool(net_output_op, "identity_add_netoutput", true);
  82. ge::NodePtr netoutput_node = graph->AddNode(net_output_op);
  83. ge::OpDescPtr reverse_op1 = std::make_shared<ge::OpDesc>();
  84. reverse_op->SetType(REVERSE);
  85. reverse_op->SetName("Reverse1");
  86. reverse_op->AddOutputDesc(ge::GeTensorDesc());
  87. ge::NodePtr reverse_node1 = graph->AddNode(reverse_op1);
  88. ge::GraphUtils::AddEdge(reverse_node->GetOutDataAnchor(0), floor_node->GetInDataAnchor(0));
  89. ge::GraphUtils::AddEdge(floor_node->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0));
  90. uint64_t size_ori = graph->GetDirectNode().size();
  91. PrunePass prune_pass;
  92. std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} };
  93. Status status = PassManager::Run(graph, passes);
  94. uint64_t size = graph->GetDirectNode().size();
  95. int diff = size_ori - size;
  96. EXPECT_EQ(ge::SUCCESS, status);
  97. EXPECT_EQ(diff, 1);
  98. }
  99. // case3: one net path with one bypass branch
  100. TEST_F(UtestGraphPassesPrunePass, has_net_out_put_node_with_one_valid_path_and_one_bypass_path) {
  101. ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
  102. // valid path construct (reverse->floor->net_out)
  103. ge::OpDescPtr reverse_op = std::make_shared<ge::OpDesc>();
  104. reverse_op->SetType(REVERSE);
  105. reverse_op->SetName("Reverse");
  106. reverse_op->AddOutputDesc(ge::GeTensorDesc());
  107. reverse_op->AddOutputDesc(ge::GeTensorDesc());
  108. ge::NodePtr reverse_node = graph->AddNode(reverse_op);
  109. ge::OpDescPtr floor_op = std::make_shared<ge::OpDesc>();
  110. floor_op->SetType(FLOOR);
  111. floor_op->SetName("Floor");
  112. floor_op->AddInputDesc(ge::GeTensorDesc());
  113. floor_op->AddOutputDesc(ge::GeTensorDesc());
  114. ge::NodePtr floor_node = graph->AddNode(floor_op);
  115. ge::OpDescPtr net_output_op = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
  116. net_output_op->AddInputDesc(ge::GeTensorDesc());
  117. net_output_op->AddOutputDesc(ge::GeTensorDesc());
  118. ge::AttrUtils::SetBool(net_output_op, "identity_add_netoutput", true);
  119. ge::NodePtr netoutput_node = graph->AddNode(net_output_op);
  120. ge::GraphUtils::AddEdge(reverse_node->GetOutDataAnchor(0), floor_node->GetInDataAnchor(0));
  121. ge::GraphUtils::AddEdge(floor_node->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0));
  122. // incvalid path construct (reverse->floor1->floor2)
  123. ge::OpDescPtr floor_op1 = std::make_shared<ge::OpDesc>();
  124. floor_op1->SetType(FLOOR);
  125. floor_op1->SetName("Floor1");
  126. floor_op1->AddInputDesc(ge::GeTensorDesc());
  127. floor_op1->AddOutputDesc(ge::GeTensorDesc());
  128. ge::NodePtr floor_node1 = graph->AddNode(floor_op1);
  129. ge::OpDescPtr floor_op2 = std::make_shared<ge::OpDesc>();
  130. floor_op2->SetType(FLOOR);
  131. floor_op2->SetName("Floor2");
  132. floor_op2->AddInputDesc(ge::GeTensorDesc());
  133. floor_op2->AddOutputDesc(ge::GeTensorDesc());
  134. ge::NodePtr floor_node2 = graph->AddNode(floor_op2);
  135. // isolated node
  136. ge::OpDescPtr floor_op3 = std::make_shared<ge::OpDesc>();
  137. floor_op3->SetType(FLOOR);
  138. floor_op3->SetName("Floor3");
  139. floor_op3->AddInputDesc(ge::GeTensorDesc());
  140. floor_op3->AddOutputDesc(ge::GeTensorDesc());
  141. ge::NodePtr floor_node3 = graph->AddNode(floor_op3);
  142. ge::GraphUtils::AddEdge(reverse_node->GetOutDataAnchor(1), floor_node1->GetInDataAnchor(0));
  143. ge::GraphUtils::AddEdge(floor_node1->GetOutDataAnchor(0), floor_node2->GetInDataAnchor(0));
  144. uint64_t size_ori = graph->GetDirectNode().size();
  145. PrunePass prune_pass;
  146. vector<GraphPass *> passes = {&prune_pass};
  147. }
  148. // case 4: multi net path with one common netout(1:multi:1)
  149. TEST_F(UtestGraphPassesPrunePass, has_net_out_put_node_with_multi_path) {
  150. ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
  151. ge::OpDescPtr data_op = std::make_shared<ge::OpDesc>();
  152. data_op->SetType(DATA);
  153. data_op->SetName("data");
  154. data_op->AddOutputDesc(ge::GeTensorDesc());
  155. data_op->AddOutputDesc(ge::GeTensorDesc());
  156. data_op->AddOutputDesc(ge::GeTensorDesc());
  157. ge::NodePtr data_node = graph->AddNode(data_op);
  158. ge::OpDescPtr reverse_op1 = std::make_shared<ge::OpDesc>();
  159. reverse_op1->SetType(REVERSE);
  160. reverse_op1->SetName("Reverse1");
  161. reverse_op1->AddInputDesc(ge::GeTensorDesc());
  162. reverse_op1->AddOutputDesc(ge::GeTensorDesc());
  163. ge::NodePtr reverse_node1 = graph->AddNode(reverse_op1);
  164. ge::OpDescPtr floor_op1 = std::make_shared<ge::OpDesc>();
  165. floor_op1->SetType(FLOOR);
  166. floor_op1->SetName("Floor1");
  167. floor_op1->AddInputDesc(ge::GeTensorDesc());
  168. floor_op1->AddOutputDesc(ge::GeTensorDesc());
  169. ge::NodePtr floor_node1 = graph->AddNode(floor_op1);
  170. ge::OpDescPtr reverse_op2 = std::make_shared<ge::OpDesc>();
  171. reverse_op2->SetType(REVERSE);
  172. reverse_op2->SetName("Reverse2");
  173. reverse_op2->AddInputDesc(ge::GeTensorDesc());
  174. reverse_op2->AddOutputDesc(ge::GeTensorDesc());
  175. ge::NodePtr reverse_node2 = graph->AddNode(reverse_op2);
  176. ge::OpDescPtr floor_op2 = std::make_shared<ge::OpDesc>();
  177. floor_op2->SetType(FLOOR);
  178. floor_op2->SetName("Floor2");
  179. floor_op2->AddInputDesc(ge::GeTensorDesc());
  180. floor_op2->AddOutputDesc(ge::GeTensorDesc());
  181. ge::NodePtr floor_node2 = graph->AddNode(floor_op2);
  182. ge::OpDescPtr reverse_op3 = std::make_shared<ge::OpDesc>();
  183. reverse_op3->SetType(REVERSE);
  184. reverse_op3->SetName("Reverse3");
  185. reverse_op3->AddInputDesc(ge::GeTensorDesc());
  186. reverse_op3->AddOutputDesc(ge::GeTensorDesc());
  187. ge::NodePtr reverse_node3 = graph->AddNode(reverse_op3);
  188. ge::OpDescPtr floor_op3 = std::make_shared<ge::OpDesc>();
  189. floor_op3->SetType(FLOOR);
  190. floor_op3->SetName("Floor3");
  191. floor_op3->AddInputDesc(ge::GeTensorDesc());
  192. floor_op3->AddOutputDesc(ge::GeTensorDesc());
  193. ge::NodePtr floor_node3 = graph->AddNode(floor_op3);
  194. ge::OpDescPtr net_output_op = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
  195. net_output_op->AddInputDesc(ge::GeTensorDesc());
  196. net_output_op->AddInputDesc(ge::GeTensorDesc());
  197. net_output_op->AddInputDesc(ge::GeTensorDesc());
  198. net_output_op->AddOutputDesc(ge::GeTensorDesc());
  199. ge::AttrUtils::SetBool(net_output_op, "identity_add_netoutput", true);
  200. ge::NodePtr netoutput_node = graph->AddNode(net_output_op);
  201. ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), reverse_node1->GetInDataAnchor(0));
  202. ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(1), reverse_node2->GetInDataAnchor(0));
  203. ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(2), reverse_node3->GetInDataAnchor(0));
  204. ge::GraphUtils::AddEdge(reverse_node1->GetOutDataAnchor(0), floor_node1->GetInDataAnchor(0));
  205. ge::GraphUtils::AddEdge(floor_node1->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0));
  206. ge::GraphUtils::AddEdge(reverse_node2->GetOutDataAnchor(0), floor_node2->GetInDataAnchor(0));
  207. ge::GraphUtils::AddEdge(floor_node2->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(1));
  208. ge::GraphUtils::AddEdge(reverse_node3->GetOutDataAnchor(0), floor_node3->GetInDataAnchor(0));
  209. ge::GraphUtils::AddEdge(floor_node3->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(2));
  210. uint64_t size_ori = graph->GetDirectNode().size();
  211. PrunePass prune_pass;
  212. std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} };
  213. Status status = PassManager::Run(graph, passes);
  214. uint64_t size_after_proc = graph->GetDirectNode().size();
  215. EXPECT_EQ(size_ori, size_after_proc);
  216. }
  217. // case 5: circle,diamand style
  218. TEST_F(UtestGraphPassesPrunePass, multi_net_out_put_node_with_circle_net) {
  219. ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
  220. ge::OpDescPtr data_op = std::make_shared<ge::OpDesc>();
  221. data_op->SetType(DATA);
  222. data_op->SetName("data");
  223. data_op->AddOutputDesc(ge::GeTensorDesc());
  224. data_op->AddOutputDesc(ge::GeTensorDesc());
  225. ge::NodePtr data_node = graph->AddNode(data_op);
  226. ge::OpDescPtr op_1 = std::make_shared<ge::OpDesc>();
  227. op_1->SetType(REVERSE);
  228. op_1->SetName("Reverse1");
  229. op_1->AddInputDesc(ge::GeTensorDesc());
  230. op_1->AddOutputDesc(ge::GeTensorDesc());
  231. op_1->AddOutputDesc(ge::GeTensorDesc());
  232. ge::NodePtr node_1 = graph->AddNode(op_1);
  233. ge::OpDescPtr op_2 = std::make_shared<ge::OpDesc>();
  234. op_2->SetType(REVERSE);
  235. op_2->SetName("Reverse2");
  236. op_2->AddInputDesc(ge::GeTensorDesc());
  237. op_2->AddInputDesc(ge::GeTensorDesc());
  238. op_2->AddOutputDesc(ge::GeTensorDesc());
  239. ge::NodePtr node_2 = graph->AddNode(op_2);
  240. ge::OpDescPtr op_3 = std::make_shared<ge::OpDesc>();
  241. op_3->SetType(REVERSE);
  242. op_3->SetName("Reverse3");
  243. op_3->AddInputDesc(ge::GeTensorDesc());
  244. op_3->AddInputDesc(ge::GeTensorDesc());
  245. op_3->AddOutputDesc(ge::GeTensorDesc());
  246. ge::NodePtr node_3 = graph->AddNode(op_3);
  247. ge::OpDescPtr op_4 = std::make_shared<ge::OpDesc>();
  248. op_4->SetType(REVERSE);
  249. op_4->SetName("Reverse4");
  250. op_4->AddInputDesc(ge::GeTensorDesc());
  251. op_4->AddOutputDesc(ge::GeTensorDesc());
  252. ge::NodePtr node_4 = graph->AddNode(op_4);
  253. ge::OpDescPtr op_5 = std::make_shared<ge::OpDesc>();
  254. op_5->SetType(REVERSE);
  255. op_5->SetName("Reverse5");
  256. op_5->AddInputDesc(ge::GeTensorDesc());
  257. op_5->AddOutputDesc(ge::GeTensorDesc());
  258. ge::NodePtr node_5 = graph->AddNode(op_5);
  259. ge::OpDescPtr net_output_op = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
  260. net_output_op->AddInputDesc(ge::GeTensorDesc());
  261. net_output_op->AddOutputDesc(ge::GeTensorDesc());
  262. ge::AttrUtils::SetBool(net_output_op, "identity_add_netoutput", true);
  263. ge::NodePtr netoutput_node = graph->AddNode(net_output_op);
  264. ge::GraphUtils::AddEdge(node_1->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0));
  265. ge::GraphUtils::AddEdge(node_2->GetOutDataAnchor(0), node_1->GetInDataAnchor(0));
  266. ge::GraphUtils::AddEdge(node_3->GetOutDataAnchor(0), node_2->GetInDataAnchor(0));
  267. ge::GraphUtils::AddEdge(node_4->GetOutDataAnchor(0), node_3->GetInDataAnchor(0));
  268. ge::GraphUtils::AddEdge(node_1->GetOutDataAnchor(1), node_4->GetInDataAnchor(0));
  269. ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node_2->GetInDataAnchor(1));
  270. ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(1), node_5->GetInDataAnchor(0));
  271. ge::GraphUtils::AddEdge(node_5->GetOutDataAnchor(0), node_3->GetInDataAnchor(1));
  272. uint64_t size_ori = graph->GetDirectNode().size();
  273. PrunePass prune_pass;
  274. std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} };
  275. Status status = PassManager::Run(graph, passes);
  276. EXPECT_EQ(ge::SUCCESS, status);
  277. uint64_t size_after_proc = graph->GetDirectNode().size();
  278. EXPECT_EQ(size_ori, size_after_proc);
  279. }
  280. // case 6: two mix circle and multi path,diamand style
  281. TEST_F(UtestGraphPassesPrunePass, mix_two_circle_net) {
  282. ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
  283. ge::OpDescPtr data_op = std::make_shared<ge::OpDesc>();
  284. data_op->SetType(DATA);
  285. data_op->SetName("data");
  286. data_op->AddOutputDesc(ge::GeTensorDesc());
  287. data_op->AddOutputDesc(ge::GeTensorDesc());
  288. ge::NodePtr data_node = graph->AddNode(data_op);
  289. ge::OpDescPtr op_1 = std::make_shared<ge::OpDesc>();
  290. op_1->SetType(REVERSE);
  291. op_1->SetName("Reverse1");
  292. op_1->AddInputDesc(ge::GeTensorDesc());
  293. op_1->AddInputDesc(ge::GeTensorDesc());
  294. op_1->AddOutputDesc(ge::GeTensorDesc());
  295. ge::NodePtr node_1 = graph->AddNode(op_1);
  296. ge::OpDescPtr op_2 = std::make_shared<ge::OpDesc>();
  297. op_2->SetType(REVERSE);
  298. op_2->SetName("Reverse2");
  299. op_2->AddInputDesc(ge::GeTensorDesc());
  300. op_2->AddOutputDesc(ge::GeTensorDesc());
  301. op_2->AddOutputDesc(ge::GeTensorDesc());
  302. ge::NodePtr node_2 = graph->AddNode(op_2);
  303. ge::OpDescPtr op_3 = std::make_shared<ge::OpDesc>();
  304. op_3->SetType(REVERSE);
  305. op_3->SetName("Reverse3");
  306. op_3->AddInputDesc(ge::GeTensorDesc());
  307. op_3->AddInputDesc(ge::GeTensorDesc());
  308. op_3->AddOutputDesc(ge::GeTensorDesc());
  309. ge::NodePtr node_3 = graph->AddNode(op_3);
  310. ge::OpDescPtr op_4 = std::make_shared<ge::OpDesc>();
  311. op_4->SetType(REVERSE);
  312. op_4->SetName("Reverse4");
  313. op_4->AddInputDesc(ge::GeTensorDesc());
  314. op_4->AddInputDesc(ge::GeTensorDesc());
  315. op_4->AddOutputDesc(ge::GeTensorDesc());
  316. ge::NodePtr node_4 = graph->AddNode(op_4);
  317. ge::OpDescPtr op_5 = std::make_shared<ge::OpDesc>();
  318. op_5->SetType(REVERSE);
  319. op_5->SetName("Reverse5");
  320. op_5->AddInputDesc(ge::GeTensorDesc());
  321. op_5->AddOutputDesc(ge::GeTensorDesc());
  322. op_5->AddOutputDesc(ge::GeTensorDesc());
  323. ge::NodePtr node_5 = graph->AddNode(op_5);
  324. ge::OpDescPtr net_output_op = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
  325. net_output_op->AddInputDesc(ge::GeTensorDesc());
  326. net_output_op->AddOutputDesc(ge::GeTensorDesc());
  327. ge::AttrUtils::SetBool(net_output_op, "identity_add_netoutput", true);
  328. ge::NodePtr netoutput_node = graph->AddNode(net_output_op);
  329. ge::GraphUtils::AddEdge(node_1->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0));
  330. ge::GraphUtils::AddEdge(node_2->GetOutDataAnchor(0), node_1->GetInDataAnchor(0));
  331. ge::GraphUtils::AddEdge(node_5->GetOutDataAnchor(0), node_1->GetInDataAnchor(1));
  332. ge::GraphUtils::AddEdge(node_4->GetOutDataAnchor(0), node_2->GetInDataAnchor(0));
  333. ge::GraphUtils::AddEdge(node_2->GetOutDataAnchor(1), node_3->GetInDataAnchor(0));
  334. ge::GraphUtils::AddEdge(node_5->GetOutDataAnchor(1), node_3->GetInDataAnchor(1));
  335. ge::GraphUtils::AddEdge(node_3->GetOutDataAnchor(0), node_4->GetInDataAnchor(0));
  336. ge::GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), node_4->GetInDataAnchor(1));
  337. ge::GraphUtils::AddEdge(node_4->GetOutDataAnchor(1), node_5->GetInDataAnchor(0));
  338. // construct two isolated node
  339. ge::OpDescPtr op_6 = std::make_shared<ge::OpDesc>();
  340. op_6->SetType(REVERSE);
  341. op_6->SetName("Reverse");
  342. op_6->AddInputDesc(ge::GeTensorDesc());
  343. op_6->AddOutputDesc(ge::GeTensorDesc());
  344. ge::NodePtr node_6 = graph->AddNode(op_6);
  345. ge::OpDescPtr op_7 = std::make_shared<ge::OpDesc>();
  346. op_7->SetType(REVERSE);
  347. op_7->SetName("Reverse");
  348. op_7->AddInputDesc(ge::GeTensorDesc());
  349. op_7->AddOutputDesc(ge::GeTensorDesc());
  350. ge::NodePtr node_7 = graph->AddNode(op_7);
  351. uint64_t size_ori = graph->GetDirectNode().size();
  352. PrunePass prune_pass;
  353. vector<GraphPass *> passes = {&prune_pass};
  354. }
  355. // case7: one net path with two DATA node
  356. TEST_F(UtestGraphPassesPrunePass, has_net_out_put_node_with_two_isolate_data_node) {
  357. ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
  358. ge::OpDescPtr reverse_op = std::make_shared<ge::OpDesc>();
  359. reverse_op->SetType(REVERSE);
  360. reverse_op->SetName("Reverse");
  361. reverse_op->AddOutputDesc(ge::GeTensorDesc());
  362. ge::NodePtr reverse_node = graph->AddNode(reverse_op);
  363. ge::OpDescPtr floor_op = std::make_shared<ge::OpDesc>();
  364. floor_op->SetType(FLOOR);
  365. floor_op->SetName("Floor");
  366. floor_op->AddInputDesc(ge::GeTensorDesc());
  367. floor_op->AddOutputDesc(ge::GeTensorDesc());
  368. ge::NodePtr floor_node = graph->AddNode(floor_op);
  369. ge::OpDescPtr net_output_op = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
  370. net_output_op->AddInputDesc(ge::GeTensorDesc());
  371. net_output_op->AddOutputDesc(ge::GeTensorDesc());
  372. ge::AttrUtils::SetBool(net_output_op, "identity_add_netoutput", true);
  373. ge::NodePtr netoutput_node = graph->AddNode(net_output_op);
  374. // construct one isolated DATA node (to be deleted)
  375. ge::OpDescPtr reverse_op_1 = std::make_shared<ge::OpDesc>();
  376. reverse_op_1->SetType(REVERSE);
  377. reverse_op_1->SetName("Reverse1");
  378. reverse_op_1->AddOutputDesc(ge::GeTensorDesc());
  379. ge::NodePtr reverse_node_1 = graph->AddNode(reverse_op_1);
  380. ge::GraphUtils::AddEdge(reverse_node->GetOutDataAnchor(0), floor_node->GetInDataAnchor(0));
  381. ge::GraphUtils::AddEdge(floor_node->GetOutDataAnchor(0), netoutput_node->GetInDataAnchor(0));
  382. // construct two isolated DATA nodes(to be not deleted)
  383. ge::OpDescPtr data_op_1 = std::make_shared<ge::OpDesc>();
  384. data_op_1->SetType(DATA);
  385. data_op_1->SetName("data");
  386. data_op_1->AddOutputDesc(ge::GeTensorDesc());
  387. data_op_1->AddOutputDesc(ge::GeTensorDesc());
  388. ge::NodePtr data_node_1 = graph->AddNode(data_op_1);
  389. ge::OpDescPtr data_op_2 = std::make_shared<ge::OpDesc>();
  390. data_op_2->SetType(DATA);
  391. data_op_2->SetName("data1");
  392. data_op_2->AddOutputDesc(ge::GeTensorDesc());
  393. data_op_2->AddOutputDesc(ge::GeTensorDesc());
  394. ge::NodePtr data_node = graph->AddNode(data_op_2);
  395. uint64_t size_ori = graph->GetDirectNode().size();
  396. PrunePass prune_pass;
  397. std::vector<std::pair<string, GraphPass*>> passes = { {"prune_pass", &prune_pass} };
  398. Status status = PassManager::Run(graph, passes);
  399. uint64_t size = graph->GetDirectNode().size();
  400. EXPECT_EQ(ge::SUCCESS, status);
  401. EXPECT_EQ(size_ori, (size + 1));
  402. // it should check net_out_put's input data node and input control node
  403. auto control_vec = netoutput_node->GetInControlNodes();
  404. EXPECT_EQ(control_vec.size(), 2);
  405. // check control_vec contains only data node
  406. for (auto node : control_vec) {
  407. bool result = (node->GetName() == "data" || node->GetName() == "data1") ? true : false;
  408. EXPECT_EQ(result, true);
  409. }
  410. auto data_vec = netoutput_node->GetInDataNodes();
  411. EXPECT_EQ(data_vec.size(), 1);
  412. // check data_vec contains only Floor node
  413. for (auto node : data_vec) {
  414. bool result = (node->GetName() == "Floor") ? true : false;
  415. EXPECT_EQ(result, true);
  416. }
  417. }

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示