You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_batch_pass.cc 40 kB

5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/multi_batch_pass.h"
  17. #include <stack>
  18. #include <unordered_set>
  19. #include "common/ge/ge_util.h"
  20. #include "common/omg_util.h"
  21. #include "graph/utils/type_utils.h"
  22. #include "common/formats/utils/formats_trans_utils.h"
  23. namespace ge {
  24. Status MultiBatchPass::Run(ComputeGraphPtr graph) {
  25. GELOGD("MultiBatchPass Enter");
  26. if (graph->GetParentGraph() != nullptr) {
  27. GELOGI("Subgraph %s skip the MultiBatchPass.", graph->GetName().c_str());
  28. return SUCCESS;
  29. }
  30. OutDataAnchorPtr pred_value = nullptr;
  31. Status ret = FindPredValue(graph, pred_value);
  32. if (ret == NOT_CHANGED) {
  33. GELOGD("SwitchN node not exist, graph not changed.");
  34. return SUCCESS;
  35. }
  36. if (ret != SUCCESS) {
  37. GELOGE(FAILED, "[Find][PredValue] in graph:%s failed.", graph->GetName().c_str());
  38. return FAILED;
  39. }
  40. if (GetDynamicType() != SUCCESS) {
  41. GELOGE(FAILED, "[Get][DynamicType] in graph:%s failed.", graph->GetName().c_str());
  42. return FAILED;
  43. }
  44. if (GetUserDesignateShape() != SUCCESS) {
  45. GELOGE(FAILED, "[Get][UserDesignateShape] in graph:%s failed.", graph->GetName().c_str());
  46. return FAILED;
  47. }
  48. std::vector<std::vector<int64_t>> batch_shape;
  49. std::vector<std::vector<int64_t>> combined_batch;
  50. if (!CheckSwitchN(batch_shape, combined_batch)) {
  51. GELOGE(FAILED, "[Check][SwitchN] in graph:%s failed.", graph->GetName().c_str());
  52. return FAILED;
  53. }
  54. if (attach_label_only_) {
  55. return AttachLabelOnly(batch_shape.size());
  56. }
  57. if (FindSwitchOutNodes(batch_shape.size()) != SUCCESS) {
  58. GELOGE(FAILED, "[Find][SwitchOutNodes] in graph:%s failed, batch_num:%zu.",
  59. graph->GetName().c_str(), batch_shape.size());
  60. return FAILED;
  61. }
  62. if (ReplaceSwitchN(graph, pred_value, batch_shape, combined_batch) != SUCCESS) {
  63. GELOGE(FAILED, "[Replace][SwitchN] in graph:%s failed.", graph->GetName().c_str());
  64. return FAILED;
  65. }
  66. for (const NodePtr &node : bypass_nodes_) {
  67. if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) {
  68. REPORT_CALL_ERROR("E19999", "Remove node:%s(%s) without relink in graph:%s failed",
  69. node->GetName().c_str(), node->GetType().c_str(), graph->GetName().c_str());
  70. GELOGE(FAILED, "[Remove][Node] %s(%s) without relink in graph:%s failed",
  71. node->GetName().c_str(), node->GetType().c_str(), graph->GetName().c_str());
  72. return FAILED;
  73. }
  74. }
  75. GELOGD("MultiBatchPass Leave");
  76. return SUCCESS;
  77. }
  78. ///
  79. /// @brief Clear Status
  80. /// @return
  81. ///
  82. Status MultiBatchPass::ClearStatus() {
  83. switch_n_nodes_.clear();
  84. bypass_nodes_.clear();
  85. batch_head_nodes_.clear();
  86. return SUCCESS;
  87. }
  88. ///
  89. /// @ingroup ge
  90. /// @brief Set batch label for Case mode.
  91. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  92. /// @param [in] const NodePtr &case_node: Case Node.
  93. /// @return 0: SUCCESS / others: FAILED
  94. ///
  95. Status MultiBatchPass::SetCaseLabel(const ComputeGraphPtr &graph, const NodePtr &case_node) {
  96. const auto &func_desc = case_node->GetOpDesc();
  97. GE_CHECK_NOTNULL(func_desc);
  98. if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) {
  99. GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), case_node->GetName().c_str());
  100. return SUCCESS;
  101. }
  102. const auto &dynamic_branch_names = func_desc->GetSubgraphInstanceNames();
  103. for (size_t i = 0; i < dynamic_branch_names.size(); ++i) {
  104. const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[i]);
  105. GE_CHECK_NOTNULL(subgraph);
  106. const std::string batch_label = "Batch_" + std::to_string(i);
  107. for (const auto &node : subgraph->GetDirectNode()) {
  108. (void)AttrUtils::SetStr(node->GetOpDesc(), ATTR_NAME_BATCH_LABEL, batch_label);
  109. }
  110. }
  111. return SUCCESS;
  112. }
  113. ///
  114. /// @brief Replace & Combine SwitchN nodes
  115. /// @param [in] graph
  116. /// @param [out] pred_value
  117. /// @return Status
  118. ///
  119. Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchorPtr &pred_value) {
  120. for (const NodePtr &node : graph->GetDirectNode()) {
  121. if (node->GetType() == CASE) {
  122. GE_CHK_STATUS_RET(SetCaseLabel(graph, node),
  123. "[Set][CaseLabel] for node:%s(%s) in graph:%s failed",
  124. node->GetName().c_str(), node->GetType().c_str(), graph->GetName().c_str());
  125. continue;
  126. }
  127. if (node->GetType() != SWITCHN) {
  128. continue;
  129. }
  130. const auto &in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT);
  131. if (in_data_anchor == nullptr) {
  132. REPORT_INNER_ERROR("E19999", "Index:%u data anchor of node:%s(%s) is nullptr, check invalid",
  133. SWITCH_PRED_INPUT, node->GetName().c_str(), node->GetType().c_str());
  134. GELOGE(FAILED, "[Get][InDataAnchor] failed, Index:%u data anchor of node:%s(%s) is nullptr.",
  135. SWITCH_PRED_INPUT, node->GetName().c_str(), node->GetType().c_str());
  136. return FAILED;
  137. }
  138. const auto &pred_input = in_data_anchor->GetPeerOutAnchor();
  139. if (pred_input == nullptr) {
  140. REPORT_INNER_ERROR("E19999", "Index:%u data anchor of node:%s(%s), its peer anchor is nullptr, check invalid",
  141. SWITCH_PRED_INPUT, node->GetName().c_str(), node->GetType().c_str());
  142. GELOGE(FAILED, "[Get][PeerOutAnchor] failed, Index:%u data anchor of node:%s(%s), its peer anchor is nullptr.",
  143. SWITCH_PRED_INPUT, node->GetName().c_str(), node->GetType().c_str());
  144. return FAILED;
  145. }
  146. if (pred_value == nullptr) {
  147. pred_value = pred_input;
  148. } else if (pred_value != pred_input) {
  149. REPORT_INNER_ERROR("E19999", "Multi pred_value of case node exist in graph:%s, check invalid",
  150. graph->GetName().c_str());
  151. GELOGE(FAILED, "[Check][Param] Multi pred_value of case node exist in graph:%s.", graph->GetName().c_str());
  152. return FAILED;
  153. }
  154. switch_n_nodes_.emplace_back(node);
  155. }
  156. if (switch_n_nodes_.empty()) {
  157. GELOGD("SwitchN node not exist.");
  158. return NOT_CHANGED;
  159. }
  160. if (pred_value == nullptr) {
  161. REPORT_INNER_ERROR("E19999", "Find Pred Input of case node in graph:%s failed", graph->GetName().c_str());
  162. GELOGE(FAILED, "[Check][Param] FindPredInput in graph:%s failed, pred_value is null.", graph->GetName().c_str());
  163. return FAILED;
  164. }
  165. GELOGI("Find pred_value %s.", pred_value->GetOwnerNode()->GetName().c_str());
  166. return SUCCESS;
  167. }
  168. ///
  169. /// @brief Get dynamic type: dynamic batch size: 1, dynamic image size: 2, dynamic dims: 3
  170. /// @return Status
  171. ///
  172. Status MultiBatchPass::GetDynamicType() {
  173. for (const auto &switch_n : switch_n_nodes_) {
  174. int32_t dynamic_type = static_cast<int32_t>(FIXED);
  175. if (!AttrUtils::GetInt(switch_n->GetOpDesc(), ATTR_DYNAMIC_TYPE, dynamic_type)) {
  176. REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_DYNAMIC_TYPE.c_str(),
  177. switch_n->GetName().c_str(), switch_n->GetType().c_str());
  178. GELOGE(FAILED, "[Get][Attr] %s from op:%s(%s) failed", ATTR_DYNAMIC_TYPE.c_str(),
  179. switch_n->GetName().c_str(), switch_n->GetType().c_str());
  180. return FAILED;
  181. }
  182. if (dynamic_type == static_cast<int32_t>(FIXED)) {
  183. REPORT_INNER_ERROR("E19999", "Attr:%s in op:%s(%s), value:%d check invalid", ATTR_DYNAMIC_TYPE.c_str(),
  184. switch_n->GetName().c_str(), switch_n->GetType().c_str(), dynamic_type);
  185. GELOGE(FAILED, "[Check][Param] Attr:%s in op:%s(%s), value:%d is invalid", ATTR_DYNAMIC_TYPE.c_str(),
  186. switch_n->GetName().c_str(), switch_n->GetType().c_str(), dynamic_type);
  187. return FAILED;
  188. }
  189. if (dynamic_type_ != static_cast<int32_t>(FIXED) && dynamic_type_ != dynamic_type) {
  190. REPORT_INNER_ERROR("E19999", "Attr:%s in op:%s(%s), value:%d not same as attr value:%d in node before, "
  191. "check invalid",
  192. ATTR_DYNAMIC_TYPE.c_str(), switch_n->GetName().c_str(), switch_n->GetType().c_str(),
  193. dynamic_type, dynamic_type_);
  194. GELOGE(FAILED, "[Check][Param] Attr:%s in op:%s(%s), value:%d not same as attr value:%d in node before",
  195. ATTR_DYNAMIC_TYPE.c_str(), switch_n->GetName().c_str(), switch_n->GetType().c_str(),
  196. dynamic_type, dynamic_type_);
  197. return FAILED;
  198. }
  199. dynamic_type_ = dynamic_type;
  200. }
  201. if (dynamic_type_ == static_cast<int32_t>(FIXED)) {
  202. REPORT_INNER_ERROR("E19999", "Find Attr:%s in all switcnn node failed", ATTR_DYNAMIC_TYPE.c_str());
  203. GELOGE(FAILED, "[Check][Param] Find Attr:%s in all switcnn node failed", ATTR_DYNAMIC_TYPE.c_str());
  204. return FAILED;
  205. }
  206. return SUCCESS;
  207. }
  208. ///
  209. /// @brief Get user designate shape order. eg{"data","label","mask"}
  210. /// @return Status
  211. ///
  212. Status MultiBatchPass::GetUserDesignateShape() {
  213. data_name_order_.clear();
  214. bool first_check = true;
  215. for (const auto &switch_n : switch_n_nodes_) {
  216. std::vector<std::string> cur_data_name_order;
  217. if (!AttrUtils::GetListStr(switch_n->GetOpDesc(), ATTR_USER_DESIGNEATE_SHAPE_ORDER, cur_data_name_order)) {
  218. REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_USER_DESIGNEATE_SHAPE_ORDER.c_str(),
  219. switch_n->GetName().c_str(), switch_n->GetType().c_str());
  220. GELOGE(FAILED, "[Get][Attr] %s from op:%s(%s) failed", ATTR_USER_DESIGNEATE_SHAPE_ORDER.c_str(),
  221. switch_n->GetName().c_str(), switch_n->GetType().c_str());
  222. return FAILED;
  223. }
  224. if (first_check) {
  225. data_name_order_ = cur_data_name_order;
  226. first_check = false;
  227. } else {
  228. if (data_name_order_ != cur_data_name_order) {
  229. REPORT_INNER_ERROR("E19999", "Attr:%s in op:%s(%s), value:%s not same as attr value:%s in node before, "
  230. "check invalid", ATTR_USER_DESIGNEATE_SHAPE_ORDER.c_str(),
  231. switch_n->GetName().c_str(), switch_n->GetType().c_str(),
  232. formats::JoinToString(cur_data_name_order).c_str(),
  233. formats::JoinToString(data_name_order_).c_str());
  234. GELOGE(FAILED, "[Check][Param] Attr:%s in op:%s(%s), value:%s not same as attr value:%s in node before.",
  235. ATTR_USER_DESIGNEATE_SHAPE_ORDER.c_str(), switch_n->GetName().c_str(), switch_n->GetType().c_str(),
  236. formats::JoinToString(cur_data_name_order).c_str(), formats::JoinToString(data_name_order_).c_str());
  237. return FAILED;
  238. }
  239. }
  240. }
  241. if (data_name_order_.empty()) {
  242. REPORT_INNER_ERROR("E19999", "Find Attr:%s in all switcnn node failed", ATTR_USER_DESIGNEATE_SHAPE_ORDER.c_str());
  243. GELOGE(FAILED, "[Check][Param] Find Attr:%s in all switcnn node failed", ATTR_USER_DESIGNEATE_SHAPE_ORDER.c_str());
  244. return FAILED;
  245. }
  246. return SUCCESS;
  247. }
  248. ///
  249. /// @brief Check SwitchN nodes
  250. /// @param [out] batch_shape
  251. /// @param [out] combined_batch
  252. /// @return bool
  253. ///
  254. bool MultiBatchPass::CheckSwitchN(std::vector<std::vector<int64_t>> &batch_shape,
  255. std::vector<std::vector<int64_t>> &combined_batch) {
  256. // Check if output_num of different SwitchN is same
  257. uint32_t batch_num = 0;
  258. for (const NodePtr &node : switch_n_nodes_) {
  259. uint32_t tmp_num = node->GetAllOutDataAnchorsSize();
  260. if (batch_num == 0) {
  261. batch_num = tmp_num;
  262. } else if (batch_num != tmp_num) {
  263. REPORT_INNER_ERROR("E19999", "Ouput size num:%u of node:%s(%s) not same as output size num:%d of node before, "
  264. "check invalid", tmp_num, node->GetName().c_str(), node->GetType().c_str(), batch_num);
  265. GELOGE(FAILED, "[Check][Param] Ouput size num:%u of node:%s(%s) not same as output size num:%d of node before",
  266. tmp_num, node->GetName().c_str(), node->GetType().c_str(), batch_num);
  267. return false;
  268. }
  269. }
  270. if (!GetBatchInfo(batch_num, batch_shape, combined_batch)) {
  271. GELOGE(FAILED, "[Get][BatchInfo] failed, batch_num:%u.", batch_num);
  272. return false;
  273. }
  274. if (batch_shape.empty()) {
  275. REPORT_INNER_ERROR("E19999", "batch_shape size is empty after GetBatchInfo, check invalid");
  276. GELOGE(FAILED, "[Check][Param] batch_shape is empty after GetBatchInfo.");
  277. return false;
  278. }
  279. if (combined_batch.empty()) {
  280. REPORT_INNER_ERROR("E19999", "combined_batch size is empty after GetBatchInfo, check invalid");
  281. GELOGE(FAILED, "[Check][Param] combined_batch is empty after GetBatchInfo.");
  282. return false;
  283. }
  284. size_t dim_num = batch_shape[0].size();
  285. size_t combined_dim_num = combined_batch[0].size();
  286. for (uint32_t i = 1; i < batch_num; i++) {
  287. size_t tmp_dim_num = batch_shape[i].size();
  288. if (dim_num != tmp_dim_num) {
  289. REPORT_INNER_ERROR("E19999", "Dim num of batch_shape not equal, batch_0:%zu, batch_%u:%zu, check invalid",
  290. dim_num, i, tmp_dim_num);
  291. GELOGE(FAILED, "[Check][Param] Dim num of batch_shape not equal, batch_0:%zu, batch_%u:%zu.",
  292. dim_num, i, tmp_dim_num);
  293. return false;
  294. }
  295. size_t tmp_combined_dim_num = combined_batch[i].size();
  296. if (combined_dim_num != tmp_combined_dim_num) {
  297. REPORT_INNER_ERROR("E19999", "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu, check invalid",
  298. combined_dim_num, i, tmp_combined_dim_num);
  299. GELOGE(FAILED, "[Check][Param] Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.",
  300. combined_dim_num, i, tmp_combined_dim_num);
  301. return false;
  302. }
  303. }
  304. return true;
  305. }
  306. ///
  307. /// @brief Check SwitchN nodes
  308. /// @param [in] batch_num
  309. /// @param [out] batch_shape
  310. /// @param [out] combined_batch
  311. /// @return bool
  312. ///
  313. bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, std::vector<std::vector<int64_t>> &batch_shape,
  314. std::vector<std::vector<int64_t>> &combined_batch) {
  315. // Check if output_shape of different SwitchN is same
  316. std::vector<std::vector<int64_t>> idx_batch_shape;
  317. std::vector<std::vector<int64_t>> idx_combined_batch;
  318. for (uint32_t i = 0; i < batch_num; i++) {
  319. idx_batch_shape.clear();
  320. idx_combined_batch.clear();
  321. for (const NodePtr &node : switch_n_nodes_) {
  322. OpDescPtr op_desc = node->GetOpDesc();
  323. if (op_desc == nullptr) {
  324. REPORT_INNER_ERROR("E19999", "OpDesc in node is nullptr, check invalid");
  325. GELOGE(FAILED, "[Get][OpDesc] failed, OpDesc in node is nullptr.");
  326. return false;
  327. }
  328. std::vector<int64_t> output_dims;
  329. if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) {
  330. REPORT_CALL_ERROR("E19999", "Get Attr:%s from output:%u tensor of op:%s(%s) failed",
  331. ATTR_NAME_SWITCHN_PRED_VALUE.c_str(), i,
  332. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  333. GELOGE(FAILED, "[Get][Attr] %s from output:%u tensor of op:%s(%s) failed",
  334. ATTR_NAME_SWITCHN_PRED_VALUE.c_str(), i, op_desc->GetName().c_str(), op_desc->GetType().c_str());
  335. return false;
  336. }
  337. idx_batch_shape.emplace_back(output_dims);
  338. output_dims.clear();
  339. if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_COMBINED_DYNAMIC_DIMS, output_dims)) {
  340. REPORT_CALL_ERROR("E19999", "Get Attr:%s from output:%u tensor of op:%s(%s) failed",
  341. ATTR_NAME_COMBINED_DYNAMIC_DIMS.c_str(), i,
  342. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  343. GELOGE(FAILED, "[Get][Attr] %s from output:%u tensor of op:%s(%s) failed",
  344. ATTR_NAME_COMBINED_DYNAMIC_DIMS.c_str(), i, op_desc->GetName().c_str(), op_desc->GetType().c_str());
  345. return false;
  346. }
  347. idx_combined_batch.emplace_back(output_dims);
  348. }
  349. if (!CheckDims(idx_batch_shape)) {
  350. REPORT_INNER_ERROR("E19999", "Attr:%s of all output:%u tensor in switcnn node not equal, or not exist, "
  351. "check invalid", ATTR_NAME_SWITCHN_PRED_VALUE.c_str(), i);
  352. GELOGE(FAILED, "[Check][Dims] failed, Attr:%s of all output:%u tensor in switcnn node not equal, or not exist.",
  353. ATTR_NAME_SWITCHN_PRED_VALUE.c_str(), i);
  354. return false;
  355. }
  356. batch_shape.emplace_back(idx_batch_shape[0]);
  357. combined_batch.emplace_back(idx_combined_batch[0]);
  358. }
  359. return true;
  360. }
  361. ///
  362. /// @brief Find outputs of SwitchN nodes
  363. /// @param [in] batch_num
  364. /// @return void
  365. ///
  366. Status MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) {
  367. std::vector<NodePtr> output_nodes;
  368. for (uint32_t i = 0; i < batch_num; i++) {
  369. output_nodes.clear();
  370. for (const NodePtr &node : switch_n_nodes_) {
  371. // idx is promised to be valid
  372. OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(i);
  373. GE_CHECK_NOTNULL(out_data_anchor);
  374. for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  375. auto out_node = peer_in_anchor->GetOwnerNode();
  376. if (out_node->GetType() != IDENTITY || !out_node->GetOutDataNodes().empty()) {
  377. output_nodes.emplace_back(out_node);
  378. continue;
  379. }
  380. bypass_nodes_.emplace_back(out_node);
  381. if (GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor) != GRAPH_SUCCESS) {
  382. REPORT_CALL_ERROR("E19999", "Remove edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
  383. node->GetName().c_str(), node->GetType().c_str(), i,
  384. out_node->GetName().c_str(), out_node->GetType().c_str(), peer_in_anchor->GetIdx());
  385. GELOGE(FAILED, "[Remove][Edge] between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
  386. node->GetName().c_str(), node->GetType().c_str(), i,
  387. out_node->GetName().c_str(), out_node->GetType().c_str(), peer_in_anchor->GetIdx());
  388. return FAILED;
  389. }
  390. for (auto &identity_out_node : out_node->GetOutControlNodes()) {
  391. output_nodes.emplace_back(identity_out_node);
  392. if (GraphUtils::RemoveEdge(out_node->GetOutControlAnchor(), identity_out_node->GetInControlAnchor()) !=
  393. GRAPH_SUCCESS) {
  394. REPORT_CALL_ERROR("E19999", "Remove control edge between op:%s(%s) and op:%s(%s) failed",
  395. out_node->GetName().c_str(), out_node->GetType().c_str(),
  396. identity_out_node->GetName().c_str(), identity_out_node->GetType().c_str());
  397. GELOGE(FAILED, "[Remove][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
  398. out_node->GetName().c_str(), out_node->GetType().c_str(),
  399. identity_out_node->GetName().c_str(), identity_out_node->GetType().c_str());
  400. return FAILED;
  401. }
  402. }
  403. }
  404. }
  405. batch_head_nodes_.emplace_back(output_nodes);
  406. }
  407. return SUCCESS;
  408. }
  409. ///
  410. /// @brief Replace & Combine SwitchN nodes
  411. /// @param [in] graph
  412. /// @param [in] pred_value
  413. /// @param [in] batch_shape
  414. /// @param [in] combined_batch
  415. /// @return Status
  416. ///
  417. Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value,
  418. const std::vector<std::vector<int64_t>> &batch_shape,
  419. const std::vector<std::vector<int64_t>> &combined_batch) {
  420. NodePtr pred_value_node = pred_value->GetOwnerNode();
  421. // Create SwitchCase node
  422. const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN;
  423. NodePtr switch_case = CreateSwitchCaseNode(graph, switch_case_name, pred_value, batch_shape, combined_batch);
  424. if (switch_case == nullptr) {
  425. GELOGE(FAILED, "[Create][SwitchCaseNode] %s failed.", switch_case_name.c_str());
  426. return FAILED;
  427. }
  428. for (const NodePtr &switch_n_node : switch_n_nodes_) {
  429. if (BypassSwitchN(switch_n_node, switch_case) != SUCCESS) {
  430. GELOGE(FAILED, "[Call][BypassSwitchN] for %s failed.", switch_case_name.c_str());
  431. return FAILED;
  432. }
  433. }
  434. // Add switchCase input edge
  435. if (GraphUtils::AddEdge(pred_value, switch_case->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  436. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed",
  437. pred_value_node->GetName().c_str(), pred_value_node->GetType().c_str(), pred_value->GetIdx(),
  438. switch_case->GetName().c_str(), switch_case->GetType().c_str());
  439. GELOGE(FAILED, "[Add][Edge] between op:%s(%s)(index:%d) and op:%s(%s)(index:0) failed",
  440. pred_value_node->GetName().c_str(), pred_value_node->GetType().c_str(), pred_value->GetIdx(),
  441. switch_case->GetName().c_str(), switch_case->GetType().c_str());
  442. return FAILED;
  443. }
  444. if (AttachLabel(switch_case) != SUCCESS) {
  445. GELOGE(FAILED, "[Attach][Label] for node:%s(%s) failed.",
  446. switch_case->GetName().c_str(), switch_case->GetType().c_str());
  447. return FAILED;
  448. }
  449. return SUCCESS;
  450. }
  451. ///
  452. /// @brief Check if output_shape of different SwitchN is same
  453. /// @param [in] output_shape
  454. /// @return bool
  455. ///
  456. bool MultiBatchPass::CheckDims(const std::vector<std::vector<int64_t>> &output_shape) const {
  457. if (output_shape.empty()) {
  458. GELOGE(FAILED, "[Check][Param] output_shape is empty.");
  459. return false;
  460. }
  461. for (auto iter = output_shape.begin() + 1; iter != output_shape.end(); ++iter) {
  462. if (output_shape[0] != *iter) {
  463. return false;
  464. }
  465. }
  466. return true;
  467. }
  468. ///
  469. /// @brief Create StreamSwitchN node
  470. /// @param [in] graph
  471. /// @param [in] name
  472. /// @param [in] pred_value
  473. /// @param [in] batch_shape
  474. /// @param [in] combined_batch
  475. /// @return ge::NodePtr
  476. ///
  477. NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name,
  478. const OutDataAnchorPtr &pred_value,
  479. const std::vector<std::vector<int64_t>> &batch_shape,
  480. const std::vector<std::vector<int64_t>> &combined_batch) {
  481. OpDescPtr op_desc = MakeShared<OpDesc>(name, STREAMSWITCHN);
  482. if (op_desc == nullptr) {
  483. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  484. GELOGE(FAILED, "[New][OpDesc] failed.");
  485. return nullptr;
  486. }
  487. GELOGI("Create StreamSwitchN op:%s.", name.c_str());
  488. OpDescPtr pred_desc = pred_value->GetOwnerNode()->GetOpDesc();
  489. if (pred_desc == nullptr) {
  490. REPORT_INNER_ERROR("E19999", "OpDesc in node is nullptr, check invalid");
  491. GELOGE(FAILED, "[Get][OpDesc] failed, OpDesc in node is nullptr.");
  492. return nullptr;
  493. }
  494. if (op_desc->AddInputDesc(pred_desc->GetOutputDesc(pred_value->GetIdx())) != GRAPH_SUCCESS) {
  495. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
  496. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  497. GELOGE(FAILED, "[Add][InputDesc] to op:%s(%s) failed",
  498. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  499. return nullptr;
  500. }
  501. NodePtr switch_case_node = graph->AddNode(op_desc);
  502. if (switch_case_node == nullptr) {
  503. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  504. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  505. GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed",
  506. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  507. return nullptr;
  508. }
  509. uint32_t batch_num = static_cast<uint32_t>(batch_shape.size());
  510. if (!AttrUtils::SetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) {
  511. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_BATCH_NUM.c_str(),
  512. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  513. GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_BATCH_NUM.c_str(),
  514. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  515. return nullptr;
  516. }
  517. if (!AttrUtils::SetInt(op_desc, ATTR_DYNAMIC_TYPE, dynamic_type_)) {
  518. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_DYNAMIC_TYPE.c_str(),
  519. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  520. GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", ATTR_DYNAMIC_TYPE.c_str(),
  521. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  522. return nullptr;
  523. }
  524. if (!AttrUtils::SetListStr(op_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, data_name_order_)) {
  525. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_USER_DESIGNEATE_SHAPE_ORDER.c_str(),
  526. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  527. GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", ATTR_USER_DESIGNEATE_SHAPE_ORDER.c_str(),
  528. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  529. return nullptr;
  530. }
  531. for (uint32_t i = 0; i < batch_num; i++) {
  532. const std::string &attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i);
  533. if (!AttrUtils::SetListInt(op_desc, attr_name, batch_shape[i])) {
  534. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", attr_name.c_str(),
  535. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  536. GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", attr_name.c_str(),
  537. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  538. return nullptr;
  539. }
  540. const std::string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i);
  541. if (!AttrUtils::SetListInt(op_desc, attr_combined_batch, combined_batch[i])) {
  542. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", attr_combined_batch.c_str(),
  543. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  544. GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", attr_combined_batch.c_str(),
  545. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  546. return nullptr;
  547. }
  548. }
  549. return switch_case_node;
  550. }
  551. ///
  552. /// @brief Bypass SwitchN node
  553. /// @param [in] switch_n_node
  554. /// @param [in] switch_case
  555. /// @return Status
  556. ///
  557. Status MultiBatchPass::BypassSwitchN(const NodePtr &switch_n_node, const NodePtr &switch_case) {
  558. InDataAnchorPtr in_data_anchor = switch_n_node->GetInDataAnchor(SWITCH_DATA_INPUT);
  559. if (in_data_anchor == nullptr) {
  560. REPORT_INNER_ERROR("E19999", "Index:%u in data anchor of node:%s(%s) is nullptr, check invalid",
  561. SWITCH_DATA_INPUT, switch_n_node->GetName().c_str(), switch_n_node->GetType().c_str());
  562. GELOGE(FAILED, "[Get][InDataAnchor] failed, Index:%u in data anchor of node:%s(%s) is nullptr",
  563. SWITCH_DATA_INPUT, switch_n_node->GetName().c_str(), switch_n_node->GetType().c_str());
  564. return FAILED;
  565. }
  566. OutDataAnchorPtr peer_data_anchor = in_data_anchor->GetPeerOutAnchor();
  567. if (peer_data_anchor == nullptr) {
  568. REPORT_INNER_ERROR("E19999", "Index:%u in data anchor of node:%s(%s), its peer ahcnhor is nullptr, check invalid",
  569. SWITCH_DATA_INPUT, switch_n_node->GetName().c_str(), switch_n_node->GetType().c_str());
  570. GELOGE(FAILED, "[Get][PeerOutAnchor] failed, Index:%u in data anchor of node:%s(%s), its peer ahcnhor is nullptr",
  571. SWITCH_DATA_INPUT, switch_n_node->GetName().c_str(), switch_n_node->GetType().c_str());
  572. return FAILED;
  573. }
  574. NodePtr data_input = peer_data_anchor->GetOwnerNode();
  575. // Remove SwitchN data input
  576. if (GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) != GRAPH_SUCCESS) {
  577. REPORT_CALL_ERROR("E19999", "Remove edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%u) failed",
  578. data_input->GetName().c_str(), data_input->GetType().c_str(), peer_data_anchor->GetIdx(),
  579. switch_n_node->GetName().c_str(), switch_n_node->GetType().c_str(), SWITCH_DATA_INPUT);
  580. GELOGE(FAILED, "[Remove][Edge] between op:%s(%s)(index:%d) and op:%s(%s)(index:%u) failed",
  581. data_input->GetName().c_str(), data_input->GetType().c_str(), peer_data_anchor->GetIdx(),
  582. switch_n_node->GetName().c_str(), switch_n_node->GetType().c_str(), SWITCH_DATA_INPUT);
  583. return FAILED;
  584. }
  585. if (GraphUtils::AddEdge(data_input->GetOutControlAnchor(), switch_case->GetInControlAnchor()) != GRAPH_SUCCESS) {
  586. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  587. data_input->GetName().c_str(), data_input->GetType().c_str(),
  588. switch_case->GetName().c_str(), switch_case->GetType().c_str());
  589. GELOGE(FAILED, "[Add][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
  590. data_input->GetName().c_str(), data_input->GetType().c_str(),
  591. switch_case->GetName().c_str(), switch_case->GetType().c_str());
  592. return FAILED;
  593. }
  594. // Add SwitchCase control output
  595. for (const OutDataAnchorPtr &out_data_anchor : switch_n_node->GetAllOutDataAnchors()) {
  596. for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  597. NodePtr data_output = peer_in_anchor->GetOwnerNode();
  598. if ((GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor) != GRAPH_SUCCESS) ||
  599. (GraphUtils::AddEdge(peer_data_anchor, peer_in_anchor) != GRAPH_SUCCESS)) {
  600. REPORT_CALL_ERROR("E19999", "Remove edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) or "
  601. "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
  602. switch_n_node->GetName().c_str(), switch_n_node->GetType().c_str(), out_data_anchor->GetIdx(),
  603. data_output->GetName().c_str(), data_output->GetType().c_str(), peer_in_anchor->GetIdx(),
  604. data_input->GetName().c_str(), data_input->GetType().c_str(), peer_data_anchor->GetIdx(),
  605. data_output->GetName().c_str(), data_output->GetType().c_str(), peer_in_anchor->GetIdx());
  606. GELOGE(FAILED, "[Replace][Edge] failed, Remove edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) or "
  607. "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%d) failed",
  608. switch_n_node->GetName().c_str(), switch_n_node->GetType().c_str(), out_data_anchor->GetIdx(),
  609. data_output->GetName().c_str(), data_output->GetType().c_str(), peer_in_anchor->GetIdx(),
  610. data_input->GetName().c_str(), data_input->GetType().c_str(), peer_data_anchor->GetIdx(),
  611. data_output->GetName().c_str(), data_output->GetType().c_str(), peer_in_anchor->GetIdx());
  612. return FAILED;
  613. }
  614. if (GraphUtils::AddEdge(switch_case->GetOutControlAnchor(), data_output->GetInControlAnchor()) != GRAPH_SUCCESS) {
  615. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  616. switch_case->GetName().c_str(), switch_case->GetType().c_str(),
  617. data_output->GetName().c_str(), data_output->GetType().c_str());
  618. GELOGE(FAILED, "[Add][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
  619. switch_case->GetName().c_str(), switch_case->GetType().c_str(),
  620. data_output->GetName().c_str(), data_output->GetType().c_str());
  621. return FAILED;
  622. }
  623. }
  624. }
  625. GE_CHK_STATUS_RET(MoveCtrlEdges(switch_n_node, switch_case),
  626. "[Move][CtrlEdges] from %s to %s failed.", switch_n_node->GetName().c_str(),
  627. switch_case->GetName().c_str());
  628. bypass_nodes_.emplace_back(switch_n_node);
  629. GELOGI("Bypass SwitchN node %s success.", switch_n_node->GetName().c_str());
  630. return SUCCESS;
  631. }
  632. ///
  633. /// @brief Attach stream_label & batch_label for batch branch
  634. /// @param [in] switch_case_node
  635. /// @return Status
  636. ///
  637. Status MultiBatchPass::AttachLabel(const NodePtr &switch_case_node) {
  638. std::vector<std::string> stream_label_list;
  639. for (uint32_t i = 0; i < static_cast<uint32_t>(batch_head_nodes_.size()); i++) {
  640. if (AttachBatchLabel(i) != SUCCESS) {
  641. GELOGE(FAILED, "[Attach][BatchLabel] failed, batch_idx=%u", i);
  642. return FAILED;
  643. }
  644. const std::string &stream_label = "stream_label_batch_" + std::to_string(i);
  645. if (AttachStreamLabel(i, stream_label) != SUCCESS) {
  646. GELOGE(FAILED, "[Attach][StreamLabel] failed, stream_label=%s, batch_idx=%u", stream_label.c_str(), i);
  647. return FAILED;
  648. }
  649. stream_label_list.emplace_back(stream_label);
  650. }
  651. return switch_case_node == nullptr ? SUCCESS : SetActiveLabelList(switch_case_node, stream_label_list);
  652. }
  653. ///
  654. /// @brief Attach batch_label for batch branch
  655. /// @param [in] batch_idx
  656. /// @return Status
  657. ///
  658. Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) {
  659. std::stack<NodePtr> nodes;
  660. for (const auto &node : batch_head_nodes_[batch_idx]) {
  661. nodes.push(node);
  662. }
  663. const std::string &batch_label = "Batch_" + std::to_string(batch_idx);
  664. std::unordered_set<NodePtr> handled_nodes;
  665. while (!nodes.empty()) {
  666. NodePtr cur_node = nodes.top();
  667. nodes.pop();
  668. if (handled_nodes.count(cur_node) > 0) {
  669. continue;
  670. }
  671. OpDescPtr cur_desc = cur_node->GetOpDesc();
  672. GE_CHECK_NOTNULL(cur_desc);
  673. if (cur_desc->HasAttr(ATTR_NAME_BATCH_LABEL)) {
  674. std::string tmp_label;
  675. if (!AttrUtils::GetStr(cur_desc, ATTR_NAME_BATCH_LABEL, tmp_label)) {
  676. REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_NAME_BATCH_LABEL.c_str(),
  677. cur_desc->GetName().c_str(), cur_desc->GetType().c_str());
  678. GELOGE(FAILED, "[Get][Attr] %s from op:%s(%s) failed", ATTR_NAME_BATCH_LABEL.c_str(),
  679. cur_desc->GetName().c_str(), cur_desc->GetType().c_str());
  680. return FAILED;
  681. }
  682. if (tmp_label != batch_label) {
  683. REPORT_INNER_ERROR("E19999", "Attr:%s from op:%s(%s) value:%s not equal to expect:%s, check invalid",
  684. ATTR_NAME_BATCH_LABEL.c_str(), cur_desc->GetName().c_str(), cur_desc->GetType().c_str(),
  685. tmp_label.c_str(), batch_label.c_str());
  686. GELOGE(FAILED, "[Check][Param] Attr:%s from op:%s(%s) value:%s not equal to expect:%s",
  687. ATTR_NAME_BATCH_LABEL.c_str(), cur_desc->GetName().c_str(), cur_desc->GetType().c_str(),
  688. tmp_label.c_str(), batch_label.c_str());
  689. return FAILED;
  690. }
  691. }
  692. GELOGD("Attach batch_label %s to node %s.", batch_label.c_str(), cur_desc->GetName().c_str());
  693. if (!AttrUtils::SetStr(cur_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
  694. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_BATCH_LABEL.c_str(),
  695. cur_desc->GetName().c_str(), cur_desc->GetType().c_str());
  696. GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_BATCH_LABEL.c_str(),
  697. cur_desc->GetName().c_str(), cur_desc->GetType().c_str());
  698. return FAILED;
  699. }
  700. for (const auto &out_node : cur_node->GetOutAllNodes()) {
  701. OpDescPtr op_desc = out_node->GetOpDesc();
  702. GE_CHECK_NOTNULL(op_desc);
  703. const std::string &type = op_desc->GetType();
  704. if ((type == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) {
  705. continue;
  706. }
  707. if (type == NETOUTPUT) {
  708. REPORT_CALL_ERROR("E19999", "SReach net_output without Merge, cur_node:%s(%s), check invalid",
  709. cur_node->GetName().c_str(), cur_node->GetType().c_str());
  710. GELOGE(FAILED, "[Check][Param] Reach net_output without Merge, cur_node:%s.", cur_node->GetName().c_str());
  711. return FAILED;
  712. }
  713. nodes.push(out_node);
  714. }
  715. (void)handled_nodes.insert(cur_node);
  716. }
  717. return SUCCESS;
  718. }
  719. ///
  720. /// @brief Attach stream_label for batch branch
  721. /// @param [in] batch_idx
  722. /// @param [in] stream_label
  723. /// @return Status
  724. ///
  725. Status MultiBatchPass::AttachStreamLabel(uint32_t batch_idx, const std::string &stream_label) {
  726. std::stack<NodePtr> nodes;
  727. for (const auto &node : batch_head_nodes_[batch_idx]) {
  728. nodes.push(node);
  729. }
  730. std::unordered_set<NodePtr> handled_nodes;
  731. while (!nodes.empty()) {
  732. NodePtr cur_node = nodes.top();
  733. nodes.pop();
  734. OpDescPtr cur_desc = cur_node->GetOpDesc();
  735. GE_CHECK_NOTNULL(cur_desc);
  736. if ((handled_nodes.count(cur_node) > 0) || (cur_desc->HasAttr(ATTR_NAME_STREAM_LABEL))) {
  737. continue;
  738. }
  739. GELOGD("Attach stream_label %s to node %s.", stream_label.c_str(), cur_desc->GetName().c_str());
  740. if (SetStreamLabel(cur_node, stream_label) != SUCCESS) {
  741. REPORT_CALL_ERROR("E19999", "Set stream_label:%s to op:%s(%s) failed",
  742. stream_label.c_str(), cur_node->GetName().c_str(), cur_node->GetType().c_str());
  743. GELOGE(FAILED, "[Set][StreamLabel] %s to op:%s(%s) failed",
  744. stream_label.c_str(), cur_node->GetName().c_str(), cur_node->GetType().c_str());
  745. return FAILED;
  746. }
  747. for (const auto &out_node : cur_node->GetOutAllNodes()) {
  748. nodes.push(out_node);
  749. }
  750. (void)handled_nodes.insert(cur_node);
  751. }
  752. return SUCCESS;
  753. }
  754. ///
  755. /// @brief move edges from old_node to new_node
  756. /// @param [in] old_node
  757. /// @param [in] new_node
  758. /// @return Status
  759. ///
  760. Status MultiBatchPass::MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new_node) {
  761. if (old_node == new_node) {
  762. return SUCCESS;
  763. }
  764. for (const NodePtr &in_ctrl_node : old_node->GetInControlNodes()) {
  765. GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), old_node->GetInControlAnchor()),
  766. "[Remove][ControlEdge] between %s and %s failed.",
  767. in_ctrl_node->GetName().c_str(), old_node->GetName().c_str());
  768. GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), new_node->GetInControlAnchor()),
  769. "[Add][ControlEdge] between %s and %s failed.",
  770. in_ctrl_node->GetName().c_str(), new_node->GetName().c_str());
  771. }
  772. for (const NodePtr &out_ctrl_node : old_node->GetOutControlNodes()) {
  773. GE_CHK_STATUS(GraphUtils::RemoveEdge(old_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()),
  774. "[Remove][ControlEdge] between %s and %s failed.",
  775. old_node->GetName().c_str(), out_ctrl_node->GetName().c_str());
  776. GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()),
  777. "[Add][ControlEdge] between %s and %s failed.",
  778. new_node->GetName().c_str(), out_ctrl_node->GetName().c_str());
  779. }
  780. return SUCCESS;
  781. }
  782. ///
  783. /// @brief attach stream_label & batch_label without change structure of graph
  784. /// @param [in] batch_num
  785. /// @return void
  786. ///
  787. Status MultiBatchPass::AttachLabelOnly(uint32_t batch_num) {
  788. std::vector<NodePtr> output_nodes;
  789. for (uint32_t i = 0; i < batch_num; i++) {
  790. output_nodes.clear();
  791. for (const NodePtr &node : switch_n_nodes_) {
  792. // idx is promised to be valid
  793. OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(i);
  794. GE_CHECK_NOTNULL(out_data_anchor);
  795. for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  796. output_nodes.emplace_back(peer_in_anchor->GetOwnerNode());
  797. }
  798. }
  799. batch_head_nodes_.emplace_back(output_nodes);
  800. }
  801. return AttachLabel(nullptr);
  802. }
  803. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示