You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

base_pass_unittest.cc 32 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <iostream>
  17. #include <map>
  18. #include <set>
  19. #include <vector>
  20. #include "gtest/gtest.h"
  21. #define protected public
  22. #include "graph/passes/base_pass.h"
  23. #undef protected
  24. #include "framework/common/types.h"
  25. #include "graph/node.h"
  26. #include "graph/utils/graph_utils.h"
  27. #include "graph_builder_utils.h"
  28. template class std::unordered_set<ge::NodePtr>;
  29. namespace ge {
  30. class UtestTestPass : public BaseNodePass {
  31. public:
  32. UtestTestPass() = default;
  33. UtestTestPass(bool dead_loop) : dead_loop_(dead_loop), run_times_(0) {}
  34. Status Run(NodePtr &node) override {
  35. ++run_times_;
  36. iter_nodes_.push_back(node);
  37. auto iter = names_to_add_del_.find(node->GetName());
  38. if (iter != names_to_add_del_.end()) {
  39. for (const auto &node_name : iter->second) {
  40. auto del_node = node->GetOwnerComputeGraph()->FindNode(node_name);
  41. GraphUtils::IsolateNode(del_node, {0});
  42. AddNodeDeleted(del_node);
  43. }
  44. }
  45. iter = names_to_add_repass_.find(node->GetName());
  46. if (iter != names_to_add_repass_.end()) {
  47. auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes();
  48. for (const auto &node_name : iter->second) {
  49. for (auto &node_re_pass : all_nodes) {
  50. if (node_re_pass->GetName() == node_name) {
  51. AddRePassNode(node_re_pass);
  52. break;
  53. }
  54. }
  55. }
  56. if (!dead_loop_) {
  57. names_to_add_repass_.erase(iter);
  58. }
  59. }
  60. iter = names_to_add_repass_immediate_.find(node->GetName());
  61. if (iter != names_to_add_repass_immediate_.end()) {
  62. auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes();
  63. for (const auto &node_name : iter->second) {
  64. for (auto &node_re_pass : all_nodes) {
  65. if (node_re_pass->GetName() == node_name) {
  66. AddImmediateRePassNode(node_re_pass);
  67. break;
  68. }
  69. }
  70. }
  71. if (!dead_loop_) {
  72. names_to_add_repass_immediate_.erase(iter);
  73. }
  74. }
  75. iter = names_to_add_suspend_.find(node->GetName());
  76. if (iter != names_to_add_suspend_.end()) {
  77. auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes();
  78. for (const auto &node_name : iter->second) {
  79. for (auto &node_re_pass : all_nodes) {
  80. if (node_re_pass->GetName() == node_name) {
  81. AddNodeSuspend(node_re_pass);
  82. break;
  83. }
  84. }
  85. }
  86. if (!dead_loop_) {
  87. names_to_add_suspend_.erase(iter);
  88. }
  89. }
  90. iter = names_to_add_resume_.find(node->GetName());
  91. if (iter != names_to_add_resume_.end()) {
  92. auto all_nodes = node->GetOwnerComputeGraph()->GetAllNodes();
  93. for (const auto &node_name : iter->second) {
  94. for (auto &node_re_pass : all_nodes) {
  95. if (node_re_pass->GetName() == node_name) {
  96. AddNodeResume(node_re_pass);
  97. break;
  98. }
  99. }
  100. }
  101. if (!dead_loop_) {
  102. names_to_add_resume_.erase(iter);
  103. }
  104. }
  105. // simulate infershape pass
  106. if(node->GetType() == WHILE){
  107. bool need_repass = false;
  108. AttrUtils::GetBool(node->GetOpDesc(),"_need_infer_again", need_repass);
  109. if(!OptionExists(kOptimizeAfterSubGraph)){
  110. return SUCCESS;
  111. }
  112. if(need_repass){
  113. AttrUtils::SetBool(node->GetOpDesc(),"_need_infer_again", false);
  114. AddImmediateRePassNode(node);
  115. }
  116. else{
  117. // clear attr on while
  118. node->GetOpDesc()->DelAttr("_need_infer_again");
  119. }
  120. }
  121. return SUCCESS;
  122. }
  123. Status OnSuspendNodesLeaked() override {
  124. // resume all node remain in suspend_nodes when leaked
  125. auto compute_graph = (iter_nodes_.size() > 0) ? iter_nodes_[0]->GetOwnerComputeGraph() : nullptr;
  126. if (compute_graph == nullptr) {
  127. return SUCCESS;
  128. }
  129. for (const auto &node_name : names_to_add_resume_onleaked_) {
  130. auto node_to_resume = compute_graph->FindNode(node_name);
  131. AddNodeResume(node_to_resume);
  132. }
  133. return SUCCESS;
  134. }
  135. void clear() { iter_nodes_.clear(); }
  136. std::vector<NodePtr> GetIterNodes() { return iter_nodes_; }
  137. void AddRePassNodeName(const std::string &iter_node, const std::string &re_pass_node) {
  138. names_to_add_repass_[iter_node].insert(re_pass_node);
  139. }
  140. void AddDelNodeName(const std::string &iter_node, const std::string &del_node) {
  141. names_to_add_del_[iter_node].insert(del_node);
  142. }
  143. void AddRePassImmediateNodeName(const std::string &iter_node, const std::string &re_pass_node) {
  144. names_to_add_repass_immediate_[iter_node].insert(re_pass_node);
  145. }
  146. void AddSuspendNodeName(const std::string &iter_node, const std::string &suspend_node) {
  147. names_to_add_suspend_[iter_node].insert(suspend_node);
  148. }
  149. void AddResumeNodeName(const std::string &iter_node, const std::string &resume_node) {
  150. names_to_add_resume_[iter_node].insert(resume_node);
  151. }
  152. void AddResumeNodeNameOnLeaked(const std::string &resume_node) {
  153. names_to_add_resume_onleaked_.insert(resume_node);
  154. }
  155. unsigned int GetRunTimes() { return run_times_; }
  156. private:
  157. std::vector<NodePtr> iter_nodes_;
  158. std::map<std::string, std::unordered_set<std::string>> names_to_add_del_;
  159. std::map<std::string, std::unordered_set<std::string>> names_to_add_repass_;
  160. std::map<std::string, std::unordered_set<std::string>> names_to_add_repass_immediate_;
  161. std::map<std::string, std::unordered_set<std::string>> names_to_add_suspend_;
  162. std::map<std::string, std::unordered_set<std::string>> names_to_add_resume_;
  163. std::unordered_set<std::string> names_to_add_resume_onleaked_;
  164. bool dead_loop_;
  165. unsigned int run_times_;
  166. };
  167. class TestDelPass : public BaseNodePass {
  168. public:
  169. Status Run(NodePtr &node) override { return SUCCESS; }
  170. };
  171. class UTESTGraphPassesBasePass : public testing::Test {
  172. protected:
  173. UTESTGraphPassesBasePass() {
  174. auto p1 = new UtestTestPass;
  175. names_to_pass_.push_back(std::make_pair("test1", p1));
  176. }
  177. void SetUp() override {
  178. for (auto &name_to_pass : names_to_pass_) {
  179. dynamic_cast<UtestTestPass *>(name_to_pass.second)->clear();
  180. }
  181. }
  182. ~UTESTGraphPassesBasePass() override {
  183. for (auto &name_to_pass : names_to_pass_) {
  184. delete name_to_pass.second;
  185. }
  186. }
  187. NamesToPass names_to_pass_;
  188. };
  189. /// reshape1
  190. /// |
  191. /// add1
  192. /// / \.
  193. /// | |
  194. /// data1 const1
  195. ComputeGraphPtr BuildGraph1() {
  196. auto builder = ut::GraphBuilder("g1");
  197. auto data = builder.AddNode("data1", DATA, 0, 1);
  198. auto a1 = builder.AddNode("add1", ADD, 2, 1);
  199. auto c1 = builder.AddNode("const1", CONSTANT, 0, 1);
  200. auto r1 = builder.AddNode("reshape1", RESHAPE, 1, 1);
  201. builder.AddDataEdge(data, 0, a1, 0);
  202. builder.AddDataEdge(c1, 0, a1, 1);
  203. builder.AddDataEdge(a1, 0, r1, 0);
  204. return builder.GetGraph();
  205. }
  206. /// sum1
  207. /// / \.
  208. /// / \.
  209. /// / \.
  210. /// reshape1 addn1
  211. /// | c |
  212. /// add1 <--- shape1
  213. /// / \ |
  214. /// | | |
  215. /// data1 const1 const2
  216. ComputeGraphPtr BuildGraph2() {
  217. auto builder = ut::GraphBuilder("g1");
  218. auto data1 = builder.AddNode("data1", DATA, 0, 1);
  219. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  220. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  221. auto add1 = builder.AddNode("add1", ADD, 2, 1);
  222. auto shape1 = builder.AddNode("shape1", SHAPE, 1, 1);
  223. auto reshape1 = builder.AddNode("reshape1", RESHAPE, 1, 1);
  224. auto addn1 = builder.AddNode("addn1", ADDN, 1, 1);
  225. auto sum1 = builder.AddNode("sum1", SUM, 2, 1);
  226. builder.AddDataEdge(data1, 0, add1, 0);
  227. builder.AddDataEdge(const1, 0, add1, 1);
  228. builder.AddDataEdge(const2, 0, shape1, 0);
  229. builder.AddControlEdge(shape1, add1);
  230. builder.AddDataEdge(add1, 0, reshape1, 0);
  231. builder.AddDataEdge(shape1, 0, addn1, 0);
  232. builder.AddDataEdge(reshape1, 0, sum1, 0);
  233. builder.AddDataEdge(addn1, 0, sum1, 1);
  234. return builder.GetGraph();
  235. }
  236. /// rnextiteration
  237. /// | |
  238. /// merge
  239. /// |
  240. /// data1
  241. ComputeGraphPtr BuildGraph3() {
  242. auto builder = ut::GraphBuilder("g1");
  243. auto data1 = builder.AddNode("data1", DATA, 0, 1);
  244. auto merge1 = builder.AddNode("merge1", MERGE, 2, 1);
  245. auto next1 = builder.AddNode("next1", NEXTITERATION, 1, 1);
  246. builder.AddDataEdge(data1, 0, merge1, 0);
  247. builder.AddDataEdge(merge1, 0, next1, 0);
  248. builder.AddDataEdge(next1, 0, merge1, 1);
  249. builder.AddControlEdge(merge1, next1);
  250. builder.AddControlEdge(next1, merge1);
  251. return builder.GetGraph();
  252. }
  253. /// cast1--shape1
  254. /// /
  255. /// data1
  256. /// \
  257. /// transdata1--shape2
  258. ComputeGraphPtr BuildGraph4() {
  259. auto builder = ut::GraphBuilder("g1");
  260. auto data1 = builder.AddNode("data1", DATA, 0, 1);
  261. auto cast1 = builder.AddNode("cast1", CAST, 1, 1);
  262. auto shape1 = builder.AddNode("shape1", SHAPE, 1, 1);
  263. auto transdata1 = builder.AddNode("transdata1", TRANSDATA, 1, 1);
  264. auto shape2 = builder.AddNode("shape2", SHAPE, 1, 1);
  265. builder.AddDataEdge(data1, 0, cast1, 0);
  266. builder.AddDataEdge(data1, 0, transdata1, 0);
  267. builder.AddDataEdge(cast1, 0, shape1, 0);
  268. builder.AddDataEdge(transdata1, 0, shape2, 0);
  269. return builder.GetGraph();
  270. }
  271. void CheckIterOrder(UtestTestPass *pass, std::vector<std::unordered_set<std::string>> &nodes_layers) {
  272. std::unordered_set<std::string> layer_nodes;
  273. size_t layer_index = 0;
  274. for (const auto &node : pass->GetIterNodes()) {
  275. layer_nodes.insert(node->GetName());
  276. EXPECT_LT(layer_index, nodes_layers.size());
  277. if (layer_nodes == nodes_layers[layer_index]) {
  278. layer_index++;
  279. layer_nodes.clear();
  280. }
  281. }
  282. EXPECT_EQ(layer_index, nodes_layers.size());
  283. }
  284. /// Op1
  285. /// |
  286. /// Merge
  287. /// / \.
  288. /// Op2 Op3
  289. TEST_F(UTESTGraphPassesBasePass, del_isolate_fail) {
  290. auto builder = ut::GraphBuilder("g1");
  291. auto merge_node = builder.AddNode("Merge", MERGE, 1, 1);
  292. auto node1 = builder.AddNode("Op1", RELU, 1, 1);
  293. auto node2 = builder.AddNode("Op2", CONVOLUTION, 1, 1);
  294. auto node3 = builder.AddNode("Op3", CONVOLUTION, 1, 1);
  295. GraphUtils::AddEdge(node1->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0));
  296. GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0));
  297. GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node3->GetInDataAnchor(0));
  298. EXPECT_EQ(node1->GetOutDataNodes().size(), 1);
  299. TestDelPass del_pass;
  300. auto ret = del_pass.IsolateAndDeleteNode(merge_node, {0, -1});
  301. EXPECT_EQ(ret, FAILED);
  302. OpDescPtr op_desc = std::make_shared<OpDesc>("merge", MERGE);
  303. NodePtr node = shared_ptr<Node>(new (std::nothrow) Node(op_desc, nullptr));
  304. ret = del_pass.IsolateAndDeleteNode(node, {0, -1});
  305. EXPECT_EQ(ret, FAILED);
  306. }
  307. /// Op1
  308. /// |
  309. /// Merge
  310. /// / \.
  311. /// Op2 Op3
  312. TEST_F(UTESTGraphPassesBasePass, del_isolate_success) {
  313. auto builder = ut::GraphBuilder("g1");
  314. auto merge_node = builder.AddNode("Merge", MERGE, 1, 2);
  315. auto node1 = builder.AddNode("Op1", RELU, 1, 1);
  316. auto node2 = builder.AddNode("Op2", CONVOLUTION, 1, 1);
  317. auto node3 = builder.AddNode("Op3", CONVOLUTION, 1, 1);
  318. GraphUtils::AddEdge(node1->GetOutDataAnchor(0), merge_node->GetInDataAnchor(0));
  319. GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0));
  320. GraphUtils::AddEdge(merge_node->GetOutDataAnchor(0), node3->GetInDataAnchor(0));
  321. EXPECT_EQ(node1->GetOutDataNodes().size(), 1);
  322. TestDelPass del_pass;
  323. auto ret = del_pass.IsolateAndDeleteNode(merge_node, {0, -1});
  324. EXPECT_EQ(ret, SUCCESS);
  325. }
  326. TEST_F(UTESTGraphPassesBasePass, data_graph) {
  327. auto graph = BuildGraph1();
  328. auto ge_pass = GEPass(graph);
  329. EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
  330. auto *pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
  331. EXPECT_EQ(pass->GetIterNodes().size(), 4);
  332. std::vector<std::unordered_set<std::string>> layers;
  333. layers.push_back({"data1", "const1"});
  334. layers.push_back({"add1"});
  335. layers.push_back({"reshape1"});
  336. CheckIterOrder(pass, layers);
  337. }
  338. TEST_F(UTESTGraphPassesBasePass, graph_with_control_link) {
  339. auto graph = BuildGraph2();
  340. auto ge_pass = GEPass(graph);
  341. EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
  342. auto *pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
  343. EXPECT_EQ(pass->GetIterNodes().size(), 8);
  344. EXPECT_EQ(pass->GetIterNodes().at(3)->GetName(), "shape1");
  345. std::vector<std::unordered_set<std::string>> layers;
  346. layers.push_back({"data1", "const1", "const2"});
  347. layers.push_back({"shape1"});
  348. layers.push_back({"add1", "addn1", "reshape1"});
  349. layers.push_back({"sum1"});
  350. CheckIterOrder(pass, layers);
  351. }
  352. TEST_F(UTESTGraphPassesBasePass, re_pass_after) {
  353. NamesToPass names_to_pass;
  354. auto test_pass = UtestTestPass();
  355. names_to_pass.push_back(std::make_pair("test", &test_pass));
  356. test_pass.AddRePassNodeName("add1", "sum1");
  357. test_pass.AddRePassNodeName("shape1", "sum1");
  358. test_pass.AddRePassNodeName("shape1", "add1");
  359. test_pass.AddRePassNodeName("data1", "add1");
  360. auto graph = BuildGraph2();
  361. auto ge_pass = GEPass(graph);
  362. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  363. EXPECT_EQ(test_pass.GetIterNodes().size(), 8);
  364. }
  365. TEST_F(UTESTGraphPassesBasePass, re_pass_before) {
  366. NamesToPass names_to_pass;
  367. auto test_pass = UtestTestPass();
  368. names_to_pass.push_back(std::make_pair("test", &test_pass));
  369. test_pass.AddRePassNodeName("add1", "data1");
  370. auto graph = BuildGraph1();
  371. auto ge_pass = GEPass(graph);
  372. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  373. EXPECT_EQ(test_pass.GetIterNodes().size(), 5);
  374. EXPECT_EQ(test_pass.GetIterNodes().at(2)->GetName(), "add1");
  375. EXPECT_EQ(test_pass.GetIterNodes().at(3)->GetName(), "reshape1");
  376. EXPECT_EQ(test_pass.GetIterNodes().at(4)->GetName(), "data1");
  377. }
  378. TEST_F(UTESTGraphPassesBasePass, re_pass_before_multi_times) {
  379. NamesToPass names_to_pass;
  380. auto test_pass = UtestTestPass();
  381. names_to_pass.push_back(std::make_pair("test", &test_pass));
  382. test_pass.AddRePassNodeName("add1", "data1");
  383. test_pass.AddRePassNodeName("add1", "const1");
  384. test_pass.AddRePassNodeName("reshape1", "data1");
  385. auto graph = BuildGraph1();
  386. auto ge_pass = GEPass(graph);
  387. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  388. EXPECT_EQ(test_pass.GetIterNodes().size(), 6);
  389. EXPECT_EQ(test_pass.GetIterNodes().at(2)->GetName(), "add1");
  390. EXPECT_EQ(test_pass.GetIterNodes().at(3)->GetName(), "reshape1");
  391. }
  392. TEST_F(UTESTGraphPassesBasePass, del_after) {
  393. NamesToPass names_to_pass;
  394. auto test_pass = UtestTestPass();
  395. names_to_pass.push_back(std::make_pair("test", &test_pass));
  396. test_pass.AddDelNodeName("add1", "sum1");
  397. auto graph = BuildGraph2();
  398. auto ge_pass = GEPass(graph);
  399. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  400. EXPECT_EQ(test_pass.GetIterNodes().size(), 7);
  401. }
  402. TEST_F(UTESTGraphPassesBasePass, del_after_multiple) {
  403. NamesToPass names_to_pass;
  404. auto test_pass = UtestTestPass();
  405. names_to_pass.push_back(std::make_pair("test", &test_pass));
  406. test_pass.AddDelNodeName("add1", "sum1");
  407. test_pass.AddDelNodeName("add1", "reshape1");
  408. auto graph = BuildGraph2();
  409. auto ge_pass = GEPass(graph);
  410. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  411. EXPECT_EQ(test_pass.GetIterNodes().size(), 6);
  412. }
  413. TEST_F(UTESTGraphPassesBasePass, del_after_break_link) {
  414. NamesToPass names_to_pass;
  415. auto test_pass = UtestTestPass();
  416. names_to_pass.push_back(std::make_pair("test", &test_pass));
  417. test_pass.AddDelNodeName("shape1", "add1");
  418. test_pass.AddDelNodeName("shape1", "addn1");
  419. test_pass.AddRePassNodeName("shape1", "shape1");
  420. test_pass.AddRePassNodeName("shape1", "reshape1");
  421. test_pass.AddRePassNodeName("shape1", "sum1");
  422. auto graph = BuildGraph2();
  423. auto ge_pass = GEPass(graph);
  424. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  425. EXPECT_EQ(test_pass.GetIterNodes().size(), 7);
  426. }
  427. TEST_F(UTESTGraphPassesBasePass, del_self_and_after) {
  428. NamesToPass names_to_pass;
  429. auto test_pass = UtestTestPass();
  430. names_to_pass.push_back(std::make_pair("test", &test_pass));
  431. test_pass.AddDelNodeName("shape1", "add1");
  432. test_pass.AddDelNodeName("shape1", "addn1");
  433. auto graph = BuildGraph2();
  434. auto ge_pass = GEPass(graph);
  435. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  436. EXPECT_EQ(test_pass.GetIterNodes().size(), 6);
  437. }
  438. TEST_F(UTESTGraphPassesBasePass, del_before) {
  439. NamesToPass names_to_pass;
  440. auto test_pass = UtestTestPass();
  441. names_to_pass.push_back(std::make_pair("test", &test_pass));
  442. test_pass.AddDelNodeName("reshape1", "add1");
  443. test_pass.AddDelNodeName("sum1", "addn1");
  444. auto graph = BuildGraph2();
  445. auto ge_pass = GEPass(graph);
  446. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  447. EXPECT_EQ(test_pass.GetIterNodes().size(), 8);
  448. }
  449. TEST_F(UTESTGraphPassesBasePass, re_pass_and_del) {
  450. NamesToPass names_to_pass;
  451. auto test_pass = UtestTestPass();
  452. names_to_pass.push_back(std::make_pair("test", &test_pass));
  453. test_pass.AddRePassNodeName("add1", "sum1");
  454. test_pass.AddDelNodeName("reshape1", "sum1");
  455. auto graph = BuildGraph2();
  456. auto ge_pass = GEPass(graph);
  457. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  458. EXPECT_EQ(test_pass.GetIterNodes().size(), 7);
  459. }
  460. /*
  461. TEST_F(UTESTGraphPassesBasePass, dead_loop) {
  462. NamesToPass names_to_pass;
  463. auto test_pass = UtestTestPass(true);
  464. names_to_pass.push_back(std::make_pair("test", &test_pass));
  465. test_pass.AddRePassNodeName("add1", "sum1");
  466. test_pass.AddRePassNodeName("sum1", "add1");
  467. auto graph = BuildGraph2();
  468. auto ge_pass = GEPass(graph);
  469. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  470. EXPECT_EQ(test_pass.GetRunTimes(), 1007);
  471. }
  472. */
  473. TEST_F(UTESTGraphPassesBasePass, while_loop) {
  474. NamesToPass names_to_pass;
  475. auto test_pass = UtestTestPass(true);
  476. names_to_pass.push_back(std::make_pair("test", &test_pass));
  477. auto graph = BuildGraph3();
  478. auto ge_pass = GEPass(graph);
  479. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  480. }
  481. /// data1 const
  482. /// \ /
  483. /// while
  484. /// / \.
  485. /// | |
  486. /// cast1 cast2
  487. ComputeGraphPtr BuildWhileGraph1() {
  488. // build sub graph
  489. auto builder_sub = ut::GraphBuilder("sub");
  490. auto data_1 = builder_sub.AddNode("data_1", DATA, 0, 1);
  491. auto data_2 = builder_sub.AddNode("data_2", DATA, 0, 1);
  492. auto add = builder_sub.AddNode("add", ADD, 2, 1);
  493. builder_sub.AddDataEdge(data_1, 0, add, 0);
  494. builder_sub.AddDataEdge(data_2, 0, add, 1);
  495. auto sub_graph = builder_sub.GetGraph();
  496. sub_graph->SetName("while_sub");
  497. // build root graph
  498. auto builder = ut::GraphBuilder("g1");
  499. auto data = builder.AddNode("data1", DATA, 0, 1);
  500. auto const_op = builder.AddNode("const_op", CONSTANT, 0, 1);
  501. auto c1 = builder.AddNode("cast1", CAST, 1, 1);
  502. auto c2 = builder.AddNode("cast2", CAST, 1, 1);
  503. // add while op
  504. auto tensor_desc = std::make_shared<GeTensorDesc>();
  505. tensor_desc->SetShape(GeShape({1,1,1,1}));
  506. tensor_desc->SetFormat(FORMAT_ND);
  507. tensor_desc->SetDataType(DT_INT32);
  508. auto op_desc = std::make_shared<OpDesc>("while", WHILE);
  509. for (int i = 0; i < 2; ++i) {
  510. op_desc->AddInputDesc(tensor_desc->Clone());
  511. }
  512. for (int i = 0; i < 2; ++i) {
  513. op_desc->AddOutputDesc(tensor_desc->Clone());
  514. }
  515. AttrUtils::SetBool(op_desc,"_need_infer_again", true);
  516. op_desc->AddSubgraphName(sub_graph->GetName());
  517. op_desc->SetSubgraphInstanceName(0,sub_graph->GetName());
  518. auto root_graph = builder.GetGraph();
  519. auto while_op = root_graph->AddNode(op_desc);
  520. builder.AddDataEdge(data, 0, while_op, 0);
  521. builder.AddDataEdge(const_op, 0, while_op, 1);
  522. builder.AddDataEdge(while_op, 0, c1, 0);
  523. builder.AddDataEdge(while_op, 1, c2, 0);
  524. sub_graph->SetParentGraph(root_graph);
  525. sub_graph->SetParentNode(while_op);
  526. root_graph->AddSubgraph(sub_graph);
  527. return root_graph;
  528. }
  529. TEST_F(UTESTGraphPassesBasePass, while_infershape) {
  530. NamesToPass names_to_pass;
  531. auto test_pass = UtestTestPass();
  532. names_to_pass.push_back(std::make_pair("test", &test_pass));
  533. auto graph = BuildWhileGraph1();
  534. auto ge_pass = GEPass(graph);
  535. auto while_node = graph->FindNode("while");
  536. EXPECT_EQ(while_node->GetOpDesc()->GetSubgraphInstanceNames().size(),1);
  537. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  538. }
  539. TEST_F(UTESTGraphPassesBasePass, re_pass_pre_node_immediately) {
  540. auto graph = BuildGraph2();
  541. auto ge_pass = GEPass(graph);
  542. auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
  543. // repass pre_node immediately
  544. test_pass->AddRePassImmediateNodeName("reshape1", "add1");
  545. EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
  546. EXPECT_EQ(test_pass->GetIterNodes().size(), 9);// todo
  547. std::vector<std::unordered_set<std::string>> layers;
  548. layers.push_back({"data1", "const1", "const2"});
  549. layers.push_back({"shape1"});
  550. layers.push_back({"add1", "addn1"});
  551. layers.push_back({"reshape1", "add1", "sum1"});
  552. CheckIterOrder(test_pass, layers);
  553. }
  554. TEST_F(UTESTGraphPassesBasePass, re_pass_cur_node_immediately) {
  555. auto graph = BuildGraph2();
  556. auto ge_pass = GEPass(graph);
  557. auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
  558. // repass cur_node immediately
  559. test_pass->AddRePassImmediateNodeName("reshape1", "reshape1");
  560. EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
  561. EXPECT_EQ(test_pass->GetIterNodes().size(), 9);
  562. std::vector<std::unordered_set<std::string>> layers;
  563. layers.push_back({"data1", "const1", "const2"});
  564. layers.push_back({"shape1"});
  565. layers.push_back({"add1", "addn1"});
  566. layers.push_back({"reshape1"});
  567. layers.push_back({"reshape1", "sum1"});
  568. CheckIterOrder(test_pass, layers);
  569. }
  570. TEST_F(UTESTGraphPassesBasePass, re_pass_next_node_immediately) {
  571. auto graph = BuildGraph2();
  572. auto ge_pass = GEPass(graph);
  573. auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
  574. // repass next_node immediately
  575. test_pass->AddRePassImmediateNodeName("reshape1", "sum1");
  576. // repass node after next_node immediately
  577. test_pass->AddRePassImmediateNodeName("add1", "sum1");
  578. EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
  579. EXPECT_EQ(test_pass->GetIterNodes().size(), 8);
  580. std::vector<std::unordered_set<std::string>> layers;
  581. layers.push_back({"data1", "const1", "const2"});
  582. layers.push_back({"shape1"});
  583. layers.push_back({"add1", "addn1"});
  584. layers.push_back({"reshape1", "sum1"});
  585. CheckIterOrder(test_pass, layers);
  586. }
  587. /**
  588. * A->B->C
  589. * if node B suspend its pre_node A, and C resume A, it is a useless operation, so iter_order should follow normal order
  590. * when C resuem A, A will pass again.
  591. */
  592. TEST_F(UTESTGraphPassesBasePass, B_suspend_pre_node_A_then_C_resume_A) {
  593. auto graph = BuildGraph2();
  594. auto ge_pass = GEPass(graph);
  595. auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
  596. // add1->reshape1->sum1
  597. test_pass->AddSuspendNodeName("reshape1", "add1");
  598. test_pass->AddResumeNodeName("sum1", "add1");
  599. EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
  600. EXPECT_EQ(test_pass->GetIterNodes().size(), 9);
  601. std::vector<std::unordered_set<std::string>> layers;
  602. layers.push_back({"data1", "const1", "const2"});
  603. layers.push_back({"shape1"});
  604. layers.push_back({"add1", "addn1"});
  605. layers.push_back({"reshape1", "sum1"});
  606. layers.push_back({"add1"});
  607. CheckIterOrder(test_pass, layers);
  608. }
  609. /**
  610. * A->B->C
  611. * if node B suspend its pre_node A, and B resume A, it is a useless operation, so iter_order should follow normal order
  612. * when B resuem A, A will pass again.
  613. */
  614. TEST_F(UTESTGraphPassesBasePass, B_suspend_pre_node_A_then_B_resume_A) {
  615. auto graph = BuildGraph2();
  616. auto ge_pass = GEPass(graph);
  617. auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
  618. // add1->reshape1->sum1
  619. test_pass->AddSuspendNodeName("reshape1", "add1");
  620. test_pass->AddResumeNodeName("reshape1", "add1");
  621. EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
  622. EXPECT_EQ(test_pass->GetIterNodes().size(), 9);
  623. std::vector<std::unordered_set<std::string>> layers;
  624. layers.push_back({"data1", "const1", "const2"});
  625. layers.push_back({"shape1"});
  626. layers.push_back({"add1", "addn1"});
  627. layers.push_back({"reshape1", "sum1", "add1"});
  628. CheckIterOrder(test_pass, layers);
  629. }
  630. /**
  631. * A->B->C
  632. * if node B resume C(which is not suspended), it is a useless operation, C will not pass.
  633. */
  634. TEST_F(UTESTGraphPassesBasePass, B_resume_node_not_suspended) {
  635. auto graph = BuildGraph2();
  636. auto ge_pass = GEPass(graph);
  637. auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
  638. // add1->reshape1->sum1
  639. test_pass->AddResumeNodeName("reshape1", "sum1");
  640. EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
  641. EXPECT_EQ(test_pass->GetIterNodes().size(), 8);
  642. std::vector<std::unordered_set<std::string>> layers;
  643. layers.push_back({"data1", "const1", "const2"});
  644. layers.push_back({"shape1"});
  645. layers.push_back({"add1", "addn1"});
  646. layers.push_back({"reshape1", "sum1"});
  647. CheckIterOrder(test_pass, layers);
  648. }
  649. /**
  650. * A->B->C
  651. * if node B suspend its pre_node A, it is a useless operation, so iter_order should follow normal order
  652. * because nobody resume it ,which means A is a leaked node, so return fail
  653. */
  654. TEST_F(UTESTGraphPassesBasePass, suspend_pre_node_nobody_resume_it_return_failed) {
  655. NamesToPass names_to_pass;
  656. auto test_pass = UtestTestPass();
  657. names_to_pass.push_back(std::make_pair("test", &test_pass));
  658. // suspend pre_node immediately
  659. test_pass.AddSuspendNodeName("reshape1", "add1");
  660. auto graph = BuildGraph2();
  661. auto ge_pass = GEPass(graph);
  662. EXPECT_EQ(ge_pass.Run(names_to_pass), INTERNAL_ERROR);
  663. }
  664. /**
  665. * A->B->C
  666. * if node B suspend its pre_node A, it is a useless operation,
  667. * so iter_order should follow normal order
  668. * resume A on leaked, which means A will pass again
  669. */
  670. TEST_F(UTESTGraphPassesBasePass, suspend_pre_node_resume_it_onleaked) {
  671. auto graph = BuildGraph2();
  672. auto ge_pass = GEPass(graph);
  673. auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
  674. // suspend pre_node immediately
  675. test_pass->AddSuspendNodeName("reshape1", "add1");
  676. test_pass->AddResumeNodeNameOnLeaked("add1");
  677. EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
  678. std::vector<std::unordered_set<std::string>> layers;
  679. layers.push_back({"data1", "const1", "const2"});
  680. layers.push_back({"shape1"});
  681. layers.push_back({"add1", "addn1"});
  682. layers.push_back({"reshape1", "sum1"});
  683. layers.push_back({"add1"});
  684. CheckIterOrder(test_pass, layers);
  685. }
  686. /// cast1--shape1
  687. /// /
  688. /// data1
  689. /// \
  690. /// transdata1--shape2
  691. /**
  692. * suspend cur node
  693. * cast1 suspend itself, shape2 resume cast1
  694. * iter order follows : data1; cast1,transdata1; shape2; cast1 ; shape1
  695. */
  696. TEST_F(UTESTGraphPassesBasePass, cast1_suspend_cur_node_shape2_resume_cast1) {
  697. auto graph = BuildGraph4();
  698. auto ge_pass = GEPass(graph);
  699. auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
  700. // suspend pre_node immediately
  701. test_pass->AddSuspendNodeName("cast1", "cast1");
  702. test_pass->AddResumeNodeName("shape2", "cast1");
  703. EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
  704. EXPECT_EQ(test_pass->GetIterNodes().size(), 6);
  705. std::vector<std::unordered_set<std::string>> layers;
  706. layers.push_back({"data1"});
  707. layers.push_back({"cast1","transdata1"});
  708. layers.push_back({"shape2"});
  709. layers.push_back({"cast1", "shape1"});
  710. CheckIterOrder(test_pass, layers);
  711. }
  712. /**
  713. * suspend cur node
  714. * cast1 suspend itself, then resume cast1
  715. * iter order follows : data1; cast1,cast1,transdata1; shape2; shape1.
  716. */
  717. TEST_F(UTESTGraphPassesBasePass, cast1_suspend_itslef_then_resume_itself) {
  718. auto graph = BuildGraph4();
  719. auto ge_pass = GEPass(graph);
  720. auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
  721. // suspend pre_node immediately
  722. test_pass->AddSuspendNodeName("cast1", "cast1");
  723. test_pass->AddResumeNodeName("cast1", "cast1");
  724. EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
  725. EXPECT_EQ(test_pass->GetIterNodes().size(), 6);
  726. std::vector<std::unordered_set<std::string>> layers;
  727. layers.push_back({"data1"});
  728. layers.push_back({"cast1","transdata1","cast1","shape1", "shape2"});
  729. CheckIterOrder(test_pass, layers);
  730. }
  731. /**
  732. * suspend cur node
  733. * cast1 suspend itself, then resume cast1 on leaked
  734. * iter order follows : data1; cast1,cast1,transdata1; shape2; shape1.
  735. */
  736. TEST_F(UTESTGraphPassesBasePass, cast1_suspend_itslef_then_resume_onleaked) {
  737. auto graph = BuildGraph4();
  738. auto ge_pass = GEPass(graph);
  739. auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
  740. // suspend pre_node immediately
  741. test_pass->AddSuspendNodeName("cast1", "cast1");
  742. test_pass->AddResumeNodeNameOnLeaked("cast1");
  743. EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
  744. EXPECT_EQ(test_pass->GetIterNodes().size(), 6);
  745. std::vector<std::unordered_set<std::string>> layers;
  746. layers.push_back({"data1"});
  747. layers.push_back({"cast1","transdata1", "shape2"});
  748. layers.push_back({"cast1","shape1"});
  749. CheckIterOrder(test_pass, layers);
  750. }
  751. /**
  752. * suspend next node
  753. * data1 suspend cast1, then resume cast1 on leaked
  754. * iter order follows : data1; transdata1, shape2; cast1, shape1.
  755. */
  756. TEST_F(UTESTGraphPassesBasePass, data1_suspend_cast1_resume_cast1_onleaked) {
  757. auto graph = BuildGraph4();
  758. auto ge_pass = GEPass(graph);
  759. auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
  760. // suspend pre_node immediately
  761. test_pass->AddSuspendNodeName("data1", "cast1");
  762. test_pass->AddResumeNodeNameOnLeaked("cast1");
  763. EXPECT_EQ(ge_pass.Run(names_to_pass_), SUCCESS);
  764. EXPECT_EQ(test_pass->GetIterNodes().size(), 5);
  765. std::vector<std::unordered_set<std::string>> layers;
  766. layers.push_back({"data1"});
  767. layers.push_back({"transdata1", "shape2"});
  768. layers.push_back({"cast1","shape1"});
  769. CheckIterOrder(test_pass, layers);
  770. }
  771. /**
  772. * suspend next node
  773. * data1 suspend cast1, nobody resume it
  774. * iter order follows : data1; transdata1, shape2;
  775. * run ret is failed ,because node leaked
  776. */
  777. TEST_F(UTESTGraphPassesBasePass, data1_suspend_cast1_nobody_resume) {
  778. auto graph = BuildGraph4();
  779. auto ge_pass = GEPass(graph);
  780. auto *test_pass = dynamic_cast<UtestTestPass *>(names_to_pass_[0].second);
  781. // suspend pre_node immediately
  782. test_pass->AddSuspendNodeName("data1", "cast1");
  783. EXPECT_EQ(ge_pass.Run(names_to_pass_), INTERNAL_ERROR);
  784. EXPECT_EQ(test_pass->GetIterNodes().size(), 3);
  785. }
  786. /*
  787. TEST_F(UTESTGraphPassesBasePass, suspend_pre_node) {
  788. NamesToPass names_to_pass;
  789. auto test_pass = UtestTestPass();
  790. names_to_pass.push_back(std::make_pair("test", &test_pass));
  791. // repass next_node immediately
  792. test_pass.AddRePassNodeName("reshape1", "sum1");
  793. // repass node after next_node immediately
  794. test_pass.AddRePassNodeName("add1", "sum1");
  795. auto graph = BuildGraph2();
  796. auto ge_pass = GEPass(graph);
  797. EXPECT_EQ(ge_pass.Run(names_to_pass), SUCCESS);
  798. EXPECT_EQ(test_pass.GetIterNodes().size(), 8);// todo
  799. std::vector<std::unordered_set<std::string>> layers;
  800. layers.push_back({"data1", "const1", "const2"});
  801. layers.push_back({"shape1"});
  802. layers.push_back({"add1", "addn1"});
  803. layers.push_back({"reshape1", "sum1"});
  804. CheckIterOrder(&test_pass, layers);
  805. }*/
  806. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示