You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

pass_utils_unittest.cc 7.3 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/pass_utils.h"
  17. #include <gtest/gtest.h>
  18. #include <vector>
  19. #include "common/types.h"
  20. #include "graph/types.h"
  21. #include "graph/utils/graph_utils.h"
  22. #include "graph/utils/op_desc_utils.h"
  23. #include "graph_builder_utils.h"
  24. #include "inc/kernel.h"
  25. #include "inc/kernel_factory.h"
  26. using namespace ge;
  27. class UtestGraphPassesPassUtils : public testing::Test {
  28. protected:
  29. void SetUp() {}
  30. void TearDown() {}
  31. };
  32. class NodeBuilder {
  33. public:
  34. NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared<OpDesc>(name, type); }
  35. NodeBuilder &AddInputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
  36. ge::DataType data_type = DT_FLOAT) {
  37. op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
  38. return *this;
  39. }
  40. NodeBuilder &AddOutputDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
  41. ge::DataType data_type = DT_FLOAT) {
  42. op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone());
  43. return *this;
  44. }
  45. ge::NodePtr Build(const ge::ComputeGraphPtr &graph) { return graph->AddNode(op_desc_); }
  46. private:
  47. ge::GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, ge::Format format = FORMAT_NCHW,
  48. ge::DataType data_type = DT_FLOAT) {
  49. GeShape ge_shape{std::vector<int64_t>(shape)};
  50. ge::GeTensorDescPtr tensor_desc = std::make_shared<ge::GeTensorDesc>();
  51. tensor_desc->SetShape(ge_shape);
  52. tensor_desc->SetFormat(format);
  53. tensor_desc->SetDataType(data_type);
  54. return tensor_desc;
  55. }
  56. ge::OpDescPtr op_desc_;
  57. };
  58. TEST_F(UtestGraphPassesPassUtils, set_out_node_weight) {
  59. ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  60. // data
  61. ge::NodePtr node_data = NodeBuilder("data", DATA).AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT).Build(graph);
  62. // const
  63. ge::NodePtr node_const =
  64. NodeBuilder("const", CONSTANT).AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT).Build(graph);
  65. // relu
  66. ge::NodePtr node_relu = NodeBuilder("node_relu1", RELU)
  67. .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  68. .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  69. .Build(graph);
  70. // sinh
  71. ge::NodePtr node_sinh = NodeBuilder("node_sinh", SINH)
  72. .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  73. .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  74. .Build(graph);
  75. // relu
  76. ge::NodePtr node_relu2 = NodeBuilder("node_relu2", RELU)
  77. .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  78. .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  79. .Build(graph);
  80. // sinh
  81. ge::NodePtr node_sinh2 = NodeBuilder("node_sinh2", SINH)
  82. .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  83. .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  84. .Build(graph);
  85. // add edge
  86. ge::GraphUtils::AddEdge(node_data->GetOutControlAnchor(), node_const->GetInControlAnchor());
  87. ge::GraphUtils::AddEdge(node_const->GetOutDataAnchor(0), node_relu->GetInDataAnchor(0));
  88. ge::GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_sinh->GetInDataAnchor(0));
  89. ge::GraphUtils::AddEdge(node_relu->GetOutDataAnchor(0), node_relu2->GetInControlAnchor());
  90. ge::GraphUtils::AddEdge(node_relu2->GetOutDataAnchor(0), node_sinh2->GetInDataAnchor(0));
  91. for (auto node : graph->GetDirectNode()) {
  92. if (node->GetType() == CONSTANT) {
  93. int32_t weight[] = {1};
  94. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32);
  95. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  96. vector<GeTensorPtr> tensor_vec = {tensor};
  97. OpDescUtils::SetWeights(node, tensor_vec);
  98. }
  99. if (!node->GetOutDataNodes().empty()) {
  100. auto out_data_anchor = node->GetOutDataNodes().at(0)->GetOutDataAnchor(0);
  101. Status status = PassUtils::SetOutNodeWeight(out_data_anchor, node);
  102. EXPECT_EQ(SUCCESS, status);
  103. }
  104. }
  105. }
  106. // only some failure castes for coverage check
  107. TEST_F(UtestGraphPassesPassUtils, is_constant_null) {
  108. ge::NodePtr node = nullptr;
  109. bool ret = PassUtils::IsConstant(node);
  110. EXPECT_EQ(false, ret);
  111. }
  112. TEST_F(UtestGraphPassesPassUtils, get_in_data_node_fail) {
  113. ge::NodePtr node = nullptr;
  114. NodePtr in_data_node = PassUtils::GetInDataNode(node, 0);
  115. EXPECT_EQ(nullptr, in_data_node);
  116. ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  117. // relu
  118. ge::NodePtr node_relu = NodeBuilder("relu", RELU)
  119. .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  120. .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  121. .Build(graph);
  122. NodePtr data_node = PassUtils::GetInDataNode(node_relu, 1);
  123. EXPECT_EQ(nullptr, data_node);
  124. }
  125. TEST_F(UtestGraphPassesPassUtils, get_unique_in_data_anchor_index_failed) {
  126. int invalid_index = -1;
  127. ge::NodePtr node = nullptr;
  128. int status = PassUtils::GetUniqueInDataAnchorIndex(node);
  129. EXPECT_EQ(invalid_index, status);
  130. ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  131. // relu
  132. ge::NodePtr node_relu = NodeBuilder("relu", RELU)
  133. .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  134. .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  135. .Build(graph);
  136. int ret = PassUtils::GetUniqueInDataAnchorIndex(node_relu);
  137. EXPECT_EQ(invalid_index, ret);
  138. }
  139. TEST_F(UtestGraphPassesPassUtils, unlink_node_with_ctrl_copy_fail) {
  140. ge::ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  141. // relu
  142. ge::NodePtr node_relu = NodeBuilder("relu", RELU)
  143. .AddInputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  144. .AddOutputDesc({2, 2, 2, 2}, FORMAT_NCHW, DT_FLOAT)
  145. .Build(graph);
  146. Status status = PassUtils::UnlinkNodeWithControlCopy(node_relu, 1);
  147. EXPECT_EQ(ge::SUCCESS, status);
  148. Status ret = PassUtils::UnlinkNodeWithControlCopy(node_relu, 0);
  149. EXPECT_EQ(ge::FAILED, ret);
  150. }
  151. TEST_F(UtestGraphPassesPassUtils, null_input) {
  152. std::vector<NodePtr> deleted_nodes;
  153. std::vector<NodePtr> end_nodes;
  154. EXPECT_NE(PassUtils::RemoveInactiveBranchToMerge(nullptr, deleted_nodes, end_nodes), 0);
  155. }

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示