You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

variable_prepare_pass_unittest.cc 8.9 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/variable_prepare_op_pass.h"
  17. #include <gtest/gtest.h>
  18. #include <string>
  19. using namespace ge;
  20. class UtestGraphPassesVariablePreparePass : public testing::Test {
  21. protected:
  22. void SetUp() {}
  23. void TearDown() {}
  24. };
  25. class NodeBuilder {
  26. public:
  27. NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared<OpDesc>(name, type); }
  28. NodeBuilder &AddInputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
  29. ge::DataType data_type = DT_FLOAT) {
  30. op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
  31. return *this;
  32. }
  33. NodeBuilder &AddOutputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
  34. ge::DataType data_type = DT_FLOAT) {
  35. op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
  36. return *this;
  37. }
  38. ge::NodePtr Build(const ge::ComputeGraphPtr &graph) { return graph->AddNode(op_desc_); }
  39. private:
  40. ge::GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
  41. ge::DataType data_type = DT_FLOAT) {
  42. GeShape ge_shape{std::vector<int64_t>(shape)};
  43. ge::GeTensorDescPtr tensor_desc = std::make_shared<ge::GeTensorDesc>();
  44. tensor_desc->SetShape(ge_shape);
  45. tensor_desc->SetFormat(format);
  46. tensor_desc->SetDataType(data_type);
  47. return tensor_desc;
  48. }
  49. ge::OpDescPtr op_desc_;
  50. };
  51. /// variable -- const
  52. /// \ /
  53. /// \ /
  54. /// assign
  55. TEST_F(UtestGraphPassesVariablePreparePass, variable_prepare_pass_succ1) {
  56. ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  57. ge::NodePtr variable_node = NodeBuilder("variable", VARIABLE)
  58. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  59. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  60. .Build(graph);
  61. ge::NodePtr const_node = NodeBuilder("const", CONSTANT)
  62. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  63. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  64. .Build(graph);
  65. ge::NodePtr apply_assign_node = NodeBuilder("assign", ASSIGN)
  66. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  67. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  68. .Build(graph);
  69. ge::GraphUtils::AddEdge(variable_node->GetOutDataAnchor(0), apply_assign_node->GetInDataAnchor(0));
  70. ge::GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), apply_assign_node->GetInDataAnchor(1));
  71. ge::VariablePrepareOpPass pass_;
  72. ge::Status status = pass_.Run(graph);
  73. EXPECT_EQ(apply_assign_node->GetOutDataNodes().size(), 0);
  74. EXPECT_EQ(SUCCESS, status);
  75. }
  76. /// variable -- applyMoment
  77. TEST_F(UtestGraphPassesVariablePreparePass, variable_prepare_pass_succ2) {
  78. ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  79. ge::NodePtr variable_node = NodeBuilder("variable", VARIABLE)
  80. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  81. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  82. .Build(graph);
  83. ge::NodePtr apply_monetum_node = NodeBuilder("apply_monetum", APPLYMOMENTUM)
  84. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  85. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  86. .Build(graph);
  87. ge::NodePtr sinh_node = NodeBuilder("sinh", SINH)
  88. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  89. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  90. .Build(graph);
  91. ge::GraphUtils::AddEdge(variable_node->GetOutDataAnchor(0), apply_monetum_node->GetInDataAnchor(0));
  92. ge::GraphUtils::AddEdge(apply_monetum_node->GetOutControlAnchor(), sinh_node->GetInControlAnchor());
  93. ge::VariablePrepareOpPass pass_;
  94. ge::Status status = pass_.Run(graph);
  95. EXPECT_EQ(apply_monetum_node->GetOutDataNodes().size(), 0);
  96. EXPECT_EQ(SUCCESS, status);
  97. }
  98. /// variable -- const1
  99. /// \ /
  100. /// \ /
  101. /// assign_add1 -- const2
  102. /// \ /
  103. /// \ /
  104. /// assign_sub -- const3
  105. /// \ /
  106. /// \ /
  107. /// assign_add2 -- const4
  108. /// \ /
  109. /// \ /
  110. /// assign_add3
  111. TEST_F(UtestGraphPassesVariablePreparePass, variable_prepare_pass_succ3) {
  112. ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  113. ge::NodePtr variable_node = NodeBuilder("variable", VARIABLE)
  114. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  115. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  116. .Build(graph);
  117. ge::NodePtr const_node1 = NodeBuilder("const1", CONSTANT)
  118. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  119. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  120. .Build(graph);
  121. ge::NodePtr const_node2 = NodeBuilder("const2", CONSTANT)
  122. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  123. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  124. .Build(graph);
  125. ge::NodePtr const_node3 = NodeBuilder("const3", CONSTANT)
  126. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  127. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  128. .Build(graph);
  129. ge::NodePtr const_node4 = NodeBuilder("const4", CONSTANT)
  130. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  131. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  132. .Build(graph);
  133. ge::NodePtr assign_add1 = NodeBuilder("assign_add1", ASSIGNADD)
  134. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  135. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  136. .Build(graph);
  137. ge::NodePtr assign_sub = NodeBuilder("assign_sub", ASSIGNSUB)
  138. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  139. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  140. .Build(graph);
  141. ge::NodePtr assign_add2 = NodeBuilder("assign_add2", ASSIGNADD)
  142. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  143. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  144. .Build(graph);
  145. ge::NodePtr assign_add3 = NodeBuilder("assign_add3", ASSIGNADD)
  146. .AddInputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  147. .AddOutputDesc({2, 16, 2, 2}, FORMAT_NHWC, DT_FLOAT)
  148. .Build(graph);
  149. ge::GraphUtils::AddEdge(variable_node->GetOutDataAnchor(0), assign_add1->GetInDataAnchor(0));
  150. ge::GraphUtils::AddEdge(const_node1->GetOutDataAnchor(0), assign_add1->GetInDataAnchor(1));
  151. ge::GraphUtils::AddEdge(assign_add1->GetOutDataAnchor(0), assign_sub->GetInDataAnchor(0));
  152. ge::GraphUtils::AddEdge(const_node2->GetOutDataAnchor(0), assign_sub->GetInDataAnchor(1));
  153. ge::GraphUtils::AddEdge(assign_sub->GetOutDataAnchor(0), assign_add2->GetInDataAnchor(0));
  154. ge::GraphUtils::AddEdge(const_node3->GetOutDataAnchor(0), assign_add2->GetInDataAnchor(1));
  155. ge::GraphUtils::AddEdge(assign_add2->GetOutDataAnchor(0), assign_add3->GetInDataAnchor(0));
  156. ge::GraphUtils::AddEdge(const_node4->GetOutDataAnchor(0), assign_add3->GetInDataAnchor(1));
  157. ge::VariablePrepareOpPass pass_;
  158. ge::Status status = pass_.Run(graph);
  159. EXPECT_EQ(assign_add3->GetOutDataNodes().size(), 0);
  160. EXPECT_EQ(SUCCESS, status);
  161. }

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示