You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

switch_pass_unittest.cc 16 kB

5 years ago

  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include <gtest/gtest.h>
  17. #include <cstdint>
  18. #include <string>
  19. #define private public
  20. #include "graph/passes/switch_pass.h"
  21. #include "common/ge_inner_error_codes.h"
  22. #include "inc/pass_manager.h"
  23. #include "utils/graph_utils.h"
  24. #undef private
  25. namespace ge {
  26. namespace {
  27. class UtestGraphPassesSwitchPass : public testing::Test {
  28. protected:
  29. UtestGraphPassesSwitchPass() {
  30. graph_ = std::make_shared<ComputeGraph>("test");
  31. vector<int64_t> shape_vec{1, 1, 1, 1};
  32. GeShape shape = GeShape(shape_vec);
  33. default_tensor_desc_ = std::make_shared<GeTensorDesc>();
  34. default_tensor_desc_->SetShape(shape);
  35. default_tensor_desc_->SetFormat(FORMAT_NCHW);
  36. default_tensor_desc_->SetDataType(DT_FLOAT);
  37. }
  38. NodePtr NewNode(const std::string &name, const std::string &type, int input_cnt, int output_cnt) {
  39. OpDescPtr op_desc = std::make_shared<OpDesc>(name, type);
  40. for (int i = 0; i < input_cnt; ++i) {
  41. op_desc->AddInputDesc(default_tensor_desc_->Clone());
  42. }
  43. for (int i = 0; i < output_cnt; ++i) {
  44. op_desc->AddOutputDesc(default_tensor_desc_->Clone());
  45. }
  46. NodePtr node = graph_->AddNode(op_desc);
  47. (void)node->SetOwnerComputeGraph(graph_);
  48. return node;
  49. }
  50. void BuildDefaultGraph(bool is_input_const, const bool *pred_value = nullptr) {
  51. /// input pred
  52. /// \ /
  53. /// Switch
  54. /// | |
  55. /// F T
  56. /// | |
  57. /// Merge
  58. ///
  59. bool is_pred_const = pred_value != nullptr;
  60. if (is_pred_const) {
  61. pred_node_ = NewNode("pred", CONSTANT, 0, 1);
  62. int32_t weight[] = {static_cast<int32_t>(*pred_value)};
  63. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32);
  64. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  65. OpDescUtils::SetWeights(pred_node_, {tensor});
  66. } else {
  67. pred_node_ = NewNode("pred", GREATER, 2, 1);
  68. }
  69. if (is_input_const) {
  70. int32_t weight[] = {1};
  71. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT32);
  72. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  73. input_node_ = NewNode("input", CONSTANT, 0, 1);
  74. OpDescUtils::SetWeights(input_node_, {tensor});
  75. } else {
  76. input_node_ = NewNode("input", RELU, 0, 1);
  77. }
  78. switch_node_ = NewNode("switch", SWITCH, 2, 2);
  79. output_false_node_ = NewNode("false_output", RELU, 1, 1);
  80. output_true_node_ = NewNode("true_output", RELU, 1, 1);
  81. merge_node_ = NewNode("merge", MERGE, 2, 1);
  82. switch_node_->GetOpDesc()->SetIsInputConst({false, is_pred_const});
  83. GraphUtils::AddEdge(input_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(0));
  84. GraphUtils::AddEdge(pred_node_->GetOutDataAnchor(0), switch_node_->GetInDataAnchor(1));
  85. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
  86. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInDataAnchor(0));
  87. GraphUtils::AddEdge(output_false_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(0));
  88. GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
  89. output_false_node_->GetOpDesc()->SetIsInputConst({false});
  90. output_true_node_->GetOpDesc()->SetIsInputConst({false});
  91. }
  92. void TestPickOutput(bool expect_output) {
  93. auto ret = pass_.Run(switch_node_);
  94. EXPECT_EQ(ret, SUCCESS);
  95. EXPECT_EQ(graph_->GetDirectNodesSize(), 5); // has two isolate nodes
  96. EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
  97. if (expect_output) {
  98. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor().get(), nullptr);
  99. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor(), output_true_node_->GetOutDataAnchor(0));
  100. EXPECT_EQ(output_true_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
  101. } else {
  102. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), output_false_node_->GetOutDataAnchor(0));
  103. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor().get(), nullptr);
  104. EXPECT_EQ(output_false_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
  105. }
  106. }
  107. ComputeGraphPtr graph_;
  108. GeTensorDescPtr default_tensor_desc_;
  109. SwitchPass pass_;
  110. NodePtr pred_node_;
  111. NodePtr input_node_;
  112. NodePtr switch_node_;
  113. NodePtr output_false_node_;
  114. NodePtr output_true_node_;
  115. NodePtr merge_node_;
  116. };
  117. } // namespace
  118. TEST_F(UtestGraphPassesSwitchPass, null_input) {
  119. NodePtr node = nullptr;
  120. auto ret = pass_.Run(node);
  121. EXPECT_EQ(ret, PARAM_INVALID);
  122. }
  123. TEST_F(UtestGraphPassesSwitchPass, null_pred) {
  124. BuildDefaultGraph(false);
  125. switch_node_->GetInDataAnchor(1)->UnlinkAll();
  126. auto ret = pass_.Run(switch_node_);
  127. EXPECT_EQ(ret, SUCCESS);
  128. }
  129. TEST_F(UtestGraphPassesSwitchPass, null_data) {
  130. BuildDefaultGraph(false);
  131. switch_node_->GetInDataAnchor(0)->UnlinkAll();
  132. auto ret = pass_.Run(switch_node_);
  133. EXPECT_EQ(ret, SUCCESS);
  134. }
  135. TEST_F(UtestGraphPassesSwitchPass, unsupported_node_type) {
  136. auto node = NewNode("Op1", CONSTANT, 0, 1);
  137. auto ret = pass_.Run(node);
  138. EXPECT_EQ(ret, SUCCESS);
  139. }
  140. TEST_F(UtestGraphPassesSwitchPass, empty_output) {
  141. BuildDefaultGraph(false);
  142. switch_node_->GetOutDataAnchor(0)->UnlinkAll();
  143. switch_node_->GetOutDataAnchor(1)->UnlinkAll();
  144. auto ret = pass_.Run(switch_node_);
  145. EXPECT_EQ(ret, SUCCESS);
  146. }
  147. TEST_F(UtestGraphPassesSwitchPass, non_const_pred) {
  148. BuildDefaultGraph(false);
  149. auto ret = pass_.Run(switch_node_);
  150. EXPECT_EQ(ret, SUCCESS);
  151. }
  152. TEST_F(UtestGraphPassesSwitchPass, pick_output_false) {
  153. bool pred_value = false;
  154. BuildDefaultGraph(false, &pred_value);
  155. TestPickOutput(false);
  156. }
  157. TEST_F(UtestGraphPassesSwitchPass, pick_output_false_float) {
  158. bool pred_value = false;
  159. BuildDefaultGraph(false, &pred_value);
  160. float weight[] = {0.0f};
  161. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_FLOAT);
  162. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  163. OpDescUtils::SetWeights(pred_node_, {tensor});
  164. TestPickOutput(false);
  165. }
  166. TEST_F(UtestGraphPassesSwitchPass, pick_output_false_bool) {
  167. bool pred_value = false;
  168. BuildDefaultGraph(false, &pred_value);
  169. bool weight[] = {false};
  170. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_BOOL);
  171. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  172. OpDescUtils::SetWeights(pred_node_, {tensor});
  173. TestPickOutput(false);
  174. }
  175. TEST_F(UtestGraphPassesSwitchPass, pick_output_false_u16) {
  176. bool pred_value = false;
  177. BuildDefaultGraph(false, &pred_value);
  178. uint16_t weight[] = {0};
  179. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_UINT16);
  180. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  181. OpDescUtils::SetWeights(pred_node_, {tensor});
  182. TestPickOutput(false);
  183. }
  184. TEST_F(UtestGraphPassesSwitchPass, pick_output_true) {
  185. bool pred_value = true;
  186. BuildDefaultGraph(false, &pred_value);
  187. TestPickOutput(true);
  188. }
  189. TEST_F(UtestGraphPassesSwitchPass, pick_output_true_double) {
  190. bool pred_value = true;
  191. BuildDefaultGraph(false, &pred_value);
  192. double weight[] = {1.0};
  193. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_DOUBLE);
  194. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  195. OpDescUtils::SetWeights(pred_node_, {tensor});
  196. TestPickOutput(true);
  197. }
  198. TEST_F(UtestGraphPassesSwitchPass, pick_output_true_int64) {
  199. bool pred_value = true;
  200. BuildDefaultGraph(false, &pred_value);
  201. int64_t weight[] = {1L};
  202. GeTensorDesc weight_desc(GeShape({1}), FORMAT_NHWC, DT_INT64);
  203. GeTensorPtr tensor = std::make_shared<GeTensor>(weight_desc, (uint8_t *)weight, sizeof(weight));
  204. OpDescUtils::SetWeights(pred_node_, {tensor});
  205. TestPickOutput(true);
  206. }
  207. TEST_F(UtestGraphPassesSwitchPass, inactive_output_not_exists) {
  208. /// input pred(false)
  209. /// \ /
  210. /// Switch
  211. /// |
  212. /// F
  213. /// |
  214. /// Merge
  215. bool pred_value = false;
  216. BuildDefaultGraph(false, &pred_value);
  217. output_true_node_->GetOutDataAnchor(0)->UnlinkAll();
  218. GraphUtils::RemoveNodeWithoutRelink(graph_, output_true_node_);
  219. switch_node_->GetOutDataAnchor(1)->UnlinkAll();
  220. /// input
  221. /// |
  222. /// F
  223. /// |
  224. /// Merge
  225. auto ret = pass_.Run(switch_node_);
  226. EXPECT_EQ(ret, SUCCESS);
  227. EXPECT_EQ(graph_->GetDirectNodesSize(), 4);
  228. EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
  229. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), output_false_node_->GetOutDataAnchor(0));
  230. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor().get(), nullptr);
  231. EXPECT_EQ(output_false_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
  232. }
  233. TEST_F(UtestGraphPassesSwitchPass, const_input_pick_output_true) {
  234. /// const pred(true)
  235. /// \ /
  236. /// Switch
  237. /// | | \
  238. /// F T1 T2
  239. /// | | |
  240. /// | | /
  241. /// | T3
  242. /// | |
  243. /// Merge
  244. bool pred_value = true;
  245. BuildDefaultGraph(true, &pred_value);
  246. auto output_true_node2 = NewNode("true_output2", RELU, 1, 1);
  247. auto output_true_node3 = NewNode("true_output3", ADD, 2, 1);
  248. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node2->GetInDataAnchor(0));
  249. GraphUtils::RemoveEdge(output_true_node_->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
  250. GraphUtils::AddEdge(output_true_node_->GetOutDataAnchor(0), output_true_node3->GetInDataAnchor(0));
  251. GraphUtils::AddEdge(output_true_node2->GetOutDataAnchor(0), output_true_node3->GetInDataAnchor(1));
  252. GraphUtils::AddEdge(output_true_node3->GetOutDataAnchor(0), merge_node_->GetInDataAnchor(1));
  253. /// pred C
  254. /// | | |
  255. /// F T1 T2
  256. /// | /
  257. /// T3
  258. /// |
  259. /// Merge
  260. auto ret = pass_.Run(switch_node_);
  261. EXPECT_EQ(ret, SUCCESS);
  262. EXPECT_EQ(graph_->GetDirectNodesSize(), 7);
  263. EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
  264. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor().get(), nullptr);
  265. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor(), output_true_node3->GetOutDataAnchor(0));
  266. EXPECT_EQ(output_true_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
  267. EXPECT_NE(output_true_node2->GetInDataAnchor(0)->GetPeerOutAnchor(),
  268. output_true_node3->GetInDataAnchor(0)->GetPeerOutAnchor());
  269. }
  270. TEST_F(UtestGraphPassesSwitchPass, after_switch_const_take_false_branch) {
  271. /// C pred(false)
  272. /// \ /
  273. /// Switch
  274. /// . .
  275. /// . .
  276. /// C_1 -> F T <- C_2
  277. /// | |
  278. /// Merge
  279. bool pred_value = false;
  280. BuildDefaultGraph(true, &pred_value);
  281. switch_node_->GetOutDataAnchor(0)->UnlinkAll();
  282. switch_node_->GetOutDataAnchor(1)->UnlinkAll();
  283. NodePtr const_node_1 = NewNode("const_1", CONSTANT, 0, 1);
  284. NodePtr const_node_2 = NewNode("const_2", CONSTANT, 0, 1);
  285. GraphUtils::AddEdge(const_node_1->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
  286. GraphUtils::AddEdge(const_node_2->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0));
  287. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInControlAnchor());
  288. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInControlAnchor());
  289. /// C pred(false)
  290. ///
  291. /// C_1 C_2
  292. /// | |
  293. /// F T
  294. /// |
  295. /// Merge
  296. auto ret = pass_.Run(switch_node_);
  297. EXPECT_EQ(ret, SUCCESS);
  298. EXPECT_EQ(graph_->GetDirectNodesSize(), 7);
  299. EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
  300. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), output_false_node_->GetOutDataAnchor(0));
  301. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor().get(), nullptr);
  302. EXPECT_EQ(output_false_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), const_node_1->GetOutDataAnchor(0));
  303. }
  304. TEST_F(UtestGraphPassesSwitchPass, after_switch_const_take_true_branch) {
  305. /// C pred(true)
  306. /// \ /
  307. /// Switch
  308. /// . .
  309. /// . .
  310. /// C_1 -> F T <- C_2
  311. /// | |
  312. /// Merge
  313. bool pred_value = true;
  314. BuildDefaultGraph(true, &pred_value);
  315. switch_node_->GetOutDataAnchor(0)->UnlinkAll();
  316. switch_node_->GetOutDataAnchor(1)->UnlinkAll();
  317. NodePtr const_node_1 = NewNode("const_1", CONSTANT, 0, 1);
  318. NodePtr const_node_2 = NewNode("const_2", CONSTANT, 0, 1);
  319. GraphUtils::AddEdge(const_node_1->GetOutDataAnchor(0), output_false_node_->GetInDataAnchor(0));
  320. GraphUtils::AddEdge(const_node_2->GetOutDataAnchor(0), output_true_node_->GetInDataAnchor(0));
  321. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(0), output_false_node_->GetInControlAnchor());
  322. GraphUtils::AddEdge(switch_node_->GetOutDataAnchor(1), output_true_node_->GetInControlAnchor());
  323. /// C_1 C_2
  324. /// | |
  325. /// F T
  326. /// |
  327. /// Merge
  328. auto ret = pass_.Run(switch_node_);
  329. EXPECT_EQ(ret, SUCCESS);
  330. EXPECT_EQ(graph_->GetDirectNodesSize(), 7);
  331. EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
  332. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor().get(), nullptr);
  333. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor(), output_true_node_->GetOutDataAnchor(0));
  334. EXPECT_EQ(output_true_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), const_node_2->GetOutDataAnchor(0));
  335. }
  336. TEST_F(UtestGraphPassesSwitchPass, dead_output_connected_to_merge) {
  337. /// input pred(true)
  338. /// \ /
  339. /// Switch
  340. /// | |
  341. /// | T
  342. /// | |
  343. /// Merge
  344. bool pred_value = true;
  345. BuildDefaultGraph(false, &pred_value);
  346. output_false_node_->GetOutDataAnchor(0)->UnlinkAll();
  347. GraphUtils::RemoveNodeWithoutRelink(graph_, output_false_node_);
  348. switch_node_->GetOutDataAnchor(0)->UnlinkAll();
  349. /// input pred(true)
  350. /// \ /
  351. /// Switch
  352. /// |
  353. /// T
  354. /// |
  355. /// Merge
  356. auto ret = pass_.Run(switch_node_);
  357. EXPECT_EQ(ret, SUCCESS);
  358. /// input
  359. /// |
  360. /// T
  361. /// |
  362. /// Merge
  363. EXPECT_EQ(graph_->GetDirectNodesSize(), 4);
  364. EXPECT_EQ(merge_node_->GetInDataNodes().size(), 1);
  365. EXPECT_EQ(merge_node_->GetInDataAnchor(0)->GetPeerOutAnchor().get(), nullptr);
  366. EXPECT_EQ(merge_node_->GetInDataAnchor(1)->GetPeerOutAnchor(), output_true_node_->GetOutDataAnchor(0));
  367. EXPECT_EQ(output_true_node_->GetInDataAnchor(0)->GetPeerOutAnchor(), input_node_->GetOutDataAnchor(0));
  368. }
  369. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示