You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

net_output_pass_unittest.cc 36 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/net_output_pass.h"
  17. #include <gtest/gtest.h>
  18. #include "common/ge_inner_error_codes.h"
  19. #include "common/types.h"
  20. #include "ge/ge_api.h"
  21. #include "graph/compute_graph.h"
  22. #include "graph/debug/graph_debug.h"
  23. #include "graph/manager/graph_manager.h"
  24. #include "graph/manager/graph_manager_utils.h"
  25. #include "graph/operator_reg.h"
  26. #include "graph/utils/op_desc_utils.h"
  27. #include "inc/pass_manager.h"
  28. #include "init/gelib.h"
  29. #include "opskernel_manager/ops_kernel_manager.h"
  30. using namespace std;
  31. using namespace testing;
  32. using namespace ge;
  33. class UtestGraphPassesNetOutputPass : public testing::Test {
  34. protected:
  35. void SetUp() {}
  36. void TearDown() {}
  37. };
  38. ge::ComputeGraphPtr BuildClearWeightGraph(void) {
  39. ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
  40. ge::OpDescPtr cast_op = std::make_shared<ge::OpDesc>();
  41. cast_op->SetType(CAST);
  42. cast_op->SetName("Cast1");
  43. cast_op->AddInputDesc(ge::GeTensorDesc());
  44. cast_op->AddOutputDesc(ge::GeTensorDesc());
  45. ge::NodePtr cast_node = graph->AddNode(cast_op);
  46. ge::OpDescPtr const_op = std::make_shared<ge::OpDesc>();
  47. const_op->SetType(CONSTANT);
  48. const_op->SetName("Const1");
  49. const_op->AddOutputDesc(ge::GeTensorDesc());
  50. ge::NodePtr const_node = graph->AddNode(const_op);
  51. ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), cast_node->GetInDataAnchor(0));
  52. return graph;
  53. }
  54. ge::ComputeGraphPtr build_graph(bool with_leaf_node = false) {
  55. ge::ComputeGraphPtr graph = std::make_shared<ge::ComputeGraph>("default");
  56. ge::OpDescPtr data_op = std::make_shared<ge::OpDesc>();
  57. data_op->SetType(DATA);
  58. data_op->SetName("Data1");
  59. data_op->AddInputDesc(ge::GeTensorDesc());
  60. data_op->AddOutputDesc(ge::GeTensorDesc());
  61. ge::NodePtr data1 = graph->AddNode(data_op);
  62. ge::OpDescPtr relu_op1 = std::make_shared<ge::OpDesc>();
  63. relu_op1->SetType(ACTIVATION);
  64. relu_op1->SetName("Relu1");
  65. relu_op1->AddInputDesc(ge::GeTensorDesc());
  66. relu_op1->AddOutputDesc(ge::GeTensorDesc());
  67. ge::NodePtr relu1 = graph->AddNode(relu_op1);
  68. ge::OpDescPtr relu_op2 = std::make_shared<ge::OpDesc>();
  69. relu_op2->SetType(RELU);
  70. relu_op2->SetName("Relu2");
  71. relu_op2->AddInputDesc(ge::GeTensorDesc());
  72. relu_op2->AddOutputDesc(ge::GeTensorDesc());
  73. relu_op2->AddOutputDesc(ge::GeTensorDesc());
  74. ge::NodePtr relu2 = graph->AddNode(relu_op2);
  75. ge::OpDescPtr relu_op3 = std::make_shared<ge::OpDesc>();
  76. relu_op3->SetType(ACTIVATION);
  77. relu_op3->SetName("Relu3");
  78. relu_op3->AddInputDesc(ge::GeTensorDesc());
  79. relu_op3->AddOutputDesc(ge::GeTensorDesc());
  80. ge::NodePtr relu3;
  81. if (with_leaf_node == true) {
  82. relu3 = graph->AddNode(relu_op3);
  83. }
  84. ge::OpDescPtr mul_op = std::make_shared<ge::OpDesc>();
  85. mul_op->SetType(MUL);
  86. mul_op->SetName("Mul");
  87. mul_op->AddInputDesc(ge::GeTensorDesc());
  88. mul_op->AddInputDesc(ge::GeTensorDesc());
  89. mul_op->AddOutputDesc(ge::GeTensorDesc());
  90. mul_op->AddOutputDesc(ge::GeTensorDesc());
  91. mul_op->AddOutputDesc(ge::GeTensorDesc());
  92. mul_op->AddOutputDesc(ge::GeTensorDesc());
  93. ge::NodePtr mul = graph->AddNode(mul_op);
  94. ge::OpDescPtr mul_op1 = std::make_shared<ge::OpDesc>();
  95. mul_op1->SetType(MUL);
  96. mul_op1->SetName("Mul1");
  97. mul_op1->AddInputDesc(ge::GeTensorDesc());
  98. mul_op1->AddInputDesc(ge::GeTensorDesc());
  99. mul_op1->AddOutputDesc(ge::GeTensorDesc());
  100. ge::NodePtr mul1 = graph->AddNode(mul_op1);
  101. ge::OpDescPtr mul_op2 = std::make_shared<ge::OpDesc>();
  102. mul_op2->SetType(MUL);
  103. mul_op2->SetName("Mul2");
  104. mul_op2->AddInputDesc(ge::GeTensorDesc());
  105. mul_op2->AddInputDesc(ge::GeTensorDesc());
  106. mul_op2->AddOutputDesc(ge::GeTensorDesc());
  107. ge::NodePtr mul2 = graph->AddNode(mul_op2);
  108. ge::OpDescPtr fc_op = std::make_shared<ge::OpDesc>();
  109. fc_op->SetType(FULL_CONNECTION);
  110. fc_op->SetName("FullConnection");
  111. fc_op->AddInputDesc(ge::GeTensorDesc());
  112. fc_op->AddOutputDesc(ge::GeTensorDesc());
  113. fc_op->AddOutputDesc(ge::GeTensorDesc());
  114. ge::NodePtr fc = graph->AddNode(fc_op);
  115. ge::GraphUtils::AddEdge(data1->GetOutDataAnchor(0), relu1->GetInDataAnchor(0));
  116. ge::GraphUtils::AddEdge(relu1->GetOutDataAnchor(0), fc->GetInDataAnchor(0));
  117. ge::GraphUtils::AddEdge(fc->GetOutDataAnchor(0), relu2->GetInDataAnchor(0));
  118. if (with_leaf_node == true) {
  119. ge::GraphUtils::AddEdge(fc->GetOutDataAnchor(1), relu3->GetInDataAnchor(0));
  120. }
  121. ge::GraphUtils::AddEdge(relu2->GetOutDataAnchor(0), mul->GetInDataAnchor(0));
  122. ge::GraphUtils::AddEdge(relu2->GetOutDataAnchor(1), mul->GetInDataAnchor(1));
  123. ge::GraphUtils::AddEdge(mul->GetOutDataAnchor(0), mul1->GetInDataAnchor(0));
  124. ge::GraphUtils::AddEdge(mul->GetOutDataAnchor(1), mul1->GetInDataAnchor(1));
  125. ge::GraphUtils::AddEdge(mul->GetOutDataAnchor(2), mul2->GetInDataAnchor(0));
  126. ge::GraphUtils::AddEdge(mul->GetOutDataAnchor(3), mul2->GetInDataAnchor(1));
  127. return graph;
  128. }
  129. TEST_F(UtestGraphPassesNetOutputPass, add_ctrl_edge_for_netout_from_leaf_success) {
  130. ge::ComputeGraphPtr compute_graph = build_graph(true);
  131. // construct targets
  132. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  133. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  134. ge::NodePtr relu3 = compute_graph->FindNode("Relu3");
  135. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{relu3, 0}};
  136. compute_graph->SetGraphOutNodesInfo(output_nodes);
  137. ge::PassManager pass_managers;
  138. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  139. Status status = pass_managers.Run(compute_graph);
  140. EXPECT_EQ(status, ge::SUCCESS);
  141. // check contain netoutput
  142. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  143. EXPECT_NE(net_out_node, nullptr);
  144. /// check input data node of netoutput
  145. /// when output and targets set conflicts each other , output set is prio
  146. /// Check data input
  147. int input_data_node_num = net_out_node->GetInDataNodes().size();
  148. EXPECT_EQ(input_data_node_num, 1);
  149. std::vector<string> expect_input_data_result{"Relu3"};
  150. for (auto node : net_out_node->GetInDataNodes()) {
  151. auto name = node->GetName();
  152. auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
  153. if (iter != expect_input_data_result.end()) {
  154. expect_input_data_result.erase(iter);
  155. }
  156. }
  157. input_data_node_num = expect_input_data_result.size();
  158. EXPECT_EQ(input_data_node_num, 0);
  159. // Check control input
  160. int control_node_num = net_out_node->GetInControlNodes().size();
  161. EXPECT_EQ(control_node_num, 2);
  162. std::vector<string> expect_result{"Mul1", "Mul2"};
  163. for (auto node : net_out_node->GetInControlNodes()) {
  164. auto name = node->GetName();
  165. auto iter = std::find(expect_result.begin(), expect_result.end(), name);
  166. if (iter != expect_result.end()) {
  167. expect_result.erase(iter);
  168. }
  169. }
  170. control_node_num = expect_result.size();
  171. EXPECT_EQ(control_node_num, 0);
  172. }
  173. TEST_F(UtestGraphPassesNetOutputPass, only_target_node_success) {
  174. ge::ComputeGraphPtr compute_graph = build_graph();
  175. // construct targets
  176. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  177. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  178. std::vector<ge::NodePtr> target_nodes = {mul1, mul2};
  179. compute_graph->SetGraphTargetNodesInfo(target_nodes);
  180. ge::PassManager pass_managers;
  181. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  182. Status status = pass_managers.Run(compute_graph);
  183. EXPECT_EQ(status, ge::SUCCESS);
  184. // check contain netoutput
  185. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  186. EXPECT_NE(net_out_node, nullptr);
  187. /// check input data node of netoutput
  188. /// Check data input
  189. int input_data_node_num = net_out_node->GetInDataNodes().size();
  190. EXPECT_EQ(input_data_node_num, 0);
  191. // Check control input
  192. int control_node_num = net_out_node->GetInControlNodes().size();
  193. EXPECT_EQ(control_node_num, 2);
  194. std::vector<string> expect_result{"Mul1", "Mul2"};
  195. for (auto node : net_out_node->GetInControlNodes()) {
  196. auto name = node->GetName();
  197. auto iter = std::find(expect_result.begin(), expect_result.end(), name);
  198. if (iter != expect_result.end()) {
  199. expect_result.erase(iter);
  200. }
  201. }
  202. control_node_num = expect_result.size();
  203. EXPECT_EQ(control_node_num, 0);
  204. }
  205. TEST_F(UtestGraphPassesNetOutputPass, targets_with_retval_success) {
  206. ge::ComputeGraphPtr compute_graph = build_graph();
  207. // Imitate the output node of _Retval issued
  208. ge::OpDescPtr retval_node_desc1 = std::make_shared<ge::OpDesc>("reval_node1", FRAMEWORKOP);
  209. retval_node_desc1->AddInputDesc(ge::GeTensorDesc());
  210. (void)ge::AttrUtils::SetStr(retval_node_desc1, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
  211. (void)ge::AttrUtils::SetInt(retval_node_desc1, RETVAL_ATTR_NAME_INDEX, 0);
  212. ge::NodePtr retval_node1 = compute_graph->AddNode(retval_node_desc1);
  213. EXPECT_NE(retval_node1, nullptr);
  214. ge::OpDescPtr retval_node_desc2 = std::make_shared<ge::OpDesc>("reval_node2", FRAMEWORKOP);
  215. retval_node_desc2->AddInputDesc(ge::GeTensorDesc());
  216. (void)ge::AttrUtils::SetStr(retval_node_desc2, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
  217. (void)ge::AttrUtils::SetInt(retval_node_desc2, RETVAL_ATTR_NAME_INDEX, 1);
  218. ge::NodePtr retval_node2 = compute_graph->AddNode(retval_node_desc2);
  219. EXPECT_NE(retval_node2, nullptr);
  220. // construct targets
  221. std::vector<ge::NodePtr> target_nodes = {retval_node1, retval_node2};
  222. compute_graph->SetGraphTargetNodesInfo(target_nodes);
  223. for (NodePtr node : compute_graph->GetDirectNode()) {
  224. if (node->GetName() == "Mul1") {
  225. GraphUtils::AddEdge(node->GetOutDataAnchor(0), retval_node1->GetInDataAnchor(0));
  226. } else if (node->GetName() == "Mul2") {
  227. GraphUtils::AddEdge(node->GetOutDataAnchor(0), retval_node2->GetInDataAnchor(0));
  228. }
  229. }
  230. ge::PassManager pass_managers;
  231. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  232. Status status = pass_managers.Run(compute_graph);
  233. EXPECT_EQ(status, ge::SUCCESS);
  234. // check contain netoutput
  235. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  236. EXPECT_NE(net_out_node, nullptr);
  237. /// check input data node of netoutput
  238. /// Check data input
  239. int input_data_node_num = net_out_node->GetInDataNodes().size();
  240. EXPECT_EQ(input_data_node_num, 0);
  241. // Check control input
  242. int control_node_num = net_out_node->GetInControlNodes().size();
  243. EXPECT_EQ(control_node_num, 2);
  244. std::vector<string> expect_result{"Mul1", "Mul2"};
  245. for (auto node : net_out_node->GetInControlNodes()) {
  246. auto name = node->GetName();
  247. auto iter = std::find(expect_result.begin(), expect_result.end(), name);
  248. if (iter != expect_result.end()) {
  249. expect_result.erase(iter);
  250. }
  251. }
  252. control_node_num = expect_result.size();
  253. EXPECT_EQ(control_node_num, 0);
  254. // Check the deletion of _Retval node
  255. retval_node1 = compute_graph->FindNode("reval_node1");
  256. EXPECT_EQ(retval_node1, nullptr);
  257. retval_node2 = compute_graph->FindNode("reval_node2");
  258. EXPECT_EQ(retval_node2, nullptr);
  259. }
  260. TEST_F(UtestGraphPassesNetOutputPass, output_node_and_target_node_no_duplicate_success) {
  261. ge::ComputeGraphPtr compute_graph = build_graph(true);
  262. // construct targets
  263. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  264. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  265. std::vector<ge::NodePtr> target_nodes = {mul1, mul2};
  266. compute_graph->SetGraphTargetNodesInfo(target_nodes);
  267. ge::NodePtr relu3 = compute_graph->FindNode("Relu3");
  268. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{relu3, 0}};
  269. compute_graph->SetGraphOutNodesInfo(output_nodes);
  270. ge::PassManager pass_managers;
  271. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  272. Status status = pass_managers.Run(compute_graph);
  273. EXPECT_EQ(status, ge::SUCCESS);
  274. // check contain netoutput
  275. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  276. EXPECT_NE(net_out_node, nullptr);
  277. /// check input data node of netoutput
  278. /// when output and targets set conflicts each other , output set is prio
  279. /// Check data input
  280. int input_data_node_num = net_out_node->GetInDataNodes().size();
  281. EXPECT_EQ(input_data_node_num, 1);
  282. std::vector<string> expect_input_data_result{"Relu3"};
  283. for (auto node : net_out_node->GetInDataNodes()) {
  284. auto name = node->GetName();
  285. auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
  286. if (iter != expect_input_data_result.end()) {
  287. expect_input_data_result.erase(iter);
  288. }
  289. }
  290. input_data_node_num = expect_input_data_result.size();
  291. EXPECT_EQ(input_data_node_num, 0);
  292. // Check control input
  293. int control_node_num = net_out_node->GetInControlNodes().size();
  294. EXPECT_EQ(control_node_num, 2);
  295. std::vector<string> expect_result{"Mul1", "Mul2"};
  296. for (auto node : net_out_node->GetInControlNodes()) {
  297. auto name = node->GetName();
  298. auto iter = std::find(expect_result.begin(), expect_result.end(), name);
  299. if (iter != expect_result.end()) {
  300. expect_result.erase(iter);
  301. }
  302. }
  303. control_node_num = expect_result.size();
  304. EXPECT_EQ(control_node_num, 0);
  305. }
  306. TEST_F(UtestGraphPassesNetOutputPass, output_node_and_target_node_duplicate_success) {
  307. ge::ComputeGraphPtr compute_graph = build_graph();
  308. // construct targets
  309. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  310. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  311. std::vector<ge::NodePtr> target_nodes = {mul2};
  312. compute_graph->SetGraphTargetNodesInfo(target_nodes);
  313. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}};
  314. compute_graph->SetGraphOutNodesInfo(output_nodes);
  315. ge::PassManager pass_managers;
  316. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  317. Status status = pass_managers.Run(compute_graph);
  318. EXPECT_EQ(status, ge::SUCCESS);
  319. // check contain netoutput
  320. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  321. EXPECT_NE(net_out_node, nullptr);
  322. /// check input data node of netoutput
  323. /// Check data input
  324. int input_data_node_num = net_out_node->GetInDataNodes().size();
  325. EXPECT_EQ(input_data_node_num, 2);
  326. std::vector<string> expect_input_data_result{"Mul1"};
  327. for (auto node : net_out_node->GetInDataNodes()) {
  328. auto name = node->GetName();
  329. auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
  330. if (iter != expect_input_data_result.end()) {
  331. expect_input_data_result.erase(iter);
  332. }
  333. }
  334. input_data_node_num = expect_input_data_result.size();
  335. EXPECT_EQ(input_data_node_num, 0);
  336. // Check control input
  337. int control_node_num = net_out_node->GetInControlNodes().size();
  338. EXPECT_EQ(control_node_num, 0);
  339. }
  340. TEST_F(UtestGraphPassesNetOutputPass, net_output_node_and_target_node_success) {
  341. ge::ComputeGraphPtr compute_graph = build_graph();
  342. ge::OpDescPtr netout = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
  343. netout->AddInputDesc(ge::GeTensorDesc());
  344. netout->AddInputDesc(ge::GeTensorDesc());
  345. netout->AddOutputDesc(ge::GeTensorDesc());
  346. netout->AddOutputDesc(ge::GeTensorDesc());
  347. ge::NodePtr netout_node = compute_graph->AddNode(netout);
  348. EXPECT_NE(netout_node, nullptr);
  349. for (NodePtr node : compute_graph->GetDirectNode()) {
  350. if (node->GetName() == "Mul1") {
  351. GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInDataAnchor(0));
  352. } else if (node->GetName() == "Mul2") {
  353. GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInDataAnchor(1));
  354. }
  355. }
  356. // construct targets
  357. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  358. std::vector<ge::NodePtr> target_nodes = {mul2};
  359. compute_graph->SetGraphTargetNodesInfo(target_nodes);
  360. ge::PassManager pass_managers;
  361. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  362. Status status = pass_managers.Run(compute_graph);
  363. EXPECT_EQ(status, ge::SUCCESS);
  364. // check contain netoutput
  365. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  366. EXPECT_NE(net_out_node, nullptr);
  367. /// check input data node of netoutput
  368. /// Check data input
  369. int input_data_node_num = net_out_node->GetInDataNodes().size();
  370. EXPECT_EQ(input_data_node_num, 1);
  371. std::vector<string> expect_input_data_result{"Mul1"};
  372. for (auto node : net_out_node->GetInDataNodes()) {
  373. auto name = node->GetName();
  374. auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
  375. if (iter != expect_input_data_result.end()) {
  376. expect_input_data_result.erase(iter);
  377. }
  378. }
  379. input_data_node_num = expect_input_data_result.size();
  380. EXPECT_EQ(input_data_node_num, 0);
  381. // Check control input
  382. int control_node_num = net_out_node->GetInControlNodes().size();
  383. EXPECT_EQ(control_node_num, 1);
  384. std::vector<string> expect_control_data_result{"Mul2"};
  385. for (auto node : net_out_node->GetInControlNodes()) {
  386. auto name = node->GetName();
  387. auto iter = std::find(expect_control_data_result.begin(), expect_control_data_result.end(), name);
  388. if (iter != expect_control_data_result.end()) {
  389. expect_control_data_result.erase(iter);
  390. }
  391. }
  392. control_node_num = expect_control_data_result.size();
  393. EXPECT_EQ(control_node_num, 0);
  394. }
  395. /// graph have netoutput node.User set outputnodes and target nodes at the same time.output nodes
  396. /// include one common node with target nodes.
  397. /// Notice: output nodes set is more prio
  398. TEST_F(UtestGraphPassesNetOutputPass, net_output_node_and_output_nodes_and_target_node_success_1) {
  399. ge::ComputeGraphPtr compute_graph = build_graph();
  400. ge::OpDescPtr netout = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
  401. netout->AddInputDesc(ge::GeTensorDesc());
  402. netout->AddInputDesc(ge::GeTensorDesc());
  403. netout->AddOutputDesc(ge::GeTensorDesc());
  404. netout->AddOutputDesc(ge::GeTensorDesc());
  405. ge::NodePtr netout_node = compute_graph->AddNode(netout);
  406. EXPECT_NE(netout_node, nullptr);
  407. for (NodePtr node : compute_graph->GetDirectNode()) {
  408. if (node->GetName() == "Mul1") {
  409. GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInDataAnchor(0));
  410. } else if (node->GetName() == "Mul2") {
  411. GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInDataAnchor(1));
  412. }
  413. }
  414. // construct targets
  415. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  416. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  417. std::vector<ge::NodePtr> target_nodes = {mul2};
  418. compute_graph->SetGraphTargetNodesInfo(target_nodes);
  419. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}};
  420. compute_graph->SetGraphOutNodesInfo(output_nodes);
  421. ge::PassManager pass_managers;
  422. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  423. Status status = pass_managers.Run(compute_graph);
  424. EXPECT_EQ(status, ge::SUCCESS);
  425. // check contain netoutput
  426. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  427. EXPECT_NE(net_out_node, nullptr);
  428. /// check input data node of netoutput
  429. /// Check data input
  430. int input_data_node_num = net_out_node->GetInDataNodes().size();
  431. EXPECT_EQ(input_data_node_num, 2);
  432. std::vector<string> expect_input_data_result{"Mul1", "Mul2"};
  433. for (auto node : net_out_node->GetInDataNodes()) {
  434. auto name = node->GetName();
  435. auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
  436. if (iter != expect_input_data_result.end()) {
  437. expect_input_data_result.erase(iter);
  438. }
  439. }
  440. input_data_node_num = expect_input_data_result.size();
  441. EXPECT_EQ(input_data_node_num, 0);
  442. // Check control input
  443. int control_node_num = net_out_node->GetInControlNodes().size();
  444. EXPECT_EQ(control_node_num, 0);
  445. }
  446. /// graph have netoutput node.User set outputnodes and target nodes at the same time.output nodes
  447. /// include one common node with target nodes.
  448. /// Notice: output nodes set is more prio
  449. TEST_F(UtestGraphPassesNetOutputPass, net_output_node_and_output_nodes_and_target_node_success_2) {
  450. ge::ComputeGraphPtr compute_graph = build_graph(true);
  451. ge::OpDescPtr netout = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
  452. netout->AddInputDesc(ge::GeTensorDesc());
  453. netout->AddOutputDesc(ge::GeTensorDesc());
  454. ge::NodePtr netout_node = compute_graph->AddNode(netout);
  455. EXPECT_NE(netout_node, nullptr);
  456. for (const auto &node : compute_graph->GetDirectNode()) {
  457. if (node->GetName() == "Mul1") {
  458. GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInDataAnchor(0));
  459. }
  460. if (node->GetName() == "Mul2") {
  461. GraphUtils::AddEdge(node->GetOutControlAnchor(), netout_node->GetInControlAnchor());
  462. }
  463. if (node->GetName() == "Relu3") {
  464. GraphUtils::AddEdge(node->GetOutControlAnchor(), netout_node->GetInControlAnchor());
  465. }
  466. }
  467. // construct targets
  468. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  469. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  470. std::vector<ge::NodePtr> target_nodes = {mul2};
  471. compute_graph->SetGraphTargetNodesInfo(target_nodes);
  472. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}};
  473. compute_graph->SetGraphOutNodesInfo(output_nodes);
  474. ge::PassManager pass_managers;
  475. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  476. Status status = pass_managers.Run(compute_graph);
  477. EXPECT_EQ(status, ge::SUCCESS);
  478. // check contain netoutput
  479. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  480. EXPECT_NE(net_out_node, nullptr);
  481. /// check input data node of netoutput
  482. /// Check data input
  483. int input_data_node_num = net_out_node->GetInDataNodes().size();
  484. EXPECT_EQ(input_data_node_num, 1);
  485. std::vector<string> expect_input_data_result{"Mul1"};
  486. for (const auto &node : net_out_node->GetInDataNodes()) {
  487. auto name = node->GetName();
  488. auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
  489. if (iter != expect_input_data_result.end()) {
  490. expect_input_data_result.erase(iter);
  491. }
  492. }
  493. input_data_node_num = expect_input_data_result.size();
  494. EXPECT_EQ(input_data_node_num, 0);
  495. // Check control input
  496. int control_node_num = net_out_node->GetInControlNodes().size();
  497. EXPECT_EQ(control_node_num, 2);
  498. std::vector<string> expect_control_data_result{"Mul2", "Relu3"};
  499. for (const auto &node : net_out_node->GetInControlNodes()) {
  500. auto name = node->GetName();
  501. auto iter = std::find(expect_control_data_result.begin(), expect_control_data_result.end(), name);
  502. if (iter != expect_control_data_result.end()) {
  503. expect_control_data_result.erase(iter);
  504. }
  505. }
  506. control_node_num = expect_control_data_result.size();
  507. EXPECT_EQ(control_node_num, 0);
  508. }
  509. /// graph have netoutput node.User set outputnodes and target nodes at the same time.output nodes
  510. /// include one common node with target nodes.
  511. /// Notice: output nodes set is more prio
  512. TEST_F(UtestGraphPassesNetOutputPass, net_output_node_and_output_nodes_and_target_node_success_3) {
  513. ge::ComputeGraphPtr compute_graph = build_graph();
  514. ge::OpDescPtr netout = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
  515. netout->AddInputDesc(ge::GeTensorDesc());
  516. netout->AddOutputDesc(ge::GeTensorDesc());
  517. ge::NodePtr netout_node = compute_graph->AddNode(netout);
  518. EXPECT_NE(netout_node, nullptr);
  519. for (const auto &node : compute_graph->GetDirectNode()) {
  520. if (node->GetName() == "Mul1") {
  521. GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInDataAnchor(0));
  522. }
  523. if (node->GetName() == "Mul2") {
  524. GraphUtils::AddEdge(node->GetOutControlAnchor(), netout_node->GetInControlAnchor());
  525. GraphUtils::AddEdge(node->GetOutDataAnchor(0), netout_node->GetInControlAnchor());
  526. }
  527. }
  528. // construct targets
  529. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  530. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  531. std::vector<ge::NodePtr> target_nodes = {mul2};
  532. compute_graph->SetGraphTargetNodesInfo(target_nodes);
  533. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}};
  534. compute_graph->SetGraphOutNodesInfo(output_nodes);
  535. ge::PassManager pass_managers;
  536. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  537. Status status = pass_managers.Run(compute_graph);
  538. EXPECT_EQ(status, ge::SUCCESS);
  539. // check contain netoutput
  540. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  541. EXPECT_NE(net_out_node, nullptr);
  542. /// check input data node of netoutput
  543. /// Check data input
  544. int input_data_node_num = net_out_node->GetInDataNodes().size();
  545. EXPECT_EQ(input_data_node_num, 1);
  546. std::vector<string> expect_input_data_result{"Mul1"};
  547. for (const auto &node : net_out_node->GetInDataNodes()) {
  548. auto name = node->GetName();
  549. auto iter = std::find(expect_input_data_result.begin(), expect_input_data_result.end(), name);
  550. if (iter != expect_input_data_result.end()) {
  551. expect_input_data_result.erase(iter);
  552. }
  553. }
  554. input_data_node_num = expect_input_data_result.size();
  555. EXPECT_EQ(input_data_node_num, 0);
  556. // Check control input
  557. int control_node_num = net_out_node->GetInControlNodes().size();
  558. EXPECT_EQ(control_node_num, 1);
  559. std::vector<string> expect_control_data_result{"Mul2"};
  560. for (const auto &node : net_out_node->GetInControlNodes()) {
  561. auto name = node->GetName();
  562. auto iter = std::find(expect_control_data_result.begin(), expect_control_data_result.end(), name);
  563. if (iter != expect_control_data_result.end()) {
  564. expect_control_data_result.erase(iter);
  565. }
  566. }
  567. control_node_num = expect_control_data_result.size();
  568. EXPECT_EQ(control_node_num, 0);
  569. }
  570. TEST_F(UtestGraphPassesNetOutputPass, no_output_no_target_no_retval_success) {
  571. ge::ComputeGraphPtr compute_graph = build_graph();
  572. // Construct specified output
  573. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  574. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  575. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}};
  576. compute_graph->SetGraphOutNodesInfo(output_nodes);
  577. ge::PassManager pass_managers;
  578. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  579. Status status = pass_managers.Run(compute_graph);
  580. EXPECT_EQ(status, ge::SUCCESS);
  581. }
  582. TEST_F(UtestGraphPassesNetOutputPass, user_out_node_success) {
  583. ge::ComputeGraphPtr compute_graph = build_graph();
  584. // Construct specified output
  585. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  586. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  587. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}};
  588. compute_graph->SetGraphOutNodesInfo(output_nodes);
  589. ge::PassManager pass_managers;
  590. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  591. Status status = pass_managers.Run(compute_graph);
  592. EXPECT_EQ(status, ge::SUCCESS);
  593. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  594. EXPECT_NE(net_out_node, nullptr);
  595. // Check data input
  596. string str;
  597. for (ge::NodePtr input_data_node : net_out_node->GetInDataNodes()) {
  598. str += input_data_node->GetName() + ";";
  599. }
  600. EXPECT_EQ(str, "Mul1;Mul2;");
  601. // Check control input
  602. int control_node_num = net_out_node->GetInControlNodes().size();
  603. EXPECT_EQ(control_node_num, 0);
  604. }
  605. TEST_F(UtestGraphPassesNetOutputPass, retval_node_for_out_success) {
  606. ge::ComputeGraphPtr compute_graph = build_graph();
  607. // Imitate the output node of _Retval issued
  608. ge::OpDescPtr retval_node_desc1 = std::make_shared<ge::OpDesc>("reval_node1", FRAMEWORKOP);
  609. retval_node_desc1->AddInputDesc(ge::GeTensorDesc());
  610. (void)ge::AttrUtils::SetStr(retval_node_desc1, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
  611. (void)ge::AttrUtils::SetInt(retval_node_desc1, RETVAL_ATTR_NAME_INDEX, 0);
  612. ge::NodePtr retval_node1 = compute_graph->AddNode(retval_node_desc1);
  613. EXPECT_NE(retval_node1, nullptr);
  614. ge::OpDescPtr retval_node_desc2 = std::make_shared<ge::OpDesc>("reval_node2", FRAMEWORKOP);
  615. retval_node_desc2->AddInputDesc(ge::GeTensorDesc());
  616. (void)ge::AttrUtils::SetStr(retval_node_desc2, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
  617. (void)ge::AttrUtils::SetInt(retval_node_desc2, RETVAL_ATTR_NAME_INDEX, 1);
  618. ge::NodePtr retval_node2 = compute_graph->AddNode(retval_node_desc2);
  619. EXPECT_NE(retval_node2, nullptr);
  620. for (NodePtr node : compute_graph->GetDirectNode()) {
  621. if (node->GetName() == "Mul1") {
  622. GraphUtils::AddEdge(node->GetOutDataAnchor(0), retval_node1->GetInDataAnchor(0));
  623. } else if (node->GetName() == "Mul2") {
  624. GraphUtils::AddEdge(node->GetOutDataAnchor(0), retval_node2->GetInDataAnchor(0));
  625. }
  626. }
  627. ge::PassManager pass_managers;
  628. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  629. Status status = pass_managers.Run(compute_graph);
  630. EXPECT_EQ(status, ge::SUCCESS);
  631. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  632. EXPECT_NE(net_out_node, nullptr);
  633. // Check data input
  634. string str;
  635. for (ge::NodePtr input_data_node : net_out_node->GetInDataNodes()) {
  636. str += input_data_node->GetName() + ";";
  637. }
  638. EXPECT_EQ(str, "Mul1;Mul2;");
  639. // Check control input
  640. int control_node_num = net_out_node->GetInControlNodes().size();
  641. EXPECT_EQ(control_node_num, 0);
  642. // Check the deletion of _Retval node
  643. retval_node1 = compute_graph->FindNode("reval_node1");
  644. EXPECT_EQ(retval_node1, nullptr);
  645. retval_node2 = compute_graph->FindNode("reval_node2");
  646. EXPECT_EQ(retval_node2, nullptr);
  647. }
  648. TEST_F(UtestGraphPassesNetOutputPass, check_order_and_const_flag_success) {
  649. ge::ComputeGraphPtr compute_graph = build_graph();
  650. ge::OpDescPtr const_node_desc = std::make_shared<ge::OpDesc>("const_output", CONSTANT);
  651. const_node_desc->AddOutputDesc(ge::GeTensorDesc());
  652. ge::NodePtr const_node = compute_graph->AddNode(const_node_desc);
  653. EXPECT_NE(const_node, nullptr);
  654. NodePtr mul1 = compute_graph->FindNode("Mul1");
  655. EXPECT_NE(mul1, nullptr);
  656. GraphUtils::AddEdge(mul1->GetOutControlAnchor(), const_node->GetInControlAnchor());
  657. // Construct specified output
  658. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{const_node, 0}};
  659. compute_graph->SetGraphOutNodesInfo(output_nodes);
  660. ge::OpDescPtr retval_node_desc2 = std::make_shared<ge::OpDesc>("reval_node2", FRAMEWORKOP);
  661. retval_node_desc2->AddInputDesc(ge::GeTensorDesc());
  662. (void)ge::AttrUtils::SetStr(retval_node_desc2, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
  663. (void)ge::AttrUtils::SetInt(retval_node_desc2, RETVAL_ATTR_NAME_INDEX, 0);
  664. ge::NodePtr retval_node2 = compute_graph->AddNode(retval_node_desc2);
  665. EXPECT_NE(retval_node2, nullptr);
  666. NodePtr mul2 = compute_graph->FindNode("Mul2");
  667. EXPECT_NE(mul2, nullptr);
  668. GraphUtils::AddEdge(mul2->GetOutDataAnchor(0), retval_node2->GetInDataAnchor(0));
  669. ge::PassManager pass_managers;
  670. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  671. Status status = pass_managers.Run(compute_graph);
  672. EXPECT_EQ(status, ge::SUCCESS);
  673. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  674. EXPECT_NE(net_out_node, nullptr);
  675. // Check data input
  676. string str;
  677. for (ge::NodePtr input_data_node : net_out_node->GetInDataNodes()) {
  678. str += input_data_node->GetName() + ";";
  679. }
  680. EXPECT_EQ(str, "const_output;Mul2;");
  681. // Check control input
  682. int control_node_num = net_out_node->GetInControlNodes().size();
  683. EXPECT_EQ(control_node_num, 0);
  684. // Check is_input_const flag
  685. std::vector<bool> is_input_const = net_out_node->GetOpDesc()->GetIsInputConst();
  686. EXPECT_EQ(is_input_const.size(), 2);
  687. EXPECT_EQ(is_input_const[0], true);
  688. EXPECT_EQ(is_input_const[1], false);
  689. // Check the deletion of _Retval node
  690. retval_node2 = compute_graph->FindNode("reval_node2");
  691. EXPECT_EQ(retval_node2, nullptr);
  692. }
  693. TEST_F(UtestGraphPassesNetOutputPass, out_node_check_fail) {
  694. ge::ComputeGraphPtr compute_graph = build_graph();
  695. // Construct specified output
  696. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  697. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  698. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes_invalid_name = {{nullptr, 0}, {mul2, 0}};
  699. compute_graph->SetGraphOutNodesInfo(output_nodes_invalid_name);
  700. ge::PassManager pass_managers;
  701. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  702. Status status = pass_managers.Run(compute_graph);
  703. EXPECT_EQ(status, ge::INTERNAL_ERROR);
  704. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  705. EXPECT_EQ(net_out_node, nullptr);
  706. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes_invalid_index = {{mul1, 0}, {mul2, 100}};
  707. compute_graph->SetGraphOutNodesInfo(output_nodes_invalid_index);
  708. status = pass_managers.Run(compute_graph);
  709. EXPECT_EQ(status, ge::INTERNAL_ERROR);
  710. net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  711. EXPECT_EQ(net_out_node, nullptr);
  712. }
  713. TEST_F(UtestGraphPassesNetOutputPass, retval_node_check_fail) {
  714. ge::ComputeGraphPtr compute_graph = build_graph();
  715. // Imitate the output node of _Retval issued
  716. ge::OpDescPtr retval_node_desc1 = std::make_shared<ge::OpDesc>("reval_node1", FRAMEWORKOP);
  717. retval_node_desc1->AddInputDesc(ge::GeTensorDesc());
  718. (void)ge::AttrUtils::SetStr(retval_node_desc1, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
  719. (void)ge::AttrUtils::SetInt(retval_node_desc1, RETVAL_ATTR_NAME_INDEX, 0);
  720. ge::NodePtr retval_node1 = compute_graph->AddNode(retval_node_desc1);
  721. EXPECT_NE(retval_node1, nullptr);
  722. ge::OpDescPtr retval_node_desc2 = std::make_shared<ge::OpDesc>("reval_node2", FRAMEWORKOP);
  723. retval_node_desc2->AddInputDesc(ge::GeTensorDesc());
  724. (void)ge::AttrUtils::SetStr(retval_node_desc2, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, "_Retval");
  725. (void)ge::AttrUtils::SetInt(retval_node_desc2, RETVAL_ATTR_NAME_INDEX, 0);
  726. ge::NodePtr retval_node2 = compute_graph->AddNode(retval_node_desc2);
  727. EXPECT_NE(retval_node2, nullptr);
  728. for (NodePtr node : compute_graph->GetDirectNode()) {
  729. if (node->GetName() == "Mul1") {
  730. GraphUtils::AddEdge(node->GetOutDataAnchor(0), retval_node1->GetInDataAnchor(0));
  731. } else if (node->GetName() == "Mul2") {
  732. GraphUtils::AddEdge(node->GetOutDataAnchor(0), retval_node2->GetInDataAnchor(0));
  733. }
  734. }
  735. ge::PassManager pass_managers;
  736. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  737. Status status = pass_managers.Run(compute_graph);
  738. EXPECT_EQ(status, ge::INTERNAL_ERROR);
  739. NodePtr net_out_node = compute_graph->FindNode(NODE_NAME_NET_OUTPUT);
  740. EXPECT_EQ(net_out_node, nullptr);
  741. }
  742. TEST_F(UtestGraphPassesNetOutputPass, out_node_update_desc_check_fail) {
  743. ge::ComputeGraphPtr compute_graph = build_graph();
  744. ge::OpDescPtr netout = std::make_shared<ge::OpDesc>(NODE_NAME_NET_OUTPUT, NETOUTPUT);
  745. ge::NodePtr netout_node = compute_graph->AddNode(netout);
  746. EXPECT_NE(netout_node, nullptr);
  747. ge::PassManager pass_managers;
  748. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  749. Status status = pass_managers.Run(compute_graph);
  750. EXPECT_EQ(status, ge::INTERNAL_ERROR);
  751. }
  752. TEST_F(UtestGraphPassesNetOutputPass, out_node_remove_check_fail) {
  753. ge::ComputeGraphPtr compute_graph = build_graph();
  754. // Construct specified output
  755. ge::NodePtr mul1 = compute_graph->FindNode("Mul1");
  756. ge::NodePtr mul2 = compute_graph->FindNode("Mul2");
  757. std::vector<std::pair<ge::NodePtr, int32_t>> output_nodes = {{mul1, 0}, {mul2, 0}};
  758. compute_graph->SetGraphOutNodesInfo(output_nodes);
  759. // compute_graph->RemoveNode(mul1);
  760. mul1->GetInDataAnchor(0)->UnlinkAll();
  761. mul1->GetInDataAnchor(1)->UnlinkAll();
  762. GraphUtils::RemoveNodeWithoutRelink(compute_graph, mul1);
  763. mul1 = compute_graph->FindNode("Mul1");
  764. EXPECT_EQ(mul1, nullptr);
  765. ge::PassManager pass_managers;
  766. pass_managers.AddPass(new (std::nothrow) NetOutputPass);
  767. Status status = pass_managers.Run(compute_graph);
  768. EXPECT_EQ(status, ge::SUCCESS);
  769. }
  770. TEST_F(UtestGraphPassesNetOutputPass, clear_weight) {
  771. ge::ComputeGraphPtr compute_graph = BuildClearWeightGraph();
  772. auto cast = compute_graph->FindNode("Cast1");
  773. Status ret = ge::OpDescUtils::ClearWeights(cast);
  774. EXPECT_EQ(ge::SUCCESS, ret);
  775. }

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示