You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_batch_pass.cc 26 kB

5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/multi_batch_pass.h"
  17. #include <stack>
  18. #include <unordered_set>
  19. #include "common/ge/ge_util.h"
  20. #include "graph/common/omg_util.h"
  21. #include "graph/utils/type_utils.h"
  22. namespace ge {
  23. Status MultiBatchPass::Run(ComputeGraphPtr graph) {
  24. GELOGD("MultiBatchPass Enter");
  25. if (graph->GetParentGraph() != nullptr) {
  26. GELOGI("Subgraph %s skip the MultiBatchPass.", graph->GetName().c_str());
  27. return SUCCESS;
  28. }
  29. OutDataAnchorPtr pred_value = nullptr;
  30. Status ret = FindPredValue(graph, pred_value);
  31. if (ret == NOT_CHANGED) {
  32. GELOGD("SwitchN node not exist, graph not changed.");
  33. return SUCCESS;
  34. }
  35. if (ret != SUCCESS) {
  36. GELOGE(FAILED, "FindPredValue failed.");
  37. return FAILED;
  38. }
  39. if (GetDynamicType() != SUCCESS) {
  40. GELOGE(FAILED, "Get dynamic type failed.");
  41. return FAILED;
  42. }
  43. if (GetUserDesignateShape() != SUCCESS) {
  44. GELOGE(FAILED, "Get user designate shape failed.");
  45. return FAILED;
  46. }
  47. std::vector<std::vector<int64_t>> batch_shape;
  48. std::vector<std::vector<int64_t>> combined_batch;
  49. if (!CheckSwitchN(batch_shape, combined_batch)) {
  50. GELOGE(FAILED, "CheckSwitchN failed.");
  51. return FAILED;
  52. }
  53. if (attach_label_only_) {
  54. return AttachLabelOnly(batch_shape.size());
  55. }
  56. if (FindSwitchOutNodes(batch_shape.size()) != SUCCESS) {
  57. GELOGE(FAILED, "Find SwitchN out nodes failed.");
  58. return FAILED;
  59. }
  60. if (ReplaceSwitchN(graph, pred_value, batch_shape, combined_batch) != SUCCESS) {
  61. GELOGE(FAILED, "Replace SwitchN nodes failed.");
  62. return FAILED;
  63. }
  64. for (const NodePtr &node : bypass_nodes_) {
  65. if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) {
  66. GELOGE(FAILED, "Remove SwitchN nodes %s failed.", node->GetName().c_str());
  67. return FAILED;
  68. }
  69. }
  70. GELOGD("MultiBatchPass Leave");
  71. return SUCCESS;
  72. }
  73. ///
  74. /// @brief Clear Status
  75. /// @return
  76. ///
  77. Status MultiBatchPass::ClearStatus() {
  78. switch_n_nodes_.clear();
  79. bypass_nodes_.clear();
  80. batch_head_nodes_.clear();
  81. return SUCCESS;
  82. }
  83. ///
  84. /// @ingroup ge
  85. /// @brief Set batch label for Case mode.
  86. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  87. /// @param [in] const NodePtr &case_node: Case Node.
  88. /// @return 0: SUCCESS / others: FAILED
  89. ///
  90. Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr &case_node) {
  91. const auto &func_desc = case_node->GetOpDesc();
  92. GE_CHECK_NOTNULL(func_desc);
  93. if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) {
  94. GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), case_node->GetName().c_str());
  95. return SUCCESS;
  96. }
  97. const auto &dynamic_branch_names = func_desc->GetSubgraphInstanceNames();
  98. for (size_t i = 0; i < dynamic_branch_names.size(); ++i) {
  99. const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[i]);
  100. GE_CHECK_NOTNULL(subgraph);
  101. const std::string batch_label = "Batch_" + std::to_string(i);
  102. for (const auto &node : subgraph->GetDirectNode()) {
  103. (void)AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label);
  104. }
  105. }
  106. return SUCCESS;
  107. }
  108. ///
  109. /// @brief Replace & Combine SwitchN nodes
  110. /// @param [in] graph
  111. /// @param [out] pred_value
  112. /// @return Status
  113. ///
  114. Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchorPtr &pred_value) {
  115. for (const NodePtr &node : graph->GetDirectNode()) {
  116. if (node->GetType() == CASE) {
  117. GE_CHK_STATUS_RET(SetCaseLabel(graph, node), "Set batch label failed");
  118. continue;
  119. }
  120. if (node->GetType() != SWITCHN) {
  121. continue;
  122. }
  123. const auto &in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT);
  124. if (in_data_anchor == nullptr) {
  125. GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str());
  126. return FAILED;
  127. }
  128. const auto &pred_input = in_data_anchor->GetPeerOutAnchor();
  129. if (pred_input == nullptr) {
  130. GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str());
  131. return FAILED;
  132. }
  133. if (pred_value == nullptr) {
  134. pred_value = pred_input;
  135. } else if (pred_value != pred_input) {
  136. GELOGE(FAILED, "Multi pred_value node exist.");
  137. return FAILED;
  138. }
  139. switch_n_nodes_.emplace_back(node);
  140. }
  141. if (switch_n_nodes_.empty()) {
  142. GELOGD("SwitchN node not exist.");
  143. return NOT_CHANGED;
  144. }
  145. if (pred_value == nullptr) {
  146. GELOGE(FAILED, "FindPredInput failed, pred_value is null.");
  147. return FAILED;
  148. }
  149. GELOGI("Find pred_value %s.", pred_value->GetOwnerNode()->GetName().c_str());
  150. return SUCCESS;
  151. }
  152. ///
  153. /// @brief Get dynamic type: dynamic batch size: 1, dynamic image size: 2, dynamic dims: 3
  154. /// @return Status
  155. ///
  156. Status MultiBatchPass::GetDynamicType() {
  157. for (const auto &switch_n : switch_n_nodes_) {
  158. int32_t dynamic_type = static_cast<int32_t>(FIXED);
  159. if (!AttrUtils::GetInt(switch_n->GetOpDesc(), ATTR_DYNAMIC_TYPE, dynamic_type)) {
  160. GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switch_n->GetName().c_str());
  161. return FAILED;
  162. }
  163. if (dynamic_type == static_cast<int32_t>(FIXED)) {
  164. GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE shouldn't be 0.");
  165. return FAILED;
  166. }
  167. if (dynamic_type_ != static_cast<int32_t>(FIXED) && dynamic_type_ != dynamic_type) {
  168. GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switch_n node should be same, while one is %d and another is %d.",
  169. dynamic_type, dynamic_type_);
  170. return FAILED;
  171. }
  172. dynamic_type_ = dynamic_type;
  173. }
  174. if (dynamic_type_ == static_cast<int32_t>(FIXED)) {
  175. GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE shouldn't be 0.");
  176. return FAILED;
  177. }
  178. return SUCCESS;
  179. }
  180. ///
  181. /// @brief Get user designate shape order. eg{"data","label","mask"}
  182. /// @return Status
  183. ///
  184. Status MultiBatchPass::GetUserDesignateShape() {
  185. data_name_order_.clear();
  186. bool first_check = true;
  187. for (const auto &switch_n : switch_n_nodes_) {
  188. std::vector<std::string> cur_data_name_order;
  189. if (!AttrUtils::GetListStr(switch_n->GetOpDesc(), ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_data_name_order)) {
  190. GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switch_n->GetName().c_str());
  191. return FAILED;
  192. }
  193. if (first_check) {
  194. data_name_order_ = cur_data_name_order;
  195. first_check = false;
  196. } else {
  197. if (data_name_order_ != cur_data_name_order) {
  198. GELOGE(FAILED, "The ATTR_USER_DESIGNEATE_SHAPE_ORDER of switchN must be same: %s failed.",
  199. switch_n->GetName().c_str());
  200. return FAILED;
  201. }
  202. }
  203. }
  204. if (data_name_order_.empty()) {
  205. GELOGE(FAILED, "user shape order can not be empty");
  206. return FAILED;
  207. }
  208. return SUCCESS;
  209. }
  210. ///
  211. /// @brief Check SwitchN nodes
  212. /// @param [out] batch_shape
  213. /// @param [out] combined_batch
  214. /// @return bool
  215. ///
  216. bool MultiBatchPass::CheckSwitchN(std::vector<std::vector<int64_t>> &batch_shape,
  217. std::vector<std::vector<int64_t>> &combined_batch) {
  218. // Check if output_num of different SwitchN is same
  219. uint32_t batch_num = 0;
  220. for (const NodePtr &node : switch_n_nodes_) {
  221. uint32_t tmp_num = node->GetAllOutDataAnchorsSize();
  222. if (batch_num == 0) {
  223. batch_num = tmp_num;
  224. } else if (batch_num != tmp_num) {
  225. GELOGE(FAILED, "Output size of SwitchN not equal;");
  226. return false;
  227. }
  228. }
  229. if (!GetBatchInfo(batch_num, batch_shape, combined_batch)) {
  230. GELOGE(FAILED, "Get batch info failed.");
  231. return false;
  232. }
  233. if (batch_shape.empty()) {
  234. GELOGE(FAILED, "batch_shape is empty.");
  235. return false;
  236. }
  237. if (combined_batch.empty()) {
  238. GELOGE(FAILED, "combined_batch is empty.");
  239. return false;
  240. }
  241. size_t dim_num = batch_shape[0].size();
  242. size_t combined_dim_num = combined_batch[0].size();
  243. for (uint32_t i = 1; i < batch_num; i++) {
  244. size_t tmp_dim_num = batch_shape[i].size();
  245. if (dim_num != tmp_dim_num) {
  246. GELOGE(FAILED, "Dim num of batch_shape not equal, batch_0:%zu, batch_%u:%zu.", dim_num, i, tmp_dim_num);
  247. return false;
  248. }
  249. size_t tmp_combined_dim_num = combined_batch[i].size();
  250. if (combined_dim_num != tmp_combined_dim_num) {
  251. GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.",
  252. combined_dim_num, i, tmp_combined_dim_num);
  253. return false;
  254. }
  255. }
  256. return true;
  257. }
  258. ///
  259. /// @brief Check SwitchN nodes
  260. /// @param [in] batch_num
  261. /// @param [out] batch_shape
  262. /// @param [out] combined_batch
  263. /// @return bool
  264. ///
  265. bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, std::vector<std::vector<int64_t>> &batch_shape,
  266. std::vector<std::vector<int64_t>> &combined_batch) {
  267. // Check if output_shape of different SwitchN is same
  268. std::vector<std::vector<int64_t>> idx_batch_shape;
  269. std::vector<std::vector<int64_t>> idx_combined_batch;
  270. for (uint32_t i = 0; i < batch_num; i++) {
  271. idx_batch_shape.clear();
  272. idx_combined_batch.clear();
  273. for (const NodePtr &node : switch_n_nodes_) {
  274. OpDescPtr op_desc = node->GetOpDesc();
  275. if (op_desc == nullptr) {
  276. GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str());
  277. return false;
  278. }
  279. std::vector<int64_t> output_dims;
  280. if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) {
  281. GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i);
  282. return false;
  283. }
  284. idx_batch_shape.emplace_back(output_dims);
  285. output_dims.clear();
  286. if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_COMBINED_DYNAMIC_DIMS, output_dims)) {
  287. GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_COMBINED_DYNAMIC_DIMS failed, batch_index=%u.", i);
  288. return false;
  289. }
  290. idx_combined_batch.emplace_back(output_dims);
  291. }
  292. if (!CheckDims(idx_batch_shape)) {
  293. GELOGE(FAILED, "CheckDims failed, batch_index=%u.", i);
  294. return false;
  295. }
  296. batch_shape.emplace_back(idx_batch_shape[0]);
  297. combined_batch.emplace_back(idx_combined_batch[0]);
  298. }
  299. return true;
  300. }
  301. ///
  302. /// @brief Find outputs of SwitchN nodes
  303. /// @param [in] batch_num
  304. /// @return void
  305. ///
  306. Status MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) {
  307. std::vector<NodePtr> output_nodes;
  308. for (uint32_t i = 0; i < batch_num; i++) {
  309. output_nodes.clear();
  310. for (const NodePtr &node : switch_n_nodes_) {
  311. // idx is promised to be valid
  312. OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(i);
  313. GE_CHECK_NOTNULL(out_data_anchor);
  314. for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  315. auto out_node = peer_in_anchor->GetOwnerNode();
  316. if (out_node->GetType() != IDENTITY || !out_node->GetOutDataNodes().empty()) {
  317. output_nodes.emplace_back(out_node);
  318. continue;
  319. }
  320. bypass_nodes_.emplace_back(out_node);
  321. if (GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor) != GRAPH_SUCCESS) {
  322. GELOGE(FAILED, "Remove SwitchN out_data_edge failed, %s->%s.", node->GetName().c_str(),
  323. out_node->GetName().c_str());
  324. return FAILED;
  325. }
  326. for (auto &identity_out_node : out_node->GetOutControlNodes()) {
  327. output_nodes.emplace_back(identity_out_node);
  328. if (GraphUtils::RemoveEdge(out_node->GetOutControlAnchor(), identity_out_node->GetInControlAnchor()) !=
  329. GRAPH_SUCCESS) {
  330. GELOGE(FAILED, "Remove SwitchN out_data_edge failed, %s->%s.", node->GetName().c_str(),
  331. out_node->GetName().c_str());
  332. return FAILED;
  333. }
  334. }
  335. }
  336. }
  337. batch_head_nodes_.emplace_back(output_nodes);
  338. }
  339. return SUCCESS;
  340. }
  341. ///
  342. /// @brief Replace & Combine SwitchN nodes
  343. /// @param [in] graph
  344. /// @param [in] pred_value
  345. /// @param [in] batch_shape
  346. /// @param [in] combined_batch
  347. /// @return Status
  348. ///
  349. Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value,
  350. const std::vector<std::vector<int64_t>> &batch_shape,
  351. const std::vector<std::vector<int64_t>> &combined_batch) {
  352. NodePtr pred_value_node = pred_value->GetOwnerNode();
  353. // Create SwitchCase node
  354. const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN;
  355. NodePtr switch_case = CreateSwitchCaseNode(graph, switch_case_name, pred_value, batch_shape, combined_batch);
  356. if (switch_case == nullptr) {
  357. GELOGE(FAILED, "CreateSwitchCaseNode %s failed.", switch_case_name.c_str());
  358. return FAILED;
  359. }
  360. for (const NodePtr &switch_n_node : switch_n_nodes_) {
  361. if (BypassSwitchN(switch_n_node, switch_case) != SUCCESS) {
  362. GELOGE(FAILED, "Bypass SwitchN %s failed.", switch_case_name.c_str());
  363. return FAILED;
  364. }
  365. }
  366. // Add switchCase input edge
  367. if (GraphUtils::AddEdge(pred_value, switch_case->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  368. GELOGE(FAILED, "Add SwitchCase in_data_edge failed, %s->%s.", pred_value_node->GetName().c_str(),
  369. switch_case->GetName().c_str());
  370. return FAILED;
  371. }
  372. if (AttachLabel(switch_case) != SUCCESS) {
  373. GELOGE(FAILED, "AttachLabel failed.");
  374. return FAILED;
  375. }
  376. return SUCCESS;
  377. }
  378. ///
  379. /// @brief Check if output_shape of different SwitchN is same
  380. /// @param [in] output_shape
  381. /// @return bool
  382. ///
  383. bool MultiBatchPass::CheckDims(const std::vector<std::vector<int64_t>> &output_shape) const {
  384. if (output_shape.empty()) {
  385. GELOGE(FAILED, "CheckDims failed: output_shape is empty.");
  386. return false;
  387. }
  388. for (auto iter = output_shape.begin() + 1; iter != output_shape.end(); ++iter) {
  389. if (output_shape[0] != *iter) {
  390. return false;
  391. }
  392. }
  393. return true;
  394. }
  395. ///
  396. /// @brief Create StreamSwitchN node
  397. /// @param [in] graph
  398. /// @param [in] name
  399. /// @param [in] pred_value
  400. /// @param [in] batch_shape
  401. /// @param [in] combined_batch
  402. /// @return ge::NodePtr
  403. ///
  404. NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name,
  405. const OutDataAnchorPtr &pred_value,
  406. const std::vector<std::vector<int64_t>> &batch_shape,
  407. const std::vector<std::vector<int64_t>> &combined_batch) {
  408. OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMSWITCHN);
  409. if (op_desc == nullptr) {
  410. GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str());
  411. return nullptr;
  412. }
  413. GELOGI("Create StreamSwitchN op:%s.", name.c_str());
  414. OpDescPtr pred_desc = pred_value->GetOwnerNode()->GetOpDesc();
  415. if (pred_desc == nullptr) {
  416. GELOGE(FAILED, "Get pred_desc failed, StreamSwitchN:%s.", name.c_str());
  417. return nullptr;
  418. }
  419. if (op_desc->AddInputDesc(pred_desc->GetOutputDesc(pred_value->GetIdx())) != GRAPH_SUCCESS) {
  420. GELOGE(FAILED, "AddInputDesc failed, StreamSwitchN:%s.", name.c_str());
  421. return nullptr;
  422. }
  423. NodePtr switch_case_node = graph->AddNode(op_desc);
  424. if (switch_case_node == nullptr) {
  425. GELOGE(FAILED, "Create node failed, StreamSwitchN:%s.", name.c_str());
  426. return nullptr;
  427. }
  428. uint32_t batch_num = static_cast<uint32_t>(batch_shape.size());
  429. if (!AttrUtils::SetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) {
  430. GELOGE(FAILED, "set attr ATTR_NAME_BATCH_NUM failed, StreamSwitchN:%s.", name.c_str());
  431. return nullptr;
  432. }
  433. if (!AttrUtils::SetInt(op_desc, ATTR_DYNAMIC_TYPE, dynamic_type_)) {
  434. GELOGE(FAILED, "Set attr ATTR_DYNAMIC_TYPE failed, StreamSwitchN:%s.", name.c_str());
  435. return nullptr;
  436. }
  437. if (!AttrUtils::SetListStr(op_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, data_name_order_)) {
  438. GELOGE(FAILED, "Set attr ATTR_USER_DESIGNEATE_SHAPE_ORDER failed, StreamSwitchN:%s.", name.c_str());
  439. return nullptr;
  440. }
  441. for (uint32_t i = 0; i < batch_num; i++) {
  442. const std::string &attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i);
  443. if (!AttrUtils::SetListInt(op_desc, attr_name, batch_shape[i])) {
  444. GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str());
  445. return nullptr;
  446. }
  447. const std::string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i);
  448. if (!AttrUtils::SetListInt(op_desc, attr_combined_batch, combined_batch[i])) {
  449. GELOGE(FAILED, "set attr ATTR_NAME_COMBINED_BATCH failed, StreamSwitchN:%s.", name.c_str());
  450. return nullptr;
  451. }
  452. }
  453. return switch_case_node;
  454. }
  455. ///
  456. /// @brief Bypass SwitchN node
  457. /// @param [in] switch_n_node
  458. /// @param [in] switch_case
  459. /// @return Status
  460. ///
  461. Status MultiBatchPass::BypassSwitchN(const NodePtr &switch_n_node, const NodePtr &switch_case) {
  462. InDataAnchorPtr in_data_anchor = switch_n_node->GetInDataAnchor(SWITCH_DATA_INPUT);
  463. if (in_data_anchor == nullptr) {
  464. GELOGE(FAILED, "Check in_data_anchor failed, SwitchN:%s.", switch_n_node->GetName().c_str());
  465. return FAILED;
  466. }
  467. OutDataAnchorPtr peer_data_anchor = in_data_anchor->GetPeerOutAnchor();
  468. if (peer_data_anchor == nullptr) {
  469. GELOGE(FAILED, "Check peer_data_anchor failed, SwitchN:%s.", switch_n_node->GetName().c_str());
  470. return FAILED;
  471. }
  472. NodePtr data_input = peer_data_anchor->GetOwnerNode();
  473. // Remove SwitchN data input
  474. if (GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) != GRAPH_SUCCESS) {
  475. GELOGE(FAILED, "Remove SwitchN in_data_edge failed, %s->%s.", data_input->GetName().c_str(),
  476. switch_n_node->GetName().c_str());
  477. return FAILED;
  478. }
  479. if (GraphUtils::AddEdge(data_input->GetOutControlAnchor(), switch_case->GetInControlAnchor()) != GRAPH_SUCCESS) {
  480. GELOGE(FAILED, "Add StreamSwitchN in_control_edge failed, %s->%s.", data_input->GetName().c_str(),
  481. switch_case->GetName().c_str());
  482. return FAILED;
  483. }
  484. // Add SwitchCase control output
  485. for (const OutDataAnchorPtr &out_data_anchor : switch_n_node->GetAllOutDataAnchors()) {
  486. for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  487. NodePtr data_output = peer_in_anchor->GetOwnerNode();
  488. if ((GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor) != GRAPH_SUCCESS) ||
  489. (GraphUtils::AddEdge(peer_data_anchor, peer_in_anchor) != GRAPH_SUCCESS)) {
  490. GELOGE(FAILED, "Bypass SwitchN data_edge failed, %s->%s->%s.", data_input->GetName().c_str(),
  491. switch_n_node->GetName().c_str(), data_output->GetName().c_str());
  492. return FAILED;
  493. }
  494. if (GraphUtils::AddEdge(switch_case->GetOutControlAnchor(), data_output->GetInControlAnchor()) != GRAPH_SUCCESS) {
  495. GELOGE(FAILED, "Add SwitchCase out_control_edge failed, %s->%s.", switch_case->GetName().c_str(),
  496. data_output->GetName().c_str());
  497. return FAILED;
  498. }
  499. }
  500. }
  501. GE_CHK_STATUS_RET(MoveCtrlEdges(switch_n_node, switch_case), "Move ctrl edges failed.");
  502. bypass_nodes_.emplace_back(switch_n_node);
  503. GELOGI("Bypass SwitchN node %s success.", switch_n_node->GetName().c_str());
  504. return SUCCESS;
  505. }
  506. ///
  507. /// @brief Attach stream_label & batch_label for batch branch
  508. /// @param [in] switch_case_node
  509. /// @return Status
  510. ///
  511. Status MultiBatchPass::AttachLabel(const NodePtr &switch_case_node) {
  512. std::vector<std::string> stream_label_list;
  513. for (uint32_t i = 0; i < static_cast<uint32_t>(batch_head_nodes_.size()); i++) {
  514. if (AttachBatchLabel(i) != SUCCESS) {
  515. GELOGE(FAILED, "AttachBatchLabel failed, batch_idx=%u", i);
  516. return FAILED;
  517. }
  518. const std::string &stream_label = "stream_label_batch_" + std::to_string(i);
  519. if (AttachStreamLabel(i, stream_label) != SUCCESS) {
  520. GELOGE(FAILED, "AttachStreamLabel failed, stream_label=%s", stream_label.c_str());
  521. return FAILED;
  522. }
  523. stream_label_list.emplace_back(stream_label);
  524. }
  525. return switch_case_node == nullptr ? SUCCESS : SetActiveLabelList(switch_case_node, stream_label_list);
  526. }
  527. ///
  528. /// @brief Attach batch_label for batch branch
  529. /// @param [in] batch_idx
  530. /// @return Status
  531. ///
  532. Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) {
  533. std::stack<NodePtr> nodes;
  534. for (const auto &node : batch_head_nodes_[batch_idx]) {
  535. nodes.push(node);
  536. }
  537. const std::string &batch_label = "Batch_" + std::to_string(batch_idx);
  538. std::unordered_set<NodePtr> handled_nodes;
  539. while (!nodes.empty()) {
  540. NodePtr cur_node = nodes.top();
  541. nodes.pop();
  542. if (handled_nodes.count(cur_node) > 0) {
  543. continue;
  544. }
  545. OpDescPtr cur_desc = cur_node->GetOpDesc();
  546. GE_CHECK_NOTNULL(cur_desc);
  547. if (cur_desc->HasAttr(ATTR_NAME_BATCH_LABEL)) {
  548. std::string tmp_label;
  549. if (!AttrUtils::GetStr(cur_desc, ATTR_NAME_BATCH_LABEL, tmp_label)) {
  550. GELOGE(FAILED, "get attr ATTR_NAME_BATCH_LABEL failed, node: %s.", cur_desc->GetName().c_str());
  551. return FAILED;
  552. }
  553. if (tmp_label != batch_label) {
  554. GELOGE(FAILED, "Reach other batch_branch, node:%s, cur_label:%s, batch_label:%s.", cur_desc->GetName().c_str(),
  555. tmp_label.c_str(), batch_label.c_str());
  556. return FAILED;
  557. }
  558. }
  559. GELOGD("Attach batch_label %s to node %s.", batch_label.c_str(), cur_desc->GetName().c_str());
  560. if (!AttrUtils::SetStr(cur_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
  561. GELOGE(FAILED, "set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", cur_desc->GetName().c_str());
  562. return FAILED;
  563. }
  564. for (const auto &out_node : cur_node->GetOutAllNodes()) {
  565. OpDescPtr op_desc = out_node->GetOpDesc();
  566. GE_CHECK_NOTNULL(op_desc);
  567. const std::string &type = op_desc->GetType();
  568. if ((type == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) {
  569. continue;
  570. }
  571. if (type == NETOUTPUT) {
  572. GELOGE(FAILED, "Reach net_output without Merge, cur_node:%s.", cur_node->GetName().c_str());
  573. return FAILED;
  574. }
  575. nodes.push(out_node);
  576. }
  577. (void)handled_nodes.insert(cur_node);
  578. }
  579. return SUCCESS;
  580. }
  581. ///
  582. /// @brief Attach stream_label for batch branch
  583. /// @param [in] batch_idx
  584. /// @param [in] stream_label
  585. /// @return Status
  586. ///
  587. Status MultiBatchPass::AttachStreamLabel(uint32_t batch_idx, const std::string &stream_label) {
  588. std::stack<NodePtr> nodes;
  589. for (const auto &node : batch_head_nodes_[batch_idx]) {
  590. nodes.push(node);
  591. }
  592. std::unordered_set<NodePtr> handled_nodes;
  593. while (!nodes.empty()) {
  594. NodePtr cur_node = nodes.top();
  595. nodes.pop();
  596. OpDescPtr cur_desc = cur_node->GetOpDesc();
  597. GE_CHECK_NOTNULL(cur_desc);
  598. if ((handled_nodes.count(cur_node) > 0) || (cur_desc->HasAttr(ATTR_NAME_STREAM_LABEL))) {
  599. continue;
  600. }
  601. GELOGD("Attach stream_label %s to node %s.", stream_label.c_str(), cur_desc->GetName().c_str());
  602. if (SetStreamLabel(cur_node, stream_label) != SUCCESS) {
  603. GELOGE(FAILED, "Set stream_label failed, node:%s.", cur_node->GetName().c_str());
  604. return FAILED;
  605. }
  606. for (const auto &out_node : cur_node->GetOutAllNodes()) {
  607. nodes.push(out_node);
  608. }
  609. (void)handled_nodes.insert(cur_node);
  610. }
  611. return SUCCESS;
  612. }
  613. ///
  614. /// @brief move edges from old_node to new_node
  615. /// @param [in] old_node
  616. /// @param [in] new_node
  617. /// @return Status
  618. ///
  619. Status MultiBatchPass::MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new_node) {
  620. if (old_node == new_node) {
  621. return SUCCESS;
  622. }
  623. for (const NodePtr &in_ctrl_node : old_node->GetInControlNodes()) {
  624. GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), old_node->GetInControlAnchor()),
  625. "Merge remove in ctrl edge failed.");
  626. GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), new_node->GetInControlAnchor()),
  627. "StreamMerge add in ctrl edge failed.");
  628. }
  629. for (const NodePtr &out_ctrl_node : old_node->GetOutControlNodes()) {
  630. GE_CHK_STATUS(GraphUtils::RemoveEdge(old_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()),
  631. "Merge remove out ctrl edge failed.");
  632. GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()),
  633. "StreamMerge add out ctrl edge failed.");
  634. }
  635. return SUCCESS;
  636. }
  637. ///
  638. /// @brief attach stream_label & batch_label without change structure of graph
  639. /// @param [in] batch_num
  640. /// @return void
  641. ///
  642. Status MultiBatchPass::AttachLabelOnly(uint32_t batch_num) {
  643. std::vector<NodePtr> output_nodes;
  644. for (uint32_t i = 0; i < batch_num; i++) {
  645. output_nodes.clear();
  646. for (const NodePtr &node : switch_n_nodes_) {
  647. // idx is promised to be valid
  648. OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(i);
  649. GE_CHECK_NOTNULL(out_data_anchor);
  650. for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  651. output_nodes.emplace_back(peer_in_anchor->GetOwnerNode());
  652. }
  653. }
  654. batch_head_nodes_.emplace_back(output_nodes);
  655. }
  656. return AttachLabel(nullptr);
  657. }
  658. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示