You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

switch_pass_unittest.cc 16 kB

5 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include <cstdint>
  18. #include <string>
  19. #define private public
  20. #include "common/ge_inner_error_codes.h"
  21. #include "inc/pass_manager.h"
  22. #include "utils/graph_utils.h"
  23. #undef private
  24. namespace ge {
  25. namespace {
  26. class UtestGraphPassesSwitchPass : public testing::Test {
  27. protected:
  28. UtestGraphPassesSwitchPass() {
  29. graph_ = std::make_shared<ComputeGraph>("test");
  30. vector<int64_t> shape_vec{1, 1, 1, 1};
  31. GeShape shape = GeShape(shape_vec);
  32. default_tensor_desc_ = std::make_shared<GeTensorDesc>();
  33. default_tensor_desc_->SetShape(shape);
  34. default_tensor_desc_->SetFormat(FORMAT_NCHW);
  35. default_tensor_desc_->SetDataType(DT_FLOAT);
  36. }
  37. NodePtr NewNode(const std::string &name, const std::string &type, int input_cnt, int output_cnt) {
  38. OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
  39. for (int i = 0; i < input_cnt; ++i) {
  40. op_desc->AddInputDesc(default_tensor_desc_->Clone());
  41. }
  42. for (int i = 0; i < output_cnt; ++i) {
  43. op_desc->AddOutputDesc(default_tensor_desc_->Clone());
  44. }
  45. NodePtr node = graph_->AddNode(op_desc);
  46. (void)node->SetOwnerComputeGraph(graph_);
  47. return node;
  48. }
  49. void BuildDefaultGraph(bool is_input_const, const bool *pred_value = nullptr) {
  50. /// input pred
  51. /// \ /
  52. /// Switch
  53. /// | |
  54. /// F T
  55. /// | |
  56. /// Merge
  57. ///
  58. bool is_pred_const = pred_value != nullptr;
  59. if (is_pred_const) {
  60. pred_node_ = NewNode("pred", CONSTANT, 0, 1);
  61. int32_t weight[] = {static_cast<int32_t>(*pred_value)};
  62. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32);
  63. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  64. OpDescUtils::SetWeights(pred_node_, {tensor});
  65. } else {
  66. pred_node_ = NewNode("pred", GREATER, 2, 1);
  67. }
  68. if (is_input_const) {
  69. int32_t weight[] = {1};
  70. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32);
  71. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  72. input_node_ = NewNode("input", CONSTANT, 0, 1);
  73. OpDescUtils::SetWeights(input_node_, {tensor});
  74. } else {
  75. input_node_ = NewNode("input", RELU, 0, 1);
  76. }
  77. switch_node_ = NewNode("switch", SWITCH, 2, 2);
  78. output_false_node_ = NewNode("false_output", RELU, 1, 1);
  79. output_true_node_ = NewNode("true_output", RELU, 1, 1);
  80. merge_node_ = NewNode("merge", MERGE, 2, 1);
  81. switch_node_->GetOpDesc()->SetIsInputConst({false, is_pred_const});
  82. GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(0));
  83. GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(1));
  84. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
  85. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInDataAnchor(0));
  86. GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0));
  87. GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
  88. output_false_node_->GetOpDesc()->SetIsInputConst({false});
  89. output_true_node_->GetOpDesc()->SetIsInputConst({false});
  90. }
  91. void TestPickOutput(bool expect_output) {
  92. auto ret = pass_.Run(switch_node_);
  93. EXPECT_EQ(ret, SUCCESS);
  94. EXPECT_EQ(graph_->GetDirectNodesSize(), 5); // has two isolate nodes
  95. EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
  96. if (expect_output) {
  97. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor().get(), nullptr);
  98. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor(), output_true_node_->GetOutDataAnchor(0));
  99. EXPECT_EQ(output_true_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
  100. } else {
  101. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), output_false_node_->GetOutDataAnchor(0));
  102. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor().get(), nullptr);
  103. EXPECT_EQ(output_false_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
  104. }
  105. }
  106. ComputeGraphPtr graph_;
  107. GeTensorDescPtr default_tensor_desc_;
  108. SwitchPass pass_;
  109. NodePtr pred_node_;
  110. NodePtr input_node_;
  111. NodePtr switch_node_;
  112. NodePtr output_false_node_;
  113. NodePtr output_true_node_;
  114. NodePtr merge_node_;
  115. };
  116. } // namespace
  117. TEST_F(UtestGraphPassesSwitchPass, null_input) {
  118. NodePtr node = nullptr;
  119. auto ret = pass_.Run(node);
  120. EXPECT_EQ(ret, PARAM_INVALID);
  121. }
  122. TEST_F(UtestGraphPassesSwitchPass, null_pred) {
  123. BuildDefaultGraph(false);
  124. switch_node_->GetInDataAnchor(1)->UnlinkAll();
  125. auto ret = pass_.Run(switch_node_);
  126. EXPECT_EQ(ret, SUCCESS);
  127. }
  128. TEST_F(UtestGraphPassesSwitchPass, null_data) {
  129. BuildDefaultGraph(false);
  130. switch_node_->GetInDataAnchor(0)->UnlinkAll();
  131. auto ret = pass_.Run(switch_node_);
  132. EXPECT_EQ(ret, SUCCESS);
  133. }
  134. TEST_F(UtestGraphPassesSwitchPass, unsupported_node_type) {
  135. auto node = NewNode("Op1", CONSTANT, 0, 1);
  136. auto ret = pass_.Run(node);
  137. EXPECT_EQ(ret, SUCCESS);
  138. }
  139. TEST_F(UtestGraphPassesSwitchPass, empty_output) {
  140. BuildDefaultGraph(false);
  141. switch_node_->GetOutDataAnchor(0)->UnlinkAll();
  142. switch_node_->GetOutDataAnchor(1)->UnlinkAll();
  143. auto ret = pass_.Run(switch_node_);
  144. EXPECT_EQ(ret, SUCCESS);
  145. }
  146. TEST_F(UtestGraphPassesSwitchPass, non_const_pred) {
  147. BuildDefaultGraph(false);
  148. auto ret = pass_.Run(switch_node_);
  149. EXPECT_EQ(ret, SUCCESS);
  150. }
  151. TEST_F(UtestGraphPassesSwitchPass, pick_output_false) {
  152. bool pred_value = false;
  153. BuildDefaultGraph(false, &pred_value);
  154. TestPickOutput(false);
  155. }
  156. TEST_F(UtestGraphPassesSwitchPass, pick_output_false_float) {
  157. bool pred_value = false;
  158. BuildDefaultGraph(false, &pred_value);
  159. float weight[] = {0.0f};
  160. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_FLOAT);
  161. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  162. OpDescUtils::SetWeights(pred_node_, {tensor});
  163. TestPickOutput(false);
  164. }
  165. TEST_F(UtestGraphPassesSwitchPass, pick_output_false_bool) {
  166. bool pred_value = false;
  167. BuildDefaultGraph(false, &pred_value);
  168. bool weight[] = {false};
  169. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_BOOL);
  170. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  171. OpDescUtils::SetWeights(pred_node_, {tensor});
  172. TestPickOutput(false);
  173. }
  174. TEST_F(UtestGraphPassesSwitchPass, pick_output_false_u16) {
  175. bool pred_value = false;
  176. BuildDefaultGraph(false, &pred_value);
  177. uint16_t weight[] = {0};
  178. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_UINT16);
  179. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  180. OpDescUtils::SetWeights(pred_node_, {tensor});
  181. TestPickOutput(false);
  182. }
  183. TEST_F(UtestGraphPassesSwitchPass, pick_output_true) {
  184. bool pred_value = true;
  185. BuildDefaultGraph(false, &pred_value);
  186. TestPickOutput(true);
  187. }
  188. TEST_F(UtestGraphPassesSwitchPass, pick_output_true_double) {
  189. bool pred_value = true;
  190. BuildDefaultGraph(false, &pred_value);
  191. double weight[] = {1.0};
  192. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_DOUBLE);
  193. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  194. OpDescUtils::SetWeights(pred_node_, {tensor});
  195. TestPickOutput(true);
  196. }
  197. TEST_F(UtestGraphPassesSwitchPass, pick_output_true_int64) {
  198. bool pred_value = true;
  199. BuildDefaultGraph(false, &pred_value);
  200. int64_t weight[] = {1L};
  201. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT64);
  202. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  203. OpDescUtils::SetWeights(pred_node_, {tensor});
  204. TestPickOutput(true);
  205. }
  206. TEST_F(UtestGraphPassesSwitchPass, inactive_output_not_exists) {
  207. /// input pred(false)
  208. /// \ /
  209. /// Switch
  210. /// |
  211. /// F
  212. /// |
  213. /// Merge
  214. bool pred_value = false;
  215. BuildDefaultGraph(false, &pred_value);
  216. output_true_node_->GetOutDataAnchor(0)->UnlinkAll();
  217. GraphUtils::RemoveNodeWithoutRelink(graph_, output_true_node_);
  218. switch_node_->GetOutDataAnchor(1)->UnlinkAll();
  219. /// input
  220. /// |
  221. /// F
  222. /// |
  223. /// Merge
  224. auto ret = pass_.Run(switch_node_);
  225. EXPECT_EQ(ret, SUCCESS);
  226. EXPECT_EQ(graph_->GetDirectNodesSize(), 4);
  227. EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
  228. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), output_false_node_->GetOutDataAnchor(0));
  229. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor().get(), nullptr);
  230. EXPECT_EQ(output_false_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
  231. }
  232. TEST_F(UtestGraphPassesSwitchPass, const_input_pick_output_true) {
  233. /// const pred(true)
  234. /// \ /
  235. /// Switch
  236. /// | | \
  237. /// F T1 T2
  238. /// | | |
  239. /// | | /
  240. /// | T3
  241. /// | |
  242. /// Merge
  243. bool pred_value = true;
  244. BuildDefaultGraph(true, &pred_value);
  245. auto output_true_node2 = NewNode("true_output2", RELU, 1, 1);
  246. auto output_true_node3 = NewNode("true_output3", ADD, 2, 1);
  247. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node2->GetInDataAnchor(0));
  248. GraphUtils::RemoveEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
  249. GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), output_true_node3->GetInDataAnchor(0));
  250. GraphUtils::AddEdge(output_true_node2->GetOutDataAnchor(0), output_true_node3->GetInDataAnchor(1));
  251. GraphUtils::AddEdge(output_true_node3->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
  252. /// pred C
  253. /// | | |
  254. /// F T1 T2
  255. /// | /
  256. /// T3
  257. /// |
  258. /// Merge
  259. auto ret = pass_.Run(switch_node_);
  260. EXPECT_EQ(ret, SUCCESS);
  261. EXPECT_EQ(graph_->GetDirectNodesSize(), 7);
  262. EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
  263. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor().get(), nullptr);
  264. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor(), output_true_node3->GetOutDataAnchor(0));
  265. EXPECT_EQ(output_true_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
  266. EXPECT_NE(output_true_node2->GetInDataAnchor(0)->GetPeerOutAnchor(),
  267. output_true_node3->GetInDataAnchor(0)->GetPeerOutAnchor());
  268. }
  269. TEST_F(UtestGraphPassesSwitchPass, after_switch_const_take_false_branch) {
  270. /// C pred(false)
  271. /// \ /
  272. /// Switch
  273. /// . .
  274. /// . .
  275. /// C_1 -> F T <- C_2
  276. /// | |
  277. /// Merge
  278. bool pred_value = false;
  279. BuildDefaultGraph(true, &pred_value);
  280. switch_node_->GetOutDataAnchor(0)->UnlinkAll();
  281. switch_node_->GetOutDataAnchor(1)->UnlinkAll();
  282. NodePtr const_node_1 = NewNode("const_1", CONSTANT, 0, 1);
  283. NodePtr const_node_2 = NewNode("const_2", CONSTANT, 0, 1);
  284. GraphUtils::AddEdge(const_node_1->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
  285. GraphUtils::AddEdge(const_node_2->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0));
  286. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInControlAnchor());
  287. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInControlAnchor());
  288. /// C pred(false)
  289. ///
  290. /// C_1 C_2
  291. /// | |
  292. /// F T
  293. /// |
  294. /// Merge
  295. auto ret = pass_.Run(switch_node_);
  296. EXPECT_EQ(ret, SUCCESS);
  297. EXPECT_EQ(graph_->GetDirectNodesSize(), 7);
  298. EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
  299. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), output_false_node_->GetOutDataAnchor(0));
  300. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor().get(), nullptr);
  301. EXPECT_EQ(output_false_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), const_node_1->GetOutDataAnchor(0));
  302. }
  303. TEST_F(UtestGraphPassesSwitchPass, after_switch_const_take_true_branch) {
  304. /// C pred(true)
  305. /// \ /
  306. /// Switch
  307. /// . .
  308. /// . .
  309. /// C_1 -> F T <- C_2
  310. /// | |
  311. /// Merge
  312. bool pred_value = true;
  313. BuildDefaultGraph(true, &pred_value);
  314. switch_node_->GetOutDataAnchor(0)->UnlinkAll();
  315. switch_node_->GetOutDataAnchor(1)->UnlinkAll();
  316. NodePtr const_node_1 = NewNode("const_1", CONSTANT, 0, 1);
  317. NodePtr const_node_2 = NewNode("const_2", CONSTANT, 0, 1);
  318. GraphUtils::AddEdge(const_node_1->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
  319. GraphUtils::AddEdge(const_node_2->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0));
  320. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInControlAnchor());
  321. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInControlAnchor());
  322. /// C_1 C_2
  323. /// | |
  324. /// F T
  325. /// |
  326. /// Merge
  327. auto ret = pass_.Run(switch_node_);
  328. EXPECT_EQ(ret, SUCCESS);
  329. EXPECT_EQ(graph_->GetDirectNodesSize(), 7);
  330. EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
  331. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor().get(), nullptr);
  332. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor(), output_true_node_->GetOutDataAnchor(0));
  333. EXPECT_EQ(output_true_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), const_node_2->GetOutDataAnchor(0));
  334. }
  335. TEST_F(UtestGraphPassesSwitchPass, dead_output_connected_to_merge) {
  336. /// input pred(true)
  337. /// \ /
  338. /// Switch
  339. /// | |
  340. /// | T
  341. /// | |
  342. /// Merge
  343. bool pred_value = true;
  344. BuildDefaultGraph(false, &pred_value);
  345. output_false_node_->GetOutDataAnchor(0)->UnlinkAll();
  346. GraphUtils::RemoveNodeWithoutRelink(graph_, output_false_node_);
  347. switch_node_->GetOutDataAnchor(0)->UnlinkAll();
  348. /// input pred(true)
  349. /// \ /
  350. /// Switch
  351. /// |
  352. /// T
  353. /// |
  354. /// Merge
  355. auto ret = pass_.Run(switch_node_);
  356. EXPECT_EQ(ret, SUCCESS);
  357. /// input
  358. /// |
  359. /// T
  360. /// |
  361. /// Merge
  362. EXPECT_EQ(graph_->GetDirectNodesSize(), 4);
  363. EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
  364. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor().get(), nullptr);
  365. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor(), output_true_node_->GetOutDataAnchor(0));
  366. EXPECT_EQ(output_true_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
  367. }
  368. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示