You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

constant_folding_pass_unittest.cc 25 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/constant_folding_pass.h"
  17. #include <string>
  18. #include <vector>
  19. #include <gtest/gtest.h>
  20. #include "common/types.h"
  21. #include "ge/common/ge/ge_util.h"
  22. #include "graph/passes/base_pass.h"
  23. #include "graph/passes/dimension_compute_pass.h"
  24. #include "graph_builder_utils.h"
  25. #include "inc/kernel.h"
  26. #include "inc/kernel_factory.h"
  27. namespace ge {
  28. const char *AddYesDim = "AddYesDim";
  29. const char *AddNYes = "AddNYes";
  30. const char *AddNNo = "AddNNo";
  31. const char *AddYes = "AddYes";
  32. const char *HuberLossYes = "HuberLossYes";
  33. const char *ShapeNo = "ShapeNo";
  34. const char *DataNo = "dataNo";
  35. const char *WrongYes = "WrongYes";
  36. const char *WrongYes1 = "WrongYes1";
  37. const char *WrongYes2 = "WrongYes2";
  38. const char *WrongYes3 = "WrongYes3";
  39. class TestAddNKernel : public Kernel {
  40. public:
  41. Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge::ConstGeTensorPtr> &input,
  42. std::vector<ge::GeTensorPtr> &v_output) override {
  43. auto output = std::make_shared<GeTensor>();
  44. std::vector<uint8_t> data{1, 2, 3};
  45. std::vector<int64_t> shape{3};
  46. output->MutableTensorDesc().SetShape(GeShape(shape));
  47. output->SetData(data);
  48. output->MutableTensorDesc().SetDataType(DT_UINT8);
  49. v_output.push_back(output);
  50. return SUCCESS;
  51. }
  52. };
  53. REGISTER_KERNEL(AddNYes, TestAddNKernel);
  54. class TestHuberLossKernel : public Kernel {
  55. public:
  56. Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge::ConstGeTensorPtr> &input,
  57. std::vector<ge::GeTensorPtr> &v_output) override {
  58. auto output1 = std::make_shared<GeTensor>();
  59. std::vector<uint8_t> data{1, 2, 3, 4, 5};
  60. std::vector<int64_t> shape{5};
  61. output1->MutableTensorDesc().SetShape(GeShape(shape));
  62. output1->SetData(data);
  63. output1->MutableTensorDesc().SetDataType(DT_UINT8);
  64. v_output.push_back(output1);
  65. auto output2 = std::make_shared<GeTensor>();
  66. std::vector<uint8_t> data2{1, 2, 3, 4, 5, 6};
  67. std::vector<int64_t> shape2{2, 3};
  68. output2->MutableTensorDesc().SetShape(GeShape(shape2));
  69. output2->SetData(data2);
  70. output2->MutableTensorDesc().SetDataType(DT_UINT8);
  71. v_output.push_back(output2);
  72. return SUCCESS;
  73. }
  74. };
  75. REGISTER_KERNEL(HuberLossYes, TestHuberLossKernel);
  76. class TestAddKernel : public Kernel {
  77. public:
  78. Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge::ConstGeTensorPtr> &input,
  79. std::vector<ge::GeTensorPtr> &v_output) override {
  80. auto output = std::make_shared<GeTensor>();
  81. std::vector<uint8_t> data{1, 2, 3, 4, 5};
  82. std::vector<int64_t> shape{5};
  83. output->MutableTensorDesc().SetShape(GeShape(shape));
  84. output->SetData(data);
  85. output->MutableTensorDesc().SetDataType(DT_UINT8);
  86. v_output.push_back(output);
  87. return SUCCESS;
  88. }
  89. };
  90. REGISTER_KERNEL(AddYes, TestAddKernel);
  91. class TestAddDimKernel : public Kernel {
  92. public:
  93. Status Compute(const ge::NodePtr &node, std::vector<ge::GeTensorPtr> &v_output) {
  94. auto output = std::make_shared<GeTensor>();
  95. std::vector<uint8_t> data{1, 2, 3, 4, 5};
  96. std::vector<int64_t> shape{5};
  97. output->MutableTensorDesc().SetShape(GeShape(shape));
  98. output->SetData(data);
  99. output->MutableTensorDesc().SetDataType(DT_UINT8);
  100. v_output.push_back(output);
  101. return SUCCESS;
  102. }
  103. };
  104. REGISTER_KERNEL(AddYesDim, TestAddDimKernel);
  105. class TestWrongKernel : public Kernel {
  106. public:
  107. Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge::ConstGeTensorPtr> &input,
  108. std::vector<ge::GeTensorPtr> &v_output) override {
  109. // for test: output weights is null
  110. v_output.push_back(nullptr);
  111. return SUCCESS;
  112. }
  113. };
  114. REGISTER_KERNEL(WrongYes, TestWrongKernel);
  115. class TestWrongKernel1 : public Kernel {
  116. public:
  117. Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge::ConstGeTensorPtr> &input,
  118. std::vector<ge::GeTensorPtr> &v_output) override {
  119. // for test: no output weights
  120. return SUCCESS;
  121. }
  122. };
  123. REGISTER_KERNEL(WrongYes1, TestWrongKernel1);
  124. class TestWrongKernel2 : public Kernel {
  125. public:
  126. Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge::ConstGeTensorPtr> &input,
  127. std::vector<ge::GeTensorPtr> &v_output) override {
  128. auto output1 = std::make_shared<GeTensor>();
  129. std::vector<uint8_t> data{1, 2, 3, 4, 5};
  130. std::vector<int64_t> shape{5};
  131. output1->MutableTensorDesc().SetShape(GeShape(shape));
  132. output1->SetData(data);
  133. output1->MutableTensorDesc().SetDataType(DT_UINT8);
  134. v_output.push_back(output1);
  135. // for test: output weights < output size
  136. return SUCCESS;
  137. }
  138. };
  139. REGISTER_KERNEL(WrongYes2, TestWrongKernel2);
  140. class TestWrongKernel3 : public Kernel {
  141. public:
  142. Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector<ge::ConstGeTensorPtr> &input,
  143. std::vector<ge::GeTensorPtr> &v_output) override {
  144. // for test: return NOT_CHANGED
  145. return NOT_CHANGED;
  146. }
  147. };
  148. REGISTER_KERNEL(WrongYes3, TestWrongKernel3);
  149. class UtestGraphPassesConstantFoldingPass : public testing::Test {
  150. protected:
  151. UtestGraphPassesConstantFoldingPass() = default;
  152. };
  153. namespace {
  154. /// netoutput1
  155. /// |
  156. /// shapeNo1
  157. /// |
  158. /// addnYes1
  159. /// / \
  160. /// / \
  161. /// const1 const2
  162. ComputeGraphPtr BuildGraph1() {
  163. auto builder = ut::GraphBuilder("test");
  164. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  165. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  166. auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1);
  167. auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1);
  168. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  169. builder.AddDataEdge(const1, 0, addn1, 0);
  170. builder.AddDataEdge(const2, 0, addn1, 1);
  171. builder.AddDataEdge(addn1, 0, shape1, 0);
  172. builder.AddDataEdge(shape1, 0, netoutput1, 0);
  173. return builder.GetGraph();
  174. }
  175. /// netoutput1
  176. /// |
  177. /// shapeNo1
  178. /// |
  179. /// addnYes1 shapeNo2
  180. /// / \ /
  181. /// / \ /
  182. /// const1 const2
  183. ComputeGraphPtr BuildGraph2() {
  184. auto builder = ut::GraphBuilder("test");
  185. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  186. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  187. auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1);
  188. auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1);
  189. auto shape2 = builder.AddNode("shape2", ShapeNo, 1, 1);
  190. auto netoutput1 = builder.AddNode("netoutput", DataNo, 1, 0);
  191. builder.AddDataEdge(const1, 0, addn1, 0);
  192. builder.AddDataEdge(const2, 0, addn1, 1);
  193. builder.AddDataEdge(const2, 0, shape2, 0);
  194. builder.AddDataEdge(addn1, 0, shape1, 0);
  195. builder.AddDataEdge(shape1, 0, netoutput1, 0);
  196. return builder.GetGraph();
  197. }
  198. /// netoutput1
  199. /// |
  200. /// shapeNo1
  201. /// | c
  202. /// addnYes1 <----- dataNo1
  203. /// / \
  204. /// / \
  205. /// const1 const2
  206. ComputeGraphPtr BuildGraph3() {
  207. auto builder = ut::GraphBuilder("test");
  208. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  209. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  210. auto data1 = builder.AddNode("data1", DataNo, 0, 1);
  211. auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1);
  212. auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1);
  213. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  214. builder.AddDataEdge(const1, 0, addn1, 0);
  215. builder.AddDataEdge(const2, 0, addn1, 1);
  216. builder.AddControlEdge(data1, addn1);
  217. builder.AddDataEdge(addn1, 0, shape1, 0);
  218. builder.AddDataEdge(shape1, 0, netoutput1, 0);
  219. return builder.GetGraph();
  220. }
  221. /// netoutput1
  222. /// |
  223. /// shapeNo1
  224. /// | c
  225. /// addnYes1 <---------
  226. /// / \ \
  227. /// / \ c \
  228. /// const1 const2 <----- dataNo1
  229. ComputeGraphPtr BuildGraph4() {
  230. auto builder = ut::GraphBuilder("test");
  231. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  232. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  233. auto data1 = builder.AddNode("data1", DataNo, 0, 1);
  234. auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1);
  235. auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1);
  236. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  237. builder.AddDataEdge(const1, 0, addn1, 0);
  238. builder.AddDataEdge(const2, 0, addn1, 1);
  239. builder.AddControlEdge(data1, const2);
  240. builder.AddControlEdge(data1, addn1);
  241. builder.AddDataEdge(addn1, 0, shape1, 0);
  242. builder.AddDataEdge(shape1, 0, netoutput1, 0);
  243. return builder.GetGraph();
  244. }
  245. /// netoutput1
  246. /// |
  247. /// shapeNo1
  248. /// | c
  249. /// addnYes1 <----- dataNo1
  250. /// / \
  251. /// / \ c
  252. /// const1 const2 <----- dataNo2
  253. ComputeGraphPtr BuildGraph5() {
  254. auto builder = ut::GraphBuilder("test");
  255. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  256. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  257. auto data1 = builder.AddNode("data1", DataNo, 0, 1);
  258. auto data2 = builder.AddNode("data2", DataNo, 0, 1);
  259. auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1);
  260. auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1);
  261. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  262. builder.AddDataEdge(const1, 0, addn1, 0);
  263. builder.AddDataEdge(const2, 0, addn1, 1);
  264. builder.AddControlEdge(data2, const2);
  265. builder.AddControlEdge(data1, addn1);
  266. builder.AddDataEdge(addn1, 0, shape1, 0);
  267. builder.AddDataEdge(shape1, 0, netoutput1, 0);
  268. return builder.GetGraph();
  269. }
  270. /// netoutput1
  271. /// |
  272. /// shapeNo1
  273. /// |
  274. /// addYes1 <---- const3
  275. /// |
  276. /// addnYes1 <-
  277. /// / \ \
  278. /// / \ \
  279. /// const1 const2 const4
  280. ComputeGraphPtr BuildGraph6() {
  281. auto builder = ut::GraphBuilder("test");
  282. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  283. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  284. auto const3 = builder.AddNode("const3", CONSTANT, 0, 1);
  285. auto const4 = builder.AddNode("const4", CONSTANT, 0, 1);
  286. auto addn1 = builder.AddNode("addn1", AddNYes, 3, 1);
  287. auto add1 = builder.AddNode("add1", AddYes, 2, 1);
  288. auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1);
  289. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  290. builder.AddDataEdge(const1, 0, addn1, 0);
  291. builder.AddDataEdge(const2, 0, addn1, 1);
  292. builder.AddDataEdge(const4, 0, addn1, 2);
  293. builder.AddDataEdge(addn1, 0, add1, 0);
  294. builder.AddDataEdge(const3, 0, add1, 1);
  295. builder.AddDataEdge(add1, 0, shape1, 0);
  296. builder.AddDataEdge(shape1, 0, netoutput1, 0);
  297. return builder.GetGraph();
  298. }
  299. /// netoutput1
  300. /// / \
  301. /// shapeNo1 ShpaeNo2
  302. /// \ /
  303. /// huberLoss1
  304. /// / | \
  305. /// / | \
  306. /// const1 const2 const3
  307. ComputeGraphPtr BuildGraph7() {
  308. auto builder = ut::GraphBuilder("test");
  309. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  310. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  311. auto const3 = builder.AddNode("const3", CONSTANT, 0, 1);
  312. auto huberLoss1 = builder.AddNode("huberLoss1", HuberLossYes, 3, 2);
  313. auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1);
  314. auto shape2 = builder.AddNode("shape2", ShapeNo, 1, 1);
  315. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  316. builder.AddDataEdge(const1, 0, huberLoss1, 0);
  317. builder.AddDataEdge(const2, 0, huberLoss1, 1);
  318. builder.AddDataEdge(const3, 0, huberLoss1, 2);
  319. builder.AddDataEdge(huberLoss1, 0, shape1, 0);
  320. builder.AddDataEdge(huberLoss1, 1, shape2, 0);
  321. builder.AddDataEdge(shape1, 0, netoutput1, 0);
  322. builder.AddDataEdge(shape2, 1, netoutput1, 0);
  323. return builder.GetGraph();
  324. }
  325. /// netoutput1
  326. /// |
  327. /// shapeNo1
  328. /// |
  329. /// addnNo1
  330. /// / \
  331. /// / \
  332. /// const1 const2
  333. ComputeGraphPtr BuildGraph8() {
  334. auto builder = ut::GraphBuilder("test");
  335. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  336. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  337. auto addn1 = builder.AddNode("addn1", AddNNo, 2, 1);
  338. auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1);
  339. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  340. builder.AddDataEdge(const1, 0, addn1, 0);
  341. builder.AddDataEdge(const2, 0, addn1, 1);
  342. builder.AddDataEdge(addn1, 0, shape1, 0);
  343. builder.AddDataEdge(shape1, 0, netoutput1, 0);
  344. return builder.GetGraph();
  345. }
  346. /// netoutput1
  347. /// |
  348. /// shapeNo1
  349. /// |
  350. /// addnYes1
  351. /// / \
  352. /// / \
  353. /// const1 data1
  354. ComputeGraphPtr BuildGraph9() {
  355. auto builder = ut::GraphBuilder("test");
  356. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  357. auto data1 = builder.AddNode("data1", DataNo, 0, 1);
  358. auto addn1 = builder.AddNode("addn1", AddNYes, 2, 1);
  359. auto shape1 = builder.AddNode("shape1", ShapeNo, 1, 1);
  360. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  361. builder.AddDataEdge(const1, 0, addn1, 0);
  362. builder.AddDataEdge(data1, 0, addn1, 1);
  363. builder.AddDataEdge(addn1, 0, shape1, 0);
  364. builder.AddDataEdge(shape1, 0, netoutput1, 0);
  365. return builder.GetGraph();
  366. }
  367. /// netoutput1
  368. /// / \
  369. /// addDim sqrt1
  370. /// \ /
  371. /// switch1
  372. /// / \
  373. /// / \
  374. /// const1 const2
  375. ComputeGraphPtr BuildGraph10() {
  376. auto builder = ut::GraphBuilder("test");
  377. auto const1 = builder.AddNode("const1", CONSTANT, 0, 1);
  378. auto const2 = builder.AddNode("const2", CONSTANT, 0, 1);
  379. auto switchNode1 = builder.AddNode("switch1", SWITCH, 2, 2);
  380. auto sqrt1 = builder.AddNode("sqrt1", RSQRT, 1, 1);
  381. auto add1 = builder.AddNode("addDim", AddYesDim, 1, 1);
  382. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  383. builder.AddDataEdge(const1, 0, switchNode1, 0);
  384. builder.AddDataEdge(const2, 0, switchNode1, 1);
  385. builder.AddDataEdge(switchNode1, 0, add1, 0);
  386. builder.AddDataEdge(switchNode1, 1, sqrt1, 0);
  387. builder.AddDataEdge(add1, 0, netoutput1, 0);
  388. builder.AddDataEdge(sqrt1, 0, netoutput1, 1);
  389. return builder.GetGraph();
  390. }
  391. /// netoutput1
  392. /// |
  393. /// FRAMEWORKOP
  394. /// |
  395. /// const1
  396. ComputeGraphPtr BuildWrongGraph1() {
  397. auto builder = ut::GraphBuilder("test");
  398. auto const_op = builder.AddNode("const1", CONSTANT, 0, 1);
  399. auto op = builder.AddNode("fmk_op", FRAMEWORKOP, 1, 1);
  400. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  401. builder.AddDataEdge(const_op, 0, op, 0);
  402. builder.AddDataEdge(op, 0, netoutput1, 0);
  403. return builder.GetGraph();
  404. }
  405. /// netoutput1
  406. /// |
  407. /// WrongYes
  408. /// |
  409. /// const1
  410. ComputeGraphPtr BuildWrongGraph2() {
  411. auto builder = ut::GraphBuilder("test");
  412. auto const_op = builder.AddNode("const1", CONSTANT, 0, 1);
  413. auto op = builder.AddNode("wrong", WrongYes, 1, 1);
  414. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  415. builder.AddDataEdge(const_op, 0, op, 0);
  416. builder.AddDataEdge(op, 0, netoutput1, 0);
  417. return builder.GetGraph();
  418. }
  419. /// netoutput1
  420. /// |
  421. /// WrongYes1
  422. /// |
  423. /// const1
  424. ComputeGraphPtr BuildWrongGraph3() {
  425. auto builder = ut::GraphBuilder("test");
  426. auto const_op = builder.AddNode("const1", CONSTANT, 0, 1);
  427. auto op = builder.AddNode("wrong1", WrongYes1, 1, 1);
  428. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  429. builder.AddDataEdge(const_op, 0, op, 0);
  430. builder.AddDataEdge(op, 0, netoutput1, 0);
  431. return builder.GetGraph();
  432. }
  433. /// netoutput1 WrongYes1
  434. /// | /
  435. /// WrongYes2
  436. /// /
  437. /// const1
  438. ComputeGraphPtr BuildWrongGraph4() {
  439. auto builder = ut::GraphBuilder("test");
  440. auto const_op_1 = builder.AddNode("const1", CONSTANT, 0, 1);
  441. auto op = builder.AddNode("wrong2", WrongYes2, 1, 2);
  442. auto netoutput1 = builder.AddNode("netoutput", NETOUTPUT, 1, 0);
  443. auto wrong_op = builder.AddNode("WrongYes1", WrongYes1, 1, 0);
  444. builder.AddDataEdge(const_op_1, 0, op, 0);
  445. builder.AddDataEdge(op, 0, netoutput1, 0);
  446. builder.AddDataEdge(op, 1, wrong_op, 0);
  447. return builder.GetGraph();
  448. }
  449. /// CONVOLUTION
  450. /// |
  451. /// WrongYes2 WrongYes1
  452. /// /
  453. /// const1
  454. ComputeGraphPtr BuildWrongGraph5() {
  455. auto builder = ut::GraphBuilder("test");
  456. auto const_op_1 = builder.AddNode("const1", CONSTANT, 0, 1);
  457. auto op = builder.AddNode("wrong2", WrongYes2, 1, 1);
  458. auto conv = builder.AddNode("conv", CONVOLUTION, 1, 0);
  459. auto wrong_op = builder.AddNode("WrongYes1", WrongYes1, 1, 0);
  460. builder.AddDataEdge(const_op_1, 0, op, 0);
  461. builder.AddDataEdge(op, 0, conv, 0);
  462. return builder.GetGraph();
  463. }
  464. /// CONVOLUTION
  465. /// |
  466. /// WrongYes3
  467. /// /
  468. /// const1
  469. ComputeGraphPtr BuildWrongGraph6() {
  470. auto builder = ut::GraphBuilder("test");
  471. auto const_op_1 = builder.AddNode("const1", CONSTANT, 0, 1);
  472. auto op = builder.AddNode("wrong3", WrongYes3, 1, 1);
  473. auto conv = builder.AddNode("conv", CONVOLUTION, 1, 0);
  474. builder.AddDataEdge(const_op_1, 0, op, 0);
  475. builder.AddDataEdge(op, 0, conv, 0);
  476. return builder.GetGraph();
  477. }
  478. } // namespace
  479. TEST_F(UtestGraphPassesConstantFoldingPass, folding_addn) {
  480. auto graph = BuildGraph1();
  481. NamesToPass names_to_pass;
  482. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  483. GEPass pass(graph);
  484. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  485. EXPECT_EQ(graph->GetAllNodes().size(), 3);
  486. auto shape1 = graph->FindNode("shape1");
  487. EXPECT_NE(shape1, nullptr);
  488. EXPECT_EQ(shape1->GetInNodes().size(), 1);
  489. auto folded_const = shape1->GetInDataNodes().at(0);
  490. EXPECT_EQ(folded_const->GetType(), CONSTANT);
  491. auto tensor = folded_const->GetOpDesc()->GetOutputDesc(0);
  492. EXPECT_EQ(tensor.GetDataType(), DT_UINT8);
  493. EXPECT_EQ(tensor.GetShape().GetDims(), std::vector<int64_t>({3}));
  494. for (auto &name_to_pass : names_to_pass) {
  495. delete name_to_pass.second;
  496. }
  497. }
  498. TEST_F(UtestGraphPassesConstantFoldingPass, folding_without_one_const) {
  499. auto graph = BuildGraph2();
  500. NamesToPass names_to_pass;
  501. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  502. GEPass pass(graph);
  503. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  504. EXPECT_EQ(graph->GetAllNodes().size(), 5);
  505. EXPECT_EQ(graph->FindNode("addn1"), nullptr);
  506. EXPECT_EQ(graph->FindNode("const1"), nullptr);
  507. auto const2 = graph->FindNode("const2");
  508. EXPECT_NE(const2, nullptr);
  509. EXPECT_EQ(const2->GetOutDataNodes().size(), 1);
  510. EXPECT_EQ(const2->GetOutDataNodes().at(0)->GetName(), "shape2");
  511. auto shape1 = graph->FindNode("shape1");
  512. EXPECT_NE(shape1, nullptr);
  513. EXPECT_EQ(shape1->GetInDataNodes().size(), 1);
  514. EXPECT_EQ(shape1->GetInDataNodes().at(0)->GetType(), CONSTANT);
  515. for (auto &name_to_pass : names_to_pass) {
  516. delete name_to_pass.second;
  517. }
  518. }
  519. TEST_F(UtestGraphPassesConstantFoldingPass, folding_with_const_control_edges) {
  520. auto graph = BuildGraph5();
  521. NamesToPass names_to_pass;
  522. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  523. GEPass pass(graph);
  524. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  525. EXPECT_EQ(graph->GetAllNodes().size(), 5);
  526. auto shape1 = graph->FindNode("shape1");
  527. EXPECT_NE(shape1, nullptr);
  528. EXPECT_EQ(shape1->GetInNodes().size(), 1);
  529. EXPECT_EQ(shape1->GetInControlNodes().size(), 0);
  530. EXPECT_EQ(shape1->GetInDataNodes().at(0)->GetType(), CONSTANT);
  531. std::unordered_set<std::string> node_names;
  532. for (auto node : shape1->GetInControlNodes()) {
  533. node_names.insert(node->GetName());
  534. }
  535. EXPECT_EQ(node_names, std::unordered_set<std::string>());
  536. for (auto &name_to_pass : names_to_pass) {
  537. delete name_to_pass.second;
  538. }
  539. }
  540. TEST_F(UtestGraphPassesConstantFoldingPass, continues_fold) {
  541. auto graph = BuildGraph6();
  542. NamesToPass names_to_pass;
  543. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  544. GEPass pass(graph);
  545. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  546. EXPECT_EQ(graph->GetAllNodes().size(), 3);
  547. auto shape1 = graph->FindNode("shape1");
  548. EXPECT_NE(shape1, nullptr);
  549. EXPECT_EQ(shape1->GetInNodes().size(), 1);
  550. auto folded_const = shape1->GetInDataNodes().at(0);
  551. EXPECT_EQ(folded_const->GetType(), CONSTANT);
  552. auto tensor = folded_const->GetOpDesc()->GetOutputDesc(0);
  553. EXPECT_EQ(tensor.GetDataType(), DT_UINT8);
  554. EXPECT_EQ(tensor.GetShape().GetDims(), std::vector<int64_t>({5}));
  555. for (auto &name_to_pass : names_to_pass) {
  556. delete name_to_pass.second;
  557. }
  558. }
  559. TEST_F(UtestGraphPassesConstantFoldingPass, multiple_output) {
  560. auto graph = BuildGraph7();
  561. NamesToPass names_to_pass;
  562. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  563. GEPass pass(graph);
  564. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  565. EXPECT_EQ(graph->GetAllNodes().size(), 5);
  566. auto shape1 = graph->FindNode("shape1");
  567. EXPECT_NE(shape1, nullptr);
  568. EXPECT_EQ(shape1->GetInNodes().size(), 1);
  569. auto folded_const = shape1->GetInDataNodes().at(0);
  570. EXPECT_EQ(folded_const->GetType(), CONSTANT);
  571. auto tensor = folded_const->GetOpDesc()->GetOutputDesc(0);
  572. EXPECT_EQ(tensor.GetDataType(), DT_UINT8);
  573. EXPECT_EQ(tensor.GetShape().GetDims(), std::vector<int64_t>({5}));
  574. auto shape2 = graph->FindNode("shape2");
  575. EXPECT_NE(shape2, nullptr);
  576. EXPECT_EQ(shape2->GetInNodes().size(), 1);
  577. auto folded_const2 = shape2->GetInDataNodes().at(0);
  578. EXPECT_EQ(folded_const2->GetType(), CONSTANT);
  579. auto tensor2 = folded_const2->GetOpDesc()->GetOutputDesc(0);
  580. EXPECT_EQ(tensor2.GetDataType(), DT_UINT8);
  581. EXPECT_EQ(tensor2.GetShape().GetDims(), std::vector<int64_t>({2, 3}));
  582. for (auto &name_to_pass : names_to_pass) {
  583. delete name_to_pass.second;
  584. }
  585. }
  586. TEST_F(UtestGraphPassesConstantFoldingPass, not_change1) {
  587. auto graph = BuildGraph8();
  588. NamesToPass names_to_pass;
  589. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  590. GEPass pass(graph);
  591. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  592. EXPECT_EQ(graph->GetAllNodes().size(), 5);
  593. for (auto &name_to_pass : names_to_pass) {
  594. delete name_to_pass.second;
  595. }
  596. }
  597. TEST_F(UtestGraphPassesConstantFoldingPass, not_change2) {
  598. auto graph = BuildGraph9();
  599. NamesToPass names_to_pass;
  600. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  601. GEPass pass(graph);
  602. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  603. EXPECT_EQ(graph->GetAllNodes().size(), 5);
  604. for (auto &name_to_pass : names_to_pass) {
  605. delete name_to_pass.second;
  606. }
  607. }
  608. TEST_F(UtestGraphPassesConstantFoldingPass, folding_size) {
  609. auto graph = BuildGraph10();
  610. NamesToPass names_to_pass;
  611. names_to_pass.push_back({"Test", new DimensionComputePass});
  612. GEPass pass(graph);
  613. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  614. EXPECT_EQ(graph->GetAllNodes().size(), 7);
  615. auto switchnode = graph->FindNode("switch1");
  616. EXPECT_NE(switchnode, nullptr);
  617. EXPECT_EQ(switchnode->GetOutDataNodes().size(), 2);
  618. EXPECT_EQ(switchnode->GetOutDataNodes().at(0)->GetName(), "addDim_ctrl_identity_0");
  619. for (auto &name_to_pass : names_to_pass) {
  620. delete name_to_pass.second;
  621. }
  622. }
  623. TEST_F(UtestGraphPassesConstantFoldingPass, unlikely1) {
  624. auto graph = BuildWrongGraph1();
  625. NamesToPass names_to_pass;
  626. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  627. GEPass pass(graph);
  628. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  629. for (auto &name_to_pass : names_to_pass) {
  630. delete name_to_pass.second;
  631. }
  632. }
  633. TEST_F(UtestGraphPassesConstantFoldingPass, unlikely2) {
  634. auto graph = BuildWrongGraph2();
  635. NamesToPass names_to_pass;
  636. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  637. GEPass pass(graph);
  638. EXPECT_EQ(pass.Run(names_to_pass), INTERNAL_ERROR);
  639. for (auto &name_to_pass : names_to_pass) {
  640. delete name_to_pass.second;
  641. }
  642. }
  643. TEST_F(UtestGraphPassesConstantFoldingPass, unlikely3) {
  644. auto graph = BuildWrongGraph3();
  645. NamesToPass names_to_pass;
  646. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  647. GEPass pass(graph);
  648. EXPECT_EQ(pass.Run(names_to_pass), INTERNAL_ERROR);
  649. for (auto &name_to_pass : names_to_pass) {
  650. delete name_to_pass.second;
  651. }
  652. }
  653. TEST_F(UtestGraphPassesConstantFoldingPass, unlikely4) {
  654. auto graph = BuildWrongGraph4();
  655. NamesToPass names_to_pass;
  656. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  657. GEPass pass(graph);
  658. EXPECT_EQ(pass.Run(names_to_pass), INTERNAL_ERROR);
  659. for (auto &name_to_pass : names_to_pass) {
  660. delete name_to_pass.second;
  661. }
  662. }
  663. TEST_F(UtestGraphPassesConstantFoldingPass, unlikely5) {
  664. auto graph = BuildWrongGraph5();
  665. NamesToPass names_to_pass;
  666. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  667. GEPass pass(graph);
  668. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  669. for (auto &name_to_pass : names_to_pass) {
  670. delete name_to_pass.second;
  671. }
  672. }
  673. TEST_F(UtestGraphPassesConstantFoldingPass, unlikely6) {
  674. auto graph = BuildWrongGraph6();
  675. NamesToPass names_to_pass;
  676. names_to_pass.push_back({"Test", new ConstantFoldingPass});
  677. GEPass pass(graph);
  678. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  679. for (auto &name_to_pass : names_to_pass) {
  680. delete name_to_pass.second;
  681. }
  682. }
  683. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示