You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

stop_gradient_pass_unittest.cc 7.4 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include <unordered_map>
  18. #define protected public
  19. #define private public
  20. #include "graph/passes/stop_gradient_pass.h"
  21. #include "common/debug/log.h"
  22. #include "common/debug/memory_dumper.h"
  23. #include "common/types.h"
  24. #include "external/graph/operator_reg.h"
  25. #include "framework/common/ge_inner_error_codes.h"
  26. #include "graph/debug/ge_attr_define.h"
  27. #include "graph/operator.h"
  28. #include "graph/utils/attr_utils.h"
  29. #include "graph/utils/graph_utils.h"
  30. #include "graph/utils/op_desc_utils.h"
  31. #include "graph/utils/tensor_utils.h"
  32. #include "inc/kernel_factory.h"
  33. #undef protected
  34. #undef private
  35. using namespace testing;
  36. using namespace ge;
  37. // for ir
  38. REG_OP(StopGradient)
  39. .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32,
  40. DT_UINT64, DT_BOOL, DT_DOUBLE}))
  41. .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, DT_UINT8, DT_INT32, DT_INT64, DT_UINT32,
  42. DT_UINT64, DT_BOOL, DT_DOUBLE}))
  43. .OP_END_FACTORY_REG(StopGradient)
  44. IMPLEMT_INFERFUNC(StopGradient, StopGradientInfer) {
  45. TensorDesc input_desc = op.GetInputDesc("x");
  46. (void)op.UpdateOutputDesc("y", input_desc);
  47. return GRAPH_SUCCESS;
  48. }
  49. INFER_FUNC_REG(StopGradient, StopGradientInfer);
  50. #define TEST_OPERATOR(op_, input_shapes, output_shapes) \
  51. { \
  52. auto op = op_; \
  53. for (auto input_pair : input_shapes) SetInputShape(op, input_pair.first, input_pair.second); \
  54. op.InferShapeAndType(); \
  55. for (auto output_pair : output_shapes) CheckOutputShape(op, output_pair.first, output_pair.second); \
  56. }
  57. #define LOOP_VEC(v) for (unsigned i = 0; i < v.size(); i++)
  58. class UtestGraphPassesStopGradientPass : public testing::Test {
  59. protected:
  60. void SetUp() { init(); }
  61. void TearDown() { destory(); }
  62. private:
  63. void init() {
  64. pass_ = new ::ge::StopGradientPass();
  65. graph_ = std::make_shared<ge::ComputeGraph>("default");
  66. op_desc_ptr_ = std::make_shared<OpDesc>("stop_gradient", STOPGRADIENT);
  67. node_ = std::make_shared<Node>(op_desc_ptr_, graph_);
  68. kernel_ = KernelFactory::Instance().Create(STOPGRADIENT);
  69. }
  70. void destory() {
  71. delete pass_;
  72. pass_ = NULL;
  73. }
  74. protected:
  75. ge::StopGradientPass *pass_;
  76. ge::ComputeGraphPtr graph_;
  77. OpDescPtr op_desc_ptr_;
  78. NodePtr node_;
  79. shared_ptr<Kernel> kernel_;
  80. void SetInputShape(Operator op, string name, vector<int64_t> shape) {
  81. TensorDesc td = op.GetInputDesc(name);
  82. td.SetShape(ge::Shape(shape));
  83. op.UpdateInputDesc(name, td);
  84. }
  85. void CheckOutputShape(Operator op, string name, vector<int64_t> shape) {
  86. ge::Shape s = op.GetOutputDesc(name).GetShape();
  87. EXPECT_EQ(s.GetDims().size(), shape.size());
  88. LOOP_VEC(shape) EXPECT_EQ(s.GetDim(i), shape[i]);
  89. }
  90. NodePtr init_node(ComputeGraphPtr graph, string &type) {
  91. // middle
  92. OpDescPtr op_def = std::make_shared<OpDesc>("op_def", type);
  93. OpDescPtr in_op_def_0 = std::make_shared<OpDesc>("op_def_in", "test");
  94. OpDescPtr out_op_def = std::make_shared<OpDesc>("op_def_in", "test");
  95. // in_op_def_0
  96. vector<int64_t> dims_vec_0 = {2, 1, 4, 1, 2};
  97. vector<int32_t> data_vec_0 = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
  98. GeTensorDesc tensor_desc_0(GeShape(dims_vec_0), FORMAT_NCHW, DT_INT32);
  99. (void)TensorUtils::SetRealDimCnt(tensor_desc_0, dims_vec_0.size());
  100. ge::ConstGeTensorPtr constTensor_0 =
  101. std::make_shared<GeTensor>(tensor_desc_0, (uint8_t *)&data_vec_0[0], data_vec_0.size() * sizeof(int32_t));
  102. ge::AttrUtils::SetTensor(in_op_def_0, ge::ATTR_NAME_WEIGHTS, constTensor_0);
  103. vector<int64_t> dims = {2, 2, 4, 3, 2};
  104. ge::GeShape shape_desc(dims);
  105. GeTensorDesc tensor_desc(shape_desc);
  106. in_op_def_0->AddOutputDesc(tensor_desc);
  107. in_op_def_0->SetType("Constant");
  108. // op_def
  109. GeTensorDesc tensor_desc_out(GeShape(), FORMAT_NCHW, DT_INT32);
  110. op_def->AddInputDesc(tensor_desc_0);
  111. op_def->AddOutputDesc(tensor_desc_out);
  112. vector<bool> is_input_const_vec = {
  113. true,
  114. };
  115. op_def->SetIsInputConst(is_input_const_vec);
  116. AttrUtils::SetInt(op_def, ge::ATTR_NAME_T, (int64_t)DT_INT32);
  117. // add attr of out_node
  118. vector<bool> is_input_const(1);
  119. is_input_const[0] = true;
  120. out_op_def->SetIsInputConst(is_input_const);
  121. out_op_def->AddInputDesc(tensor_desc_0);
  122. // Add node
  123. NodePtr in_node_0 = graph->AddNode(in_op_def_0);
  124. NodePtr node = graph->AddNode(op_def);
  125. NodePtr out_node = graph->AddNode(out_op_def);
  126. // Add edge
  127. GraphUtils::AddEdge(in_node_0->GetOutDataAnchor(0), node->GetInDataAnchor(0));
  128. GraphUtils::AddEdge(node->GetOutDataAnchor(0), out_node->GetInDataAnchor(0));
  129. return node;
  130. }
  131. };
  132. TEST_F(UtestGraphPassesStopGradientPass, success) {
  133. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  134. string type = STOPGRADIENT;
  135. NodePtr node = init_node(graph, type);
  136. ge::Status ret = pass_->Run(node);
  137. EXPECT_EQ(ge::SUCCESS, ret);
  138. }
  139. TEST_F(UtestGraphPassesStopGradientPass, not_changed) {
  140. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  141. string type = SIZE;
  142. NodePtr node = init_node(graph, type);
  143. ge::Status ret = pass_->Run(node);
  144. EXPECT_EQ(ge::SUCCESS, ret);
  145. }
  146. TEST_F(UtestGraphPassesStopGradientPass, get_origenal_type_fail) {
  147. ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test");
  148. string type = STOPGRADIENT;
  149. NodePtr node = init_node(graph, type);
  150. string type2 = "FrameworkOp";
  151. node->GetOpDesc()->SetType(type2);
  152. ge::Status ret = pass_->Run(node);
  153. }
  154. TEST_F(UtestGraphPassesStopGradientPass, size_check_fail) {
  155. vector<int64_t> dims_vec_0 = {8, 2};
  156. GeTensorDesc tensor_desc_0(GeShape(dims_vec_0), FORMAT_NCHW, DT_FLOAT);
  157. op_desc_ptr_->AddInputDesc(tensor_desc_0);
  158. vector<int64_t> dims_vec_1 = {3, 4, 5};
  159. GeTensorDesc tensor_desc_1(GeShape(dims_vec_1), FORMAT_NCHW, DT_FLOAT);
  160. op_desc_ptr_->AddInputDesc(tensor_desc_1);
  161. GeTensorDesc tensor_desc_out(GeShape(), FORMAT_NCHW, DT_INT64);
  162. op_desc_ptr_->AddOutputDesc(tensor_desc_out);
  163. ge::Status ret = pass_->Run(node_);
  164. EXPECT_EQ(ge::FAILED, ret);
  165. }
  166. TEST_F(UtestGraphPassesStopGradientPass, ir_infer_shape) {
  167. auto i = std::unordered_map<string, vector<int64_t>>({
  168. {"x", {2, 1, 5, 3}},
  169. });
  170. auto o = std::unordered_map<string, vector<int64_t>>({
  171. {"y", {2, 1, 5, 3}},
  172. });
  173. auto test_op = op::StopGradient("test_op");
  174. TEST_OPERATOR(test_op, i, o);
  175. }

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示