You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

guarantee_const_pass_unittest.cc 8.7 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include <unordered_map>
  18. #include <vector>
  19. #define protected public
  20. #define private public
  21. #include "graph/passes/guarantee_const_pass.h"
  22. #include "../ops_stub.h"
  23. #include "common/ge_inner_error_codes.h"
  24. #include "common/types.h"
  25. #include "graph/debug/ge_attr_define.h"
  26. #include "graph/utils/attr_utils.h"
  27. #include "graph/utils/graph_utils.h"
  28. #include "graph/utils/op_desc_utils.h"
  29. #include "graph/utils/tensor_utils.h"
  30. #include "inc/pass_manager.h"
  31. #undef protected
  32. #undef private
  33. using namespace testing;
  34. using namespace ge;
  35. using namespace std;
  36. // To check whether the shape of output is correct or not
  37. #define TEST_OPERATOR(op_, input_shapes, output_shapes) \
  38. { \
  39. auto op = op_; \
  40. for (auto input_pair : input_shapes) SetInputShape(op, input_pair.first, input_pair.second); \
  41. op.InferShapeAndType(); \
  42. for (auto output_pair : output_shapes) CheckOutputShape(op, output_pair.first, output_pair.second); \
  43. }
  44. #define LOOP_VEC(v) for (unsigned i = 0; i < v.size(); i++)
  45. class UtestGraphPassesGuaranteeConstPass : public testing::Test {
  46. protected:
  47. void SetUp() { init(); }
  48. void TearDown() { destory(); }
  49. private:
  50. void init() { guarantee_const_op_remove_pass_ = new ::ge::GuaranteeConstPass(); }
  51. void destory() {
  52. delete guarantee_const_op_remove_pass_;
  53. guarantee_const_op_remove_pass_ = NULL;
  54. }
  55. protected:
  56. ge::GuaranteeConstPass *guarantee_const_op_remove_pass_;
  57. void SetInputShape(Operator op, string name, vector<int64_t> shape) {
  58. TensorDesc td = op.GetInputDesc(name);
  59. td.SetShape(ge::Shape(shape));
  60. op.UpdateInputDesc(name, td);
  61. }
  62. void CheckOutputShape(Operator op, string name, vector<int64_t> shape) {
  63. ge::Shape s = op.GetOutputDesc(name).GetShape();
  64. EXPECT_EQ(s.GetDims().size(), shape.size());
  65. LOOP_VEC(shape) EXPECT_EQ(s.GetDim(i), shape[i]);
  66. }
  67. /// Init the node which will be passed in graph, isMultiInput represents whether using more than
  68. /// one data anchor or not.
  69. NodePtr init_node(ComputeGraphPtr graph, vector<int64_t> dims_vec, vector<int32_t> data_vec, bool isMultiInput,
  70. string type) {
  71. // middle
  72. OpDescPtr op_def = std::make_shared<OpDesc>("op_def", type);
  73. OpDescPtr in_op_def = std::make_shared<OpDesc>("op_def_in", "test");
  74. OpDescPtr out_op_def = std::make_shared<OpDesc>("op_def_out", "test");
  75. OpDescPtr another_in_op_def = std::make_shared<OpDesc>("another_op_def_in", "test");
  76. // whether using another input data anchor or not
  77. if (isMultiInput) {
  78. vector<bool> is_input_const_vec = {true, true};
  79. op_def->SetIsInputConst(is_input_const_vec);
  80. AttrUtils::SetInt(op_def, ge::ATTR_NAME_T, (int64_t)DT_INT32);
  81. }
  82. // input tensor;
  83. GeTensorDesc tensor_desc(GeShape(dims_vec), FORMAT_NCHW, DT_INT32);
  84. ge::ConstGeTensorPtr const_tensor =
  85. std::make_shared<GeTensor>(tensor_desc, (uint8_t *)&data_vec[0], data_vec.size() * sizeof(int32_t));
  86. ge::AttrUtils::SetTensor(in_op_def, ge::ATTR_NAME_WEIGHTS, const_tensor);
  87. op_def->AddInputDesc(tensor_desc);
  88. // whether using another input data anchor or not
  89. if (isMultiInput) {
  90. vector<int64_t> dims_vec_another = {6};
  91. vector<int32_t> data_vec_another = {1, 2, 3, 4, 5, 6};
  92. GeTensorDesc another_tensor_desc(GeShape(dims_vec_another), FORMAT_NCHW, DT_INT32);
  93. ge::ConstGeTensorPtr const_tensor_another = std::make_shared<GeTensor>(
  94. another_tensor_desc, (uint8_t *)&data_vec_another[0], data_vec_another.size() * sizeof(int32_t));
  95. ge::AttrUtils::SetTensor(another_in_op_def, ge::ATTR_NAME_WEIGHTS, const_tensor_another);
  96. op_def->AddInputDesc(another_tensor_desc);
  97. another_in_op_def->AddOutputDesc(another_tensor_desc);
  98. out_op_def->AddInputDesc(another_tensor_desc);
  99. }
  100. GeTensorDesc tensor_desc_out(GeShape(dims_vec), FORMAT_NCHW, DT_INT32);
  101. op_def->AddOutputDesc(tensor_desc_out);
  102. in_op_def->AddOutputDesc(tensor_desc);
  103. // add attr of out_node
  104. vector<bool> is_output_const(3, false);
  105. is_output_const[0] = true;
  106. out_op_def->SetIsInputConst(is_output_const);
  107. out_op_def->AddInputDesc(tensor_desc);
  108. // Add node
  109. NodePtr in_node = graph->AddNode(in_op_def);
  110. NodePtr node = graph->AddNode(op_def);
  111. NodePtr out_node = graph->AddNode(out_op_def);
  112. // Add edge
  113. GraphUtils::AddEdge(in_node->GetOutDataAnchor(0), node->GetInDataAnchor(0));
  114. GraphUtils::AddEdge(node->GetOutDataAnchor(0), out_node->GetInDataAnchor(0));
  115. // when need multi input nodes (which to verify the isolate node function)
  116. if (isMultiInput) {
  117. NodePtr another_in_node = graph->AddNode(another_in_op_def);
  118. GraphUtils::AddEdge(another_in_node->GetOutDataAnchor(0), node->GetInDataAnchor(1));
  119. }
  120. return node;
  121. }
  122. };
  123. TEST_F(UtestGraphPassesGuaranteeConstPass, not_changed) {
  124. // the original type of op is not guarantee_const
  125. string type = SIZE;
  126. // input tensor
  127. vector<int64_t> dims_vec = {6};
  128. vector<int32_t> data_vec = {1, 2, 3, 4, 5, 6};
  129. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  130. NodePtr node = init_node(graph, dims_vec, data_vec, false, type);
  131. ge::Status ret = guarantee_const_op_remove_pass_->Run(node);
  132. EXPECT_EQ(SUCCESS, ret);
  133. }
  134. TEST_F(UtestGraphPassesGuaranteeConstPass, get_origenal_type_fail) {
  135. string type = GUARANTEECONST;
  136. // input tensor
  137. vector<int64_t> dims_vec = {6};
  138. vector<int32_t> data_vec = {1, 2, 3, 4, 5, 6};
  139. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  140. NodePtr node = init_node(graph, dims_vec, data_vec, false, type);
  141. // change the type
  142. string type2 = "FrameworkOp";
  143. node->GetOpDesc()->SetType(type2);
  144. ge::Status ret = guarantee_const_op_remove_pass_->Run(node);
  145. }
  146. TEST_F(UtestGraphPassesGuaranteeConstPass, int32_success_6) {
  147. // input tensor
  148. string type = GUARANTEECONST;
  149. vector<int64_t> dims_vec = {6};
  150. vector<int32_t> data_vec = {1, 2, 3, 4, 5, 6};
  151. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  152. NodePtr node = init_node(graph, dims_vec, data_vec, false, type);
  153. // when input tensor is [1, 2, 3, 4, 5, 6], return success
  154. ge::Status output = guarantee_const_op_remove_pass_->Run(node);
  155. EXPECT_EQ(ge::SUCCESS, output);
  156. }
  157. TEST_F(UtestGraphPassesGuaranteeConstPass, int32_success_2_3) {
  158. // input tensor
  159. string type = GUARANTEECONST;
  160. vector<int64_t> dims_vec = {2, 3};
  161. vector<int32_t> data_vec = {1, 2, 3, 4, 5, 6};
  162. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  163. NodePtr node = init_node(graph, dims_vec, data_vec, false, type);
  164. // when input tensor is [[1, 2, 3], [4, 5, 6]], return success
  165. ge::Status output = guarantee_const_op_remove_pass_->Run(node);
  166. EXPECT_EQ(ge::SUCCESS, output);
  167. }
  168. TEST_F(UtestGraphPassesGuaranteeConstPass, isolate_node_failed) {
  169. // input tensor
  170. string type = GUARANTEECONST;
  171. vector<int64_t> dims_vec = {2, 3};
  172. vector<int32_t> data_vec = {1, 2, 3, 4, 5, 6};
  173. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  174. // add another input node
  175. NodePtr node = init_node(graph, dims_vec, data_vec, true, type);
  176. // when there are more than one input anchors, return failed
  177. ge::Status output = guarantee_const_op_remove_pass_->Run(node);
  178. EXPECT_EQ(ge::PARAM_INVALID, output);
  179. }
  180. // IR test, the shape and data type of input should be equal to the shape and data type of output
  181. TEST_F(UtestGraphPassesGuaranteeConstPass, ir_infer_shape) {
  182. auto input = unordered_map<string, vector<int64_t>>({
  183. {"x", {3, 5, 3, 4}},
  184. });
  185. auto output = unordered_map<string, vector<int64_t>>({
  186. {"y", {3, 5, 3, 4}},
  187. });
  188. auto guaranteeConst = op::GuaranteeConst("guaranteeconst");
  189. TEST_OPERATOR(guaranteeConst, input, output);
  190. }

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示