You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

addn_pass_unittest.cc 8.2 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <cstdint>
  17. #include <string>
  18. #include <gtest/gtest.h>
  19. #include "common/ge_inner_error_codes.h"
  20. #include "graph/passes/addn_pass.h"
  21. namespace ge {
  22. namespace {
  23. GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, Format format = FORMAT_NCHW,
  24. DataType data_type = DT_FLOAT) {
  25. GeShape ge_shape{vector<int64_t>(shape)};
  26. GeTensorDescPtr tensor_desc = std::make_shared<GeTensorDesc>();
  27. tensor_desc->SetShape(ge_shape);
  28. tensor_desc->SetFormat(format);
  29. tensor_desc->SetDataType(data_type);
  30. return tensor_desc;
  31. }
  32. class NodeBuilder {
  33. public:
  34. NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared<OpDesc>(name, type); }
  35. NodeBuilder &AddInputDesc(std::initializer_list<int64_t> shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW,
  36. DataType data_type = DT_FLOAT) {
  37. op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
  38. return *this;
  39. }
  40. NodeBuilder &AddOutputDesc(std::initializer_list<int64_t> shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW,
  41. DataType data_type = DT_FLOAT) {
  42. op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
  43. return *this;
  44. }
  45. NodeBuilder &AddOutputDesc(GeTensorDescPtr tensor_desc) {
  46. op_desc_->AddOutputDesc(tensor_desc->Clone());
  47. return *this;
  48. }
  49. NodePtr Build(const ComputeGraphPtr &graph) {
  50. NodePtr node = graph->AddNode(op_desc_);
  51. return node;
  52. }
  53. private:
  54. OpDescPtr op_desc_;
  55. };
  56. } // namespace
  57. TEST(UtestGraphPassesAddnPass, null_pass) {
  58. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  59. GEPass pass(graph);
  60. AddNPass *addn_pass = nullptr;
  61. NamesToPass names_to_pass;
  62. names_to_pass.emplace_back("Test", addn_pass);
  63. EXPECT_EQ(pass.Run(names_to_pass), INTERNAL_ERROR);
  64. }
  65. TEST(UtestGraphPassesAddnPass, null_graph) {
  66. ComputeGraphPtr graph = nullptr;
  67. GEPass pass(graph);
  68. AddNPass addn_pass;
  69. NamesToPass names_to_pass;
  70. names_to_pass.emplace_back("Test", nullptr);
  71. EXPECT_EQ(pass.Run(names_to_pass), INTERNAL_ERROR);
  72. }
  73. TEST(UtestGraphPassesAddnPass, empty_pass) {
  74. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  75. GEPass pass(graph);
  76. AddNPass addn_pass;
  77. NamesToPass names_to_pass;
  78. EXPECT_EQ(pass.Run(names_to_pass), INTERNAL_ERROR);
  79. }
  80. /// |
  81. /// AddN
  82. /// |
  83. TEST(UtestGraphPassesAddnPass, single_addn_node) {
  84. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  85. GeTensorDescPtr general_ge_tensor_desc = std::make_shared<GeTensorDesc>();
  86. NodePtr add_n_node = NodeBuilder("add_n_node", ADDN).Build(graph);
  87. GEPass pass(graph);
  88. AddNPass addn_pass;
  89. NamesToPass names_to_pass;
  90. names_to_pass.emplace_back("Test", &addn_pass);
  91. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  92. EXPECT_EQ(graph->GetDirectNodesSize(), 1);
  93. EXPECT_TRUE(add_n_node->GetInDataNodes().empty());
  94. EXPECT_TRUE(add_n_node->GetOutDataNodes().empty());
  95. }
  96. /// Op1
  97. /// |
  98. /// AddN
  99. /// |
  100. TEST(UtestGraphPassesAddnPass, no_output) {
  101. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  102. GeTensorDescPtr general_ge_tensor_desc = std::make_shared<GeTensorDesc>();
  103. NodePtr node = NodeBuilder("node", RELU).AddInputDesc({1, 1, 224, 224}).AddOutputDesc({1, 1, 224, 224}).Build(graph);
  104. NodePtr add_n_node = NodeBuilder("add_n_node", ADDN).AddInputDesc({1, 1, 224, 224}).Build(graph);
  105. GraphUtils::AddEdge(node->GetOutDataAnchor(0), add_n_node->GetInDataAnchor(0));
  106. GEPass pass(graph);
  107. AddNPass addn_pass;
  108. NamesToPass names_to_pass;
  109. names_to_pass.emplace_back("Test", &addn_pass);
  110. EXPECT_NE(pass.Run(names_to_pass), SUCCESS);
  111. EXPECT_FALSE(add_n_node->GetInDataNodes().empty());
  112. EXPECT_TRUE(add_n_node->GetOutDataNodes().empty());
  113. EXPECT_FALSE(node->GetOutDataNodes().empty());
  114. }
  115. /// |
  116. /// AddN
  117. /// |
  118. /// Op
  119. TEST(UtestGraphPassesAddnPass, no_input) {
  120. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  121. GeTensorDescPtr general_ge_tensor_desc = std::make_shared<GeTensorDesc>();
  122. NodePtr add_n_node = NodeBuilder("add_n_node", ADDN).AddOutputDesc({1, 1, 224, 224}).Build(graph);
  123. NodePtr node = NodeBuilder("node2", RELU).AddInputDesc({1, 1, 224, 224}).AddOutputDesc({1, 1, 224, 224}).Build(graph);
  124. GraphUtils::AddEdge(add_n_node->GetOutDataAnchor(0), node->GetInDataAnchor(0));
  125. GEPass pass(graph);
  126. AddNPass addn_pass;
  127. NamesToPass names_to_pass;
  128. names_to_pass.emplace_back("Test", &addn_pass);
  129. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  130. EXPECT_EQ(graph->GetDirectNodesSize(), 2);
  131. EXPECT_TRUE(add_n_node->GetInDataNodes().empty());
  132. EXPECT_EQ(node->GetInDataNodes().at(0)->GetName(), add_n_node->GetName());
  133. }
  134. /// Op1
  135. /// |
  136. /// AddN
  137. /// |
  138. /// Op2
  139. TEST(UtestGraphPassesAddnPass, single_input_remove_addn_success) {
  140. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  141. GeTensorDescPtr general_ge_tensor_desc = std::make_shared<GeTensorDesc>();
  142. NodePtr node1 =
  143. NodeBuilder("node1", CONSTANTOP).AddInputDesc({1, 1, 224, 224}).AddOutputDesc({1, 1, 224, 224}).Build(graph);
  144. NodePtr add_n_node =
  145. NodeBuilder("add_n_node", ADDN).AddInputDesc({1, 1, 224, 224}).AddOutputDesc({1, 1, 224, 224}).Build(graph);
  146. NodePtr node2 =
  147. NodeBuilder("node2", RELU).AddInputDesc({1, 1, 224, 224}).AddOutputDesc({1, 1, 224, 224}).Build(graph);
  148. GraphUtils::AddEdge(node1->GetOutDataAnchor(0), add_n_node->GetInDataAnchor(0));
  149. GraphUtils::AddEdge(add_n_node->GetOutDataAnchor(0), node2->GetInDataAnchor(0));
  150. EXPECT_EQ(graph->GetDirectNodesSize(), 3);
  151. GEPass pass(graph);
  152. AddNPass addn_pass;
  153. NamesToPass names_to_pass;
  154. names_to_pass.emplace_back("Test", &addn_pass);
  155. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  156. EXPECT_EQ(node1->GetOutDataNodes().at(0)->GetName(), node2->GetName());
  157. EXPECT_EQ(node2->GetInDataNodes().at(0)->GetName(), node1->GetName());
  158. EXPECT_TRUE(add_n_node->GetOutDataNodes().empty());
  159. EXPECT_TRUE(add_n_node->GetInDataNodes().empty());
  160. }
  161. /// Op1 Op2
  162. /// \ /
  163. /// AddN
  164. /// |
  165. /// Op3
  166. TEST(UtestGraphPassesAddnPass, multiple_inputs_do_not_remove) {
  167. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  168. GeTensorDescPtr general_ge_tensor_desc = std::make_shared<GeTensorDesc>();
  169. NodePtr node1 =
  170. NodeBuilder("node1", CONSTANTOP).AddInputDesc({1, 1, 224, 224}).AddOutputDesc({1, 1, 224, 224}).Build(graph);
  171. NodePtr node2 =
  172. NodeBuilder("node2", CONSTANTOP).AddInputDesc({1, 1, 224, 224}).AddOutputDesc({1, 1, 224, 224}).Build(graph);
  173. NodePtr add_n_node = NodeBuilder("add_n_node", ADDN)
  174. .AddInputDesc({1, 1, 224, 224})
  175. .AddInputDesc({1, 1, 224, 224})
  176. .AddOutputDesc({1, 1, 224, 224})
  177. .Build(graph);
  178. NodePtr node3 =
  179. NodeBuilder("node3", RELU).AddInputDesc({1, 1, 224, 224}).AddOutputDesc({1, 1, 224, 224}).Build(graph);
  180. GraphUtils::AddEdge(node1->GetOutDataAnchor(0), add_n_node->GetInDataAnchor(0));
  181. GraphUtils::AddEdge(node2->GetOutDataAnchor(0), add_n_node->GetInDataAnchor(1));
  182. GraphUtils::AddEdge(add_n_node->GetOutDataAnchor(0), node3->GetInDataAnchor(0));
  183. EXPECT_EQ(graph->GetDirectNodesSize(), 4);
  184. GEPass pass(graph);
  185. AddNPass addn_pass;
  186. NamesToPass names_to_pass;
  187. names_to_pass.emplace_back("Test", &addn_pass);
  188. EXPECT_EQ(pass.Run(names_to_pass), SUCCESS);
  189. EXPECT_EQ(graph->GetDirectNodesSize(), 4);
  190. }
  191. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示