You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

variable_accelerate_ctrl_unittest.cc 7.7 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include "passes/graph_builder_utils.h"
  18. #define private public
  19. #include "graph/manager/util/variable_accelerate_ctrl.h"
  20. #undef private
  21. namespace ge {
  22. class UtestVariableAccelerateCtrl : public testing::Test {
  23. protected:
  24. void SetUp() {}
  25. void TearDown() {}
  26. };
  27. namespace {
  28. /// netoutput1
  29. /// |
  30. /// shapeNo1
  31. /// |
  32. /// addnYes1
  33. /// / \.
  34. /// / \.
  35. /// const1 const2
  36. ComputeGraphPtr BuildGraph1() {
  37. auto builder = ut::GraphBuilder("test");
  38. auto const1 = builder.AddNode("const1", "CONSTANT", 0, 1);
  39. auto const2 = builder.AddNode("const2", "CONSTANT", 0, 1);
  40. auto addn1 = builder.AddNode("addn1", "AddNYes", 2, 1);
  41. auto shape1 = builder.AddNode("shape1", "ShapeNo", 1, 1);
  42. auto netoutput1 = builder.AddNode("netoutput", "NETOUTPUT", 1, 0);
  43. builder.AddDataEdge(const1, 0, addn1, 0);
  44. builder.AddDataEdge(const2, 0, addn1, 1);
  45. builder.AddDataEdge(addn1, 0, shape1, 0);
  46. builder.AddDataEdge(shape1, 0, netoutput1, 0);
  47. return builder.GetGraph();
  48. }
  49. ///
  50. /// netoutput1
  51. /// / \ \.
  52. /// add1 assign1 \.
  53. /// / \ / \ \.
  54. /// var1 var2 const1 var3
  55. ComputeGraphPtr BuildGraph2() {
  56. auto builder = ut::GraphBuilder("test");
  57. auto var1 = builder.AddNode("var1", "Variable", 0, 1);
  58. auto var2 = builder.AddNode("var2", "VariableV2", 0, 1);
  59. auto var3 = builder.AddNode("var3", "VarHandleOp", 0, 1);
  60. auto const1 = builder.AddNode("const1", "Const", 0, 1);
  61. auto add1 = builder.AddNode("add1", "Add", 2, 1);
  62. auto assign1 = builder.AddNode("assign1", "Assign", 2, 1);
  63. auto netoutput1 = builder.AddNode("netoutput1", "Netoutput", 3, 0);
  64. builder.AddDataEdge(var1, 0, add1, 0);
  65. builder.AddDataEdge(var2, 0, add1, 1);
  66. builder.AddDataEdge(var2, 0, assign1, 1);
  67. builder.AddDataEdge(var3, 0, netoutput1, 2);
  68. builder.AddDataEdge(const1, 0, assign1, 0);
  69. builder.AddDataEdge(add1, 0, netoutput1, 0);
  70. builder.AddDataEdge(assign1, 0, netoutput1, 1);
  71. return builder.GetGraph();
  72. }
  73. } // namespace
  74. TEST_F(UtestVariableAccelerateCtrl, add_graph_null_ptr) {
  75. VarAccelerateCtrl c;
  76. c.AddGraph(1, nullptr);
  77. EXPECT_TRUE(c.graph_ids_to_var_names_.empty());
  78. }
  79. TEST_F(UtestVariableAccelerateCtrl, add_graph_no_var) {
  80. VarAccelerateCtrl c;
  81. c.AddGraph(1, BuildGraph1());
  82. EXPECT_TRUE(c.graph_ids_to_var_names_.count(1) > 0);
  83. EXPECT_TRUE(c.graph_ids_to_var_names_[1].empty());
  84. }
  85. TEST_F(UtestVariableAccelerateCtrl, add_graph_vars) {
  86. VarAccelerateCtrl c;
  87. c.AddGraph(1, BuildGraph2());
  88. EXPECT_TRUE(c.graph_ids_to_var_names_.count(1) > 0);
  89. EXPECT_EQ(c.graph_ids_to_var_names_[1].size(), 3);
  90. EXPECT_EQ(c.graph_ids_to_var_names_[1].count("var1"), 1);
  91. EXPECT_EQ(c.graph_ids_to_var_names_[1].count("var2"), 1);
  92. EXPECT_EQ(c.graph_ids_to_var_names_[1].count("var3"), 1);
  93. }
  94. TEST_F(UtestVariableAccelerateCtrl, remove_graph_vars) {
  95. VarAccelerateCtrl c;
  96. c.AddGraph(1, BuildGraph2());
  97. EXPECT_FALSE(c.graph_ids_to_var_names_.empty());
  98. c.RemoveGraph(1);
  99. EXPECT_TRUE(c.graph_ids_to_var_names_.empty());
  100. }
  101. TEST_F(UtestVariableAccelerateCtrl, graph_rebuild) {
  102. VarAccelerateCtrl c;
  103. c.AddGraph(1, BuildGraph2());
  104. EXPECT_FALSE(c.IsGraphNeedRebuild(1));
  105. c.SetVarChanged("var1");
  106. EXPECT_TRUE(c.IsGraphNeedRebuild(1));
  107. }
  108. TEST_F(UtestVariableAccelerateCtrl, graph_rebuild_multi_changed) {
  109. VarAccelerateCtrl c;
  110. c.AddGraph(1, BuildGraph2());
  111. EXPECT_FALSE(c.IsGraphNeedRebuild(1));
  112. c.SetVarChanged("var2");
  113. c.SetVarChanged("var3");
  114. EXPECT_TRUE(c.IsGraphNeedRebuild(1));
  115. }
  116. TEST_F(UtestVariableAccelerateCtrl, graph_rebuild_multi_graph) {
  117. VarAccelerateCtrl c;
  118. c.AddGraph(1, BuildGraph2());
  119. c.AddGraph(2, BuildGraph2());
  120. EXPECT_FALSE(c.IsGraphNeedRebuild(1));
  121. EXPECT_FALSE(c.IsGraphNeedRebuild(2));
  122. c.SetVarChanged("var1");
  123. EXPECT_TRUE(c.IsGraphNeedRebuild(1));
  124. EXPECT_TRUE(c.IsGraphNeedRebuild(2));
  125. }
  126. TEST_F(UtestVariableAccelerateCtrl, graph_rebuild_after_remove_graph) {
  127. VarAccelerateCtrl c;
  128. c.AddGraph(1, BuildGraph2());
  129. c.AddGraph(2, BuildGraph2());
  130. EXPECT_FALSE(c.IsGraphNeedRebuild(1));
  131. EXPECT_FALSE(c.IsGraphNeedRebuild(2));
  132. c.SetVarChanged("var1");
  133. EXPECT_TRUE(c.IsGraphNeedRebuild(1));
  134. EXPECT_TRUE(c.IsGraphNeedRebuild(2));
  135. c.RemoveGraph(2);
  136. EXPECT_TRUE(c.IsGraphNeedRebuild(1));
  137. EXPECT_FALSE(c.IsGraphNeedRebuild(2));
  138. }
  139. TEST_F(UtestVariableAccelerateCtrl, graph_rebuild_after_build_end) {
  140. VarAccelerateCtrl c;
  141. c.AddGraph(1, BuildGraph2());
  142. c.AddGraph(2, BuildGraph2());
  143. EXPECT_FALSE(c.IsGraphNeedRebuild(1));
  144. EXPECT_FALSE(c.IsGraphNeedRebuild(2));
  145. c.SetVarChanged("var1");
  146. EXPECT_TRUE(c.IsGraphNeedRebuild(1));
  147. EXPECT_TRUE(c.IsGraphNeedRebuild(2));
  148. c.SetGraphBuildEnd(2);
  149. EXPECT_TRUE(c.IsGraphNeedRebuild(1));
  150. EXPECT_FALSE(c.IsGraphNeedRebuild(2));
  151. }
  152. TEST_F(UtestVariableAccelerateCtrl, var_permit_to_change) {
  153. VarAccelerateCtrl c;
  154. c.AddGraph(1, BuildGraph2());
  155. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var1"));
  156. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var2"));
  157. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var3"));
  158. c.SetVarChanged("var1");
  159. EXPECT_FALSE(c.IsVarPermitToChangeFormats("var1"));
  160. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var2"));
  161. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var3"));
  162. }
  163. TEST_F(UtestVariableAccelerateCtrl, var_permit_to_change_remove_graph_not_change) {
  164. VarAccelerateCtrl c;
  165. c.AddGraph(1, BuildGraph2());
  166. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var1"));
  167. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var2"));
  168. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var3"));
  169. c.SetVarChanged("var1");
  170. EXPECT_FALSE(c.IsVarPermitToChangeFormats("var1"));
  171. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var2"));
  172. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var3"));
  173. c.RemoveGraph(1);
  174. EXPECT_FALSE(c.IsVarPermitToChangeFormats("var1"));
  175. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var2"));
  176. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var3"));
  177. }
  178. TEST_F(UtestVariableAccelerateCtrl, var_permit_to_change_excceds_the_max_num) {
  179. VarAccelerateCtrl c;
  180. c.AddGraph(1, BuildGraph2());
  181. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var1"));
  182. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var2"));
  183. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var3"));
  184. c.SetVarChanged("var1");
  185. c.SetVarChanged("var1");
  186. c.SetVarChanged("var1");
  187. c.SetVarChanged("var1");
  188. c.SetVarChanged("var1");
  189. c.SetVarChanged("var1");
  190. EXPECT_FALSE(c.IsVarPermitToChangeFormats("var1"));
  191. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var2"));
  192. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var3"));
  193. }
  194. TEST_F(UtestVariableAccelerateCtrl, var_changed_before_add_graph) {
  195. VarAccelerateCtrl c;
  196. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var1"));
  197. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var2"));
  198. EXPECT_TRUE(c.IsVarPermitToChangeFormats("var3"));
  199. c.SetVarChanged("var1");
  200. EXPECT_FALSE(c.IsVarPermitToChangeFormats("var1"));
  201. c.AddGraph(1, BuildGraph2());
  202. EXPECT_FALSE(c.IsVarPermitToChangeFormats("var1"));
  203. // graph no need to set again
  204. EXPECT_FALSE(c.IsGraphNeedRebuild(1));
  205. }
  206. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示