You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_batch_pass.cc 26 kB

5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/multi_batch_pass.h"
  17. #include <stack>
  18. #include <unordered_set>
  19. #include "common/ge/ge_util.h"
  20. #include "graph/common/omg_util.h"
  21. #include "graph/utils/type_utils.h"
  22. using std::string;
  23. using std::vector;
  24. namespace ge {
  25. Status MultiBatchPass::Run(ComputeGraphPtr graph) {
  26. GELOGD("MultiBatchPass Enter");
  27. if (graph->GetParentGraph() != nullptr) {
  28. GELOGI("Subgraph %s skip the MultiBatchPass.", graph->GetName().c_str());
  29. return SUCCESS;
  30. }
  31. OutDataAnchorPtr pred_value = nullptr;
  32. Status ret = FindPredValue(graph, pred_value);
  33. if (ret == NOT_CHANGED) {
  34. GELOGD("SwitchN node not exist, graph not changed.");
  35. return SUCCESS;
  36. }
  37. if (ret != SUCCESS) {
  38. GELOGE(FAILED, "FindPredValue failed.");
  39. return FAILED;
  40. }
  41. if (GetDynamicType() != SUCCESS) {
  42. GELOGE(FAILED, "Get dynamic type failed.");
  43. return FAILED;
  44. }
  45. if (GetUserDesignateShape() != SUCCESS) {
  46. GELOGE(FAILED, "Get user designate shape failed.");
  47. return FAILED;
  48. }
  49. std::vector<std::vector<int64_t>> batch_shape;
  50. vector<vector<int64_t>> combined_batch;
  51. if (!CheckSwitchN(batch_shape, combined_batch)) {
  52. GELOGE(FAILED, "CheckSwitchN failed.");
  53. return FAILED;
  54. }
  55. if (attach_label_only_) {
  56. return AttachLabelOnly(batch_shape.size());
  57. }
  58. if (FindSwitchOutNodes(batch_shape.size()) != SUCCESS) {
  59. GELOGE(FAILED, "Find SwitchN out nodes failed.");
  60. return FAILED;
  61. }
  62. if (ReplaceSwitchN(graph, pred_value, batch_shape, combined_batch) != SUCCESS) {
  63. GELOGE(FAILED, "Replace SwitchN nodes failed.");
  64. return FAILED;
  65. }
  66. for (const NodePtr &node : bypass_nodes_) {
  67. if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) {
  68. GELOGE(FAILED, "Remove SwitchN nodes %s failed.", node->GetName().c_str());
  69. return FAILED;
  70. }
  71. }
  72. GELOGD("MultiBatchPass Leave");
  73. return SUCCESS;
  74. }
  75. ///
  76. /// @brief Clear Status
  77. /// @return
  78. ///
  79. Status MultiBatchPass::ClearStatus() {
  80. switch_n_nodes_.clear();
  81. bypass_nodes_.clear();
  82. batch_head_nodes_.clear();
  83. return SUCCESS;
  84. }
  85. ///
  86. /// @ingroup ge
  87. /// @brief Set batch label for Case mode.
  88. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  89. /// @param [in] const NodePtr &case_node: Case Node.
  90. /// @return 0: SUCCESS / others: FAILED
  91. ///
  92. Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr &case_node) {
  93. const auto &func_desc = case_node->GetOpDesc();
  94. if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) {
  95. GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), case_node->GetName().c_str());
  96. return SUCCESS;
  97. }
  98. const auto &dynamic_branch_names = func_desc->GetSubgraphInstanceNames();
  99. for (size_t i = 0; i < dynamic_branch_names.size(); ++i) {
  100. const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[i]);
  101. GE_CHECK_NOTNULL(subgraph);
  102. const string batch_label = "Batch_" + std::to_string(i);
  103. for (const auto &node : subgraph->GetDirectNode()) {
  104. (void)AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label);
  105. }
  106. }
  107. return SUCCESS;
  108. }
  109. ///
  110. /// @brief Replace & Combine SwitchN nodes
  111. /// @param [in] graph
  112. /// @param [out] pred_value
  113. /// @return Status
  114. ///
  115. Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchorPtr &pred_value) {
  116. for (const NodePtr &node : graph->GetDirectNode()) {
  117. if (node->GetType() == CASE) {
  118. GE_CHK_STATUS_RET(SetCaseLabel(graph, node), "Set batch label failed");
  119. continue;
  120. }
  121. if (node->GetType() != SWITCHN) {
  122. continue;
  123. }
  124. InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT);
  125. if (in_data_anchor == nullptr) {
  126. GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str());
  127. return FAILED;
  128. }
  129. OutDataAnchorPtr pred_input = in_data_anchor->GetPeerOutAnchor();
  130. if (pred_input == nullptr) {
  131. GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str());
  132. return FAILED;
  133. }
  134. if (pred_value == nullptr) {
  135. pred_value = pred_input;
  136. } else if (pred_value != pred_input) {
  137. GELOGE(FAILED, "Multi pred_value node exist.");
  138. return FAILED;
  139. }
  140. switch_n_nodes_.emplace_back(node);
  141. }
  142. if (switch_n_nodes_.empty()) {
  143. GELOGD("SwitchN node not exist.");
  144. return NOT_CHANGED;
  145. }
  146. if (pred_value == nullptr) {
  147. GELOGE(FAILED, "FindPredInput failed, pred_value is null.");
  148. return FAILED;
  149. }
  150. GELOGI("Find pred_value %s.", pred_value->GetOwnerNode()->GetName().c_str());
  151. return SUCCESS;
  152. }
  153. ///
  154. /// @brief Get dynamic type: dynamic batch size: 1, dynamic image size: 2, dynamic dims: 3
  155. /// @return Status
  156. ///
  157. Status MultiBatchPass::GetDynamicType() {
  158. for (const auto &switchn : switch_n_nodes_) {
  159. auto switchn_desc = switchn->GetOpDesc();
  160. GE_CHECK_NOTNULL(switchn_desc);
  161. int32_t dynamic_type = static_cast<int32_t>(FIXED);
  162. if (!AttrUtils::GetInt(switchn_desc, ATTR_DYNAMIC_TYPE, dynamic_type)) {
  163. GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switchn->GetName().c_str());
  164. return FAILED;
  165. }
  166. if (dynamic_type == static_cast<int32_t>(FIXED)) {
  167. GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE shouldn't be 0.");
  168. return FAILED;
  169. }
  170. if (dynamic_type_ != static_cast<int32_t>(FIXED) && dynamic_type_ != dynamic_type) {
  171. GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switchn node should be same, while one is %d and another is %d.",
  172. dynamic_type, dynamic_type_);
  173. return FAILED;
  174. }
  175. dynamic_type_ = dynamic_type;
  176. }
  177. if (dynamic_type_ == static_cast<int32_t>(FIXED)) {
  178. GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE shouldn't be 0.");
  179. return FAILED;
  180. }
  181. return SUCCESS;
  182. }
  183. ///
  184. /// @brief Get user designate shape order. eg{"data","label","mask"}
  185. /// @return Status
  186. ///
  187. Status MultiBatchPass::GetUserDesignateShape() {
  188. data_name_order_.clear();
  189. bool first_check = true;
  190. for (const auto &switchn : switch_n_nodes_) {
  191. auto switchn_desc = switchn->GetOpDesc();
  192. GE_CHECK_NOTNULL(switchn_desc);
  193. vector<string> cur_switchn_data_name_order;
  194. if (!AttrUtils::GetListStr(switchn_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_switchn_data_name_order)) {
  195. GELOGE(FAILED, "Get attr ATTR_USER_DESIGNEATE_SHAPE_ORDER of node: %s failed.", switchn->GetName().c_str());
  196. return FAILED;
  197. }
  198. if (first_check) {
  199. data_name_order_ = cur_switchn_data_name_order;
  200. first_check = false;
  201. } else {
  202. if (data_name_order_ != cur_switchn_data_name_order) {
  203. GELOGE(FAILED, "The ATTR_USER_DESIGNEATE_SHAPE_ORDER of switchN must be same: %s failed.",
  204. switchn->GetName().c_str());
  205. return FAILED;
  206. }
  207. }
  208. }
  209. if (data_name_order_.empty()) {
  210. GELOGE(FAILED, "user shape order can not be empty");
  211. return FAILED;
  212. }
  213. return SUCCESS;
  214. }
  215. ///
  216. /// @brief Check SwitchN nodes
  217. /// @param [out] batch_shape
  218. /// @param [out] combined_batch
  219. /// @return bool
  220. ///
  221. bool MultiBatchPass::CheckSwitchN(vector<vector<int64_t>> &batch_shape, vector<vector<int64_t>> &combined_batch) {
  222. // Check if output_num of different SwitchN is same
  223. uint32_t batch_num = 0;
  224. for (const NodePtr &node : switch_n_nodes_) {
  225. uint32_t tmp_num = node->GetAllOutDataAnchorsSize();
  226. if (batch_num == 0) {
  227. batch_num = tmp_num;
  228. } else if (batch_num != tmp_num) {
  229. GELOGE(FAILED, "Output size of SwitchN not equal;");
  230. return false;
  231. }
  232. }
  233. if (!GetBatchInfo(batch_num, batch_shape, combined_batch)) {
  234. GELOGE(FAILED, "Get batch info failed.");
  235. return false;
  236. }
  237. if (batch_shape.empty()) {
  238. GELOGE(FAILED, "batch_shape is empty.");
  239. return false;
  240. }
  241. if (combined_batch.empty()) {
  242. GELOGE(FAILED, "combined_batch is empty.");
  243. return false;
  244. }
  245. size_t dim_num = batch_shape[0].size();
  246. size_t combined_dim_num = combined_batch[0].size();
  247. for (uint32_t i = 1; i < batch_num; i++) {
  248. size_t tmp_dim_num = batch_shape[i].size();
  249. if (dim_num != tmp_dim_num) {
  250. GELOGE(FAILED, "Dim num of batch_shape not equal, batch_0:%zu, batch_%u:%zu.", dim_num, i, tmp_dim_num);
  251. return false;
  252. }
  253. size_t tmp_combined_dim_num = combined_batch[i].size();
  254. if (combined_dim_num != tmp_combined_dim_num) {
  255. GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.", dim_num, i, tmp_dim_num);
  256. return false;
  257. }
  258. }
  259. return true;
  260. }
  261. ///
  262. /// @brief Check SwitchN nodes
  263. /// @param [in] batch_num
  264. /// @param [out] batch_shape
  265. /// @param [out] combined_batch
  266. /// @return bool
  267. ///
  268. bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector<vector<int64_t>> &batch_shape,
  269. vector<vector<int64_t>> &combined_batch) {
  270. // Check if output_shape of different SwitchN is same
  271. vector<vector<int64_t>> idx_batch_shape;
  272. vector<vector<int64_t>> idx_combined_batch;
  273. for (uint32_t i = 0; i < batch_num; i++) {
  274. idx_batch_shape.clear();
  275. idx_combined_batch.clear();
  276. for (const NodePtr &node : switch_n_nodes_) {
  277. OpDescPtr op_desc = node->GetOpDesc();
  278. if (op_desc == nullptr) {
  279. GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str());
  280. return false;
  281. }
  282. vector<int64_t> output_dims;
  283. if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) {
  284. GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i);
  285. return false;
  286. }
  287. idx_batch_shape.emplace_back(output_dims);
  288. output_dims.clear();
  289. if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_COMBINED_DYNAMIC_DIMS, output_dims)) {
  290. GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_COMBINED_DYNAMIC_DIMS failed, batch_index=%u.", i);
  291. return false;
  292. }
  293. idx_combined_batch.emplace_back(output_dims);
  294. }
  295. if (!CheckDims(idx_batch_shape)) {
  296. GELOGE(FAILED, "CheckDims failed, batch_index=%u.", i);
  297. return false;
  298. }
  299. batch_shape.emplace_back(idx_batch_shape[0]);
  300. combined_batch.emplace_back(idx_combined_batch[0]);
  301. }
  302. return true;
  303. }
  304. ///
  305. /// @brief Find outputs of SwitchN nodes
  306. /// @param [in] batch_num
  307. /// @return void
  308. ///
  309. Status MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) {
  310. std::vector<NodePtr> output_nodes;
  311. for (uint32_t i = 0; i < batch_num; i++) {
  312. output_nodes.clear();
  313. for (const NodePtr &node : switch_n_nodes_) {
  314. // idx is promised to be valid
  315. OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(i);
  316. GE_CHECK_NOTNULL(out_data_anchor);
  317. for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  318. auto out_node = peer_in_anchor->GetOwnerNode();
  319. if (out_node->GetType() != IDENTITY || !out_node->GetOutDataNodes().empty()) {
  320. output_nodes.emplace_back(out_node);
  321. continue;
  322. }
  323. bypass_nodes_.emplace_back(out_node);
  324. if (GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor) != GRAPH_SUCCESS) {
  325. GELOGE(FAILED, "Remove SwitchN out_data_edge failed, %s->%s.", node->GetName().c_str(),
  326. out_node->GetName().c_str());
  327. return FAILED;
  328. }
  329. for (auto &identity_out_node : out_node->GetOutControlNodes()) {
  330. output_nodes.emplace_back(identity_out_node);
  331. if (GraphUtils::RemoveEdge(out_node->GetOutControlAnchor(), identity_out_node->GetInControlAnchor()) !=
  332. GRAPH_SUCCESS) {
  333. GELOGE(FAILED, "Remove SwitchN out_data_edge failed, %s->%s.", node->GetName().c_str(),
  334. out_node->GetName().c_str());
  335. return FAILED;
  336. }
  337. }
  338. }
  339. }
  340. batch_head_nodes_.emplace_back(output_nodes);
  341. }
  342. return SUCCESS;
  343. }
  344. ///
  345. /// @brief Replace & Combine SwitchN nodes
  346. /// @param [in] graph
  347. /// @param [in] pred_value
  348. /// @param [in] batch_shape
  349. /// @param [in] combined_batch
  350. /// @return Status
  351. ///
  352. Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value,
  353. const vector<vector<int64_t>> &batch_shape,
  354. const vector<vector<int64_t>> &combined_batch) {
  355. NodePtr pred_value_node = pred_value->GetOwnerNode();
  356. // Create SwitchCase node
  357. const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN;
  358. NodePtr switch_case = CreateSwitchCaseNode(graph, switch_case_name, pred_value, batch_shape, combined_batch);
  359. if (switch_case == nullptr) {
  360. GELOGE(FAILED, "CreateSwitchCaseNode %s failed.", switch_case_name.c_str());
  361. return FAILED;
  362. }
  363. for (const NodePtr &switch_n_node : switch_n_nodes_) {
  364. if (BypassSwitchN(switch_n_node, switch_case) != SUCCESS) {
  365. GELOGE(FAILED, "Bypass SwitchN %s failed.", switch_case_name.c_str());
  366. return FAILED;
  367. }
  368. }
  369. // Add switchCase input edge
  370. if (GraphUtils::AddEdge(pred_value, switch_case->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  371. GELOGE(FAILED, "Add SwitchCase in_data_edge failed, %s->%s.", pred_value_node->GetName().c_str(),
  372. switch_case->GetName().c_str());
  373. return FAILED;
  374. }
  375. if (AttachLabel(switch_case) != SUCCESS) {
  376. GELOGE(FAILED, "AttachLabel failed.");
  377. return FAILED;
  378. }
  379. return SUCCESS;
  380. }
  381. ///
  382. /// @brief Check if output_shape of different SwitchN is same
  383. /// @param [in] output_shape
  384. /// @return bool
  385. ///
  386. bool MultiBatchPass::CheckDims(const std::vector<std::vector<int64_t>> &output_shape) const {
  387. if (output_shape.empty()) {
  388. GELOGE(FAILED, "CheckDims failed: output_shape is empty.");
  389. return false;
  390. }
  391. size_t num = output_shape.size();
  392. size_t dim_num = output_shape[0].size();
  393. for (size_t i = 1; i < num; i++) {
  394. size_t tmp_dim_num = output_shape[i].size();
  395. if (dim_num != tmp_dim_num) {
  396. GELOGE(FAILED, "CheckDims failed: dim_num not equal, output_0:%zu, output_%zu:%zu.", dim_num, i, tmp_dim_num);
  397. return false;
  398. }
  399. }
  400. if (dim_num == 0) {
  401. return true;
  402. }
  403. for (size_t i = 0; i < dim_num; i++) {
  404. int64_t dim_value = output_shape[0][i];
  405. for (size_t j = 1; j < num; j++) {
  406. int64_t tmp_dim_value = output_shape[j][i];
  407. if (dim_value != tmp_dim_value) {
  408. GELOGE(FAILED, "CheckDims failed: dim_value not equal, dim_index=%zu, dim_value_0:%ld, dim_value_%zu:%ld.", i,
  409. dim_value, j, tmp_dim_value);
  410. return false;
  411. }
  412. }
  413. }
  414. return true;
  415. }
  416. ///
  417. /// @brief Create StreamSwitchN node
  418. /// @param [in] graph
  419. /// @param [in] name
  420. /// @param [in] pred_value
  421. /// @param [in] batch_shape
  422. /// @param [in] combined_batch
  423. /// @return ge::NodePtr
  424. ///
  425. NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name,
  426. const OutDataAnchorPtr &pred_value,
  427. const vector<vector<int64_t>> &batch_shape,
  428. const vector<vector<int64_t>> &combined_batch) {
  429. OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMSWITCHN);
  430. if (op_desc == nullptr) {
  431. GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str());
  432. return nullptr;
  433. }
  434. GELOGI("Create StreamSwitchN op:%s.", name.c_str());
  435. OpDescPtr pred_desc = pred_value->GetOwnerNode()->GetOpDesc();
  436. if (pred_desc == nullptr) {
  437. GELOGE(FAILED, "Get pred_desc failed, StreamSwitchN:%s.", name.c_str());
  438. return nullptr;
  439. }
  440. if (op_desc->AddInputDesc(pred_desc->GetOutputDesc(pred_value->GetIdx())) != GRAPH_SUCCESS) {
  441. GELOGE(FAILED, "AddInputDesc failed, StreamSwitchN:%s.", name.c_str());
  442. return nullptr;
  443. }
  444. NodePtr switch_case_node = graph->AddNode(op_desc);
  445. if (switch_case_node == nullptr) {
  446. GELOGE(FAILED, "Create node failed, StreamSwitchN:%s.", name.c_str());
  447. return nullptr;
  448. }
  449. uint32_t batch_num = static_cast<uint32_t>(batch_shape.size());
  450. if (!AttrUtils::SetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) {
  451. GELOGE(FAILED, "set attr ATTR_NAME_BATCH_NUM failed, StreamSwitchN:%s.", name.c_str());
  452. return nullptr;
  453. }
  454. if (!AttrUtils::SetInt(op_desc, ATTR_DYNAMIC_TYPE, dynamic_type_)) {
  455. GELOGE(FAILED, "Set attr ATTR_DYNAMIC_TYPE failed, StreamSwitchN:%s.", name.c_str());
  456. return nullptr;
  457. }
  458. if (!AttrUtils::SetListStr(op_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, data_name_order_)) {
  459. GELOGE(FAILED, "Set attr ATTR_USER_DESIGNEATE_SHAPE_ORDER failed, StreamSwitchN:%s.", name.c_str());
  460. return nullptr;
  461. }
  462. for (uint32_t i = 0; i < batch_num; i++) {
  463. const std::string &attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i);
  464. if (!AttrUtils::SetListInt(op_desc, attr_name, batch_shape[i])) {
  465. GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str());
  466. return nullptr;
  467. }
  468. const string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i);
  469. if (!AttrUtils::SetListInt(op_desc, attr_combined_batch, combined_batch[i])) {
  470. GELOGE(FAILED, "set attr ATTR_NAME_COMBINED_BATCH failed, StreamSwitchN:%s.", name.c_str());
  471. return nullptr;
  472. }
  473. }
  474. return switch_case_node;
  475. }
  476. ///
  477. /// @brief Bypass SwitchN node
  478. /// @param [in] switch_n_node
  479. /// @param [in] switch_case
  480. /// @return Status
  481. ///
  482. Status MultiBatchPass::BypassSwitchN(const NodePtr &switch_n_node, const NodePtr &switch_case) {
  483. InDataAnchorPtr in_data_anchor = switch_n_node->GetInDataAnchor(SWITCH_DATA_INPUT);
  484. if (in_data_anchor == nullptr) {
  485. GELOGE(FAILED, "Check in_data_anchor failed, SwitchN:%s.", switch_n_node->GetName().c_str());
  486. return FAILED;
  487. }
  488. OutDataAnchorPtr peer_data_anchor = in_data_anchor->GetPeerOutAnchor();
  489. if (peer_data_anchor == nullptr) {
  490. GELOGE(FAILED, "Check peer_data_anchor failed, SwitchN:%s.", switch_n_node->GetName().c_str());
  491. return FAILED;
  492. }
  493. NodePtr data_input = peer_data_anchor->GetOwnerNode();
  494. // Remove SwitchN data input
  495. if (GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) != GRAPH_SUCCESS) {
  496. GELOGE(FAILED, "Remove SwitchN in_data_edge failed, %s->%s.", data_input->GetName().c_str(),
  497. switch_n_node->GetName().c_str());
  498. return FAILED;
  499. }
  500. if (GraphUtils::AddEdge(data_input->GetOutControlAnchor(), switch_case->GetInControlAnchor()) != GRAPH_SUCCESS) {
  501. GELOGE(FAILED, "Add StreamSwitchN in_control_edge failed, %s->%s.", data_input->GetName().c_str(),
  502. switch_case->GetName().c_str());
  503. return FAILED;
  504. }
  505. // Add SwitchCase control output
  506. for (const OutDataAnchorPtr &out_data_anchor : switch_n_node->GetAllOutDataAnchors()) {
  507. for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  508. NodePtr data_output = peer_in_anchor->GetOwnerNode();
  509. if ((GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor) != GRAPH_SUCCESS) ||
  510. (GraphUtils::AddEdge(peer_data_anchor, peer_in_anchor) != GRAPH_SUCCESS)) {
  511. GELOGE(FAILED, "Bypass SwitchN data_edge failed, %s->%s->%s.", data_input->GetName().c_str(),
  512. switch_n_node->GetName().c_str(), data_output->GetName().c_str());
  513. return FAILED;
  514. }
  515. if (GraphUtils::AddEdge(switch_case->GetOutControlAnchor(), data_output->GetInControlAnchor()) != GRAPH_SUCCESS) {
  516. GELOGE(FAILED, "Add SwitchCase out_control_edge failed, %s->%s.", switch_case->GetName().c_str(),
  517. data_output->GetName().c_str());
  518. return FAILED;
  519. }
  520. }
  521. }
  522. GE_CHK_STATUS_RET(MoveCtrlEdges(switch_n_node, switch_case), "Move ctrl edges failed.");
  523. bypass_nodes_.emplace_back(switch_n_node);
  524. GELOGI("Bypass SwitchN node %s success.", switch_n_node->GetName().c_str());
  525. return SUCCESS;
  526. }
  527. ///
  528. /// @brief Attach stream_label & batch_label for batch branch
  529. /// @param [in] switch_case_node
  530. /// @return Status
  531. ///
  532. Status MultiBatchPass::AttachLabel(const NodePtr &switch_case_node) {
  533. std::vector<std::string> stream_label_list;
  534. for (uint32_t i = 0; i < static_cast<uint32_t>(batch_head_nodes_.size()); i++) {
  535. if (AttachBatchLabel(i) != SUCCESS) {
  536. GELOGE(FAILED, "AttachBatchLabel failed, batch_idx=%u", i);
  537. return FAILED;
  538. }
  539. const std::string &stream_label = "stream_label_batch_" + std::to_string(i);
  540. if (AttachStreamLabel(i, stream_label) != SUCCESS) {
  541. GELOGE(FAILED, "AttachStreamLabel failed, stream_label=%s", stream_label.c_str());
  542. return FAILED;
  543. }
  544. stream_label_list.emplace_back(stream_label);
  545. }
  546. return switch_case_node == nullptr ? SUCCESS : SetActiveLabelList(switch_case_node, stream_label_list);
  547. }
  548. ///
  549. /// @brief Attach batch_label for batch branch
  550. /// @param [in] batch_idx
  551. /// @return Status
  552. ///
  553. Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) {
  554. std::stack<NodePtr> nodes;
  555. for (const auto &node : batch_head_nodes_[batch_idx]) {
  556. nodes.push(node);
  557. }
  558. const std::string &batch_label = "Batch_" + std::to_string(batch_idx);
  559. std::unordered_set<NodePtr> handled_nodes;
  560. while (!nodes.empty()) {
  561. NodePtr cur_node = nodes.top();
  562. nodes.pop();
  563. if (handled_nodes.count(cur_node) > 0) {
  564. continue;
  565. }
  566. OpDescPtr cur_desc = cur_node->GetOpDesc();
  567. GE_CHECK_NOTNULL(cur_desc);
  568. if (cur_desc->HasAttr(ATTR_NAME_BATCH_LABEL)) {
  569. std::string tmp_label;
  570. if (!AttrUtils::GetStr(cur_desc, ATTR_NAME_BATCH_LABEL, tmp_label)) {
  571. GELOGE(FAILED, "get attr ATTR_NAME_BATCH_LABEL failed, node: %s.", cur_desc->GetName().c_str());
  572. return FAILED;
  573. }
  574. if (tmp_label != batch_label) {
  575. GELOGE(FAILED, "Reach other batch_branch, node:%s, cur_label:%s, batch_label:%s.", cur_desc->GetName().c_str(),
  576. tmp_label.c_str(), batch_label.c_str());
  577. return FAILED;
  578. }
  579. }
  580. GELOGD("Attach batch_label %s to node %s.", batch_label.c_str(), cur_desc->GetName().c_str());
  581. if (!AttrUtils::SetStr(cur_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
  582. GELOGE(FAILED, "set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", cur_desc->GetName().c_str());
  583. return FAILED;
  584. }
  585. for (const auto &out_node : cur_node->GetOutAllNodes()) {
  586. OpDescPtr op_desc = out_node->GetOpDesc();
  587. GE_CHECK_NOTNULL(op_desc);
  588. const std::string &type = op_desc->GetType();
  589. if ((type == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) {
  590. continue;
  591. }
  592. if (type == NETOUTPUT) {
  593. GELOGE(FAILED, "Reach net_output without Merge, cur_node:%s.", cur_node->GetName().c_str());
  594. return FAILED;
  595. }
  596. nodes.push(out_node);
  597. }
  598. (void)handled_nodes.insert(cur_node);
  599. }
  600. return SUCCESS;
  601. }
  602. ///
  603. /// @brief Attach stream_label for batch branch
  604. /// @param [in] batch_idx
  605. /// @param [in] stream_label
  606. /// @return Status
  607. ///
  608. Status MultiBatchPass::AttachStreamLabel(uint32_t batch_idx, const std::string &stream_label) {
  609. std::stack<NodePtr> nodes;
  610. for (const auto &node : batch_head_nodes_[batch_idx]) {
  611. nodes.push(node);
  612. }
  613. std::unordered_set<NodePtr> handled_nodes;
  614. while (!nodes.empty()) {
  615. NodePtr cur_node = nodes.top();
  616. nodes.pop();
  617. OpDescPtr cur_desc = cur_node->GetOpDesc();
  618. GE_CHECK_NOTNULL(cur_desc);
  619. if ((handled_nodes.count(cur_node) > 0) || (cur_desc->HasAttr(ATTR_NAME_STREAM_LABEL))) {
  620. continue;
  621. }
  622. GELOGD("Attach stream_label %s to node %s.", stream_label.c_str(), cur_desc->GetName().c_str());
  623. if (SetStreamLabel(cur_node, stream_label) != SUCCESS) {
  624. GELOGE(FAILED, "Set stream_label failed, node:%s.", cur_node->GetName().c_str());
  625. return FAILED;
  626. }
  627. for (const auto &out_node : cur_node->GetOutAllNodes()) {
  628. nodes.push(out_node);
  629. }
  630. (void)handled_nodes.insert(cur_node);
  631. }
  632. return SUCCESS;
  633. }
  634. ///
  635. /// @brief move edges from old_node to new_node
  636. /// @param [in] old_node
  637. /// @param [in] new_node
  638. /// @return Status
  639. ///
  640. Status MultiBatchPass::MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new_node) {
  641. if (old_node == new_node) {
  642. return SUCCESS;
  643. }
  644. for (const NodePtr &in_ctrl_node : old_node->GetInControlNodes()) {
  645. GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), old_node->GetInControlAnchor()),
  646. "Merge remove in ctrl edge failed.");
  647. GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), new_node->GetInControlAnchor()),
  648. "StreamMerge add in ctrl edge failed.");
  649. }
  650. for (const NodePtr &out_ctrl_node : old_node->GetOutControlNodes()) {
  651. GE_CHK_STATUS(GraphUtils::RemoveEdge(old_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()),
  652. "Merge remove out ctrl edge failed.");
  653. GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()),
  654. "StreamMerge add out ctrl edge failed.");
  655. }
  656. return SUCCESS;
  657. }
  658. ///
  659. /// @brief attach stream_label & batch_label without change structure of graph
  660. /// @param [in] batch_num
  661. /// @return void
  662. ///
  663. Status MultiBatchPass::AttachLabelOnly(uint32_t batch_num) {
  664. std::vector<NodePtr> output_nodes;
  665. for (uint32_t i = 0; i < batch_num; i++) {
  666. output_nodes.clear();
  667. for (const NodePtr &node : switch_n_nodes_) {
  668. // idx is promised to be valid
  669. OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(i);
  670. GE_CHECK_NOTNULL(out_data_anchor);
  671. for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  672. output_nodes.emplace_back(peer_in_anchor->GetOwnerNode());
  673. }
  674. }
  675. batch_head_nodes_.emplace_back(output_nodes);
  676. }
  677. return AttachLabel(nullptr);
  678. }
  679. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示