You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_batch_clone_pass.cc 29 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/multi_batch_clone_pass.h"
  17. #include "common/formats/utils/formats_trans_utils.h"
  18. #include "common/ge/ge_util.h"
  19. #include "graph/common/local_context.h"
  20. #include "graph/preprocess/multi_batch_options.h"
  21. #include "graph/utils/node_utils.h"
  22. #include "graph/utils/op_desc_utils.h"
  23. #include "graph/utils/tensor_utils.h"
  24. #include "graph/utils/type_utils.h"
  25. #include "register/op_registry.h"
  26. namespace ge {
  27. namespace {
  28. constexpr uint8_t kDataInIndex = 0;
  29. constexpr uint8_t kDataOutIndex = 0;
  30. constexpr uint8_t kCaseArgIndex = 1;
  31. const std::string kMultiBatchCaseNode = "ascend_mbatch_shape_case";
  32. const std::string kMultiBatchDataNode = "ascend_mbatch_shape_data";
  33. const std::string kMultiBatchConstNode = "ascend_mbatch_shape_const";
  34. const std::string kMultiBatchMapIndexNode = "ascend_mbatch_shape_mapindex";
  35. const std::string kMultiBatchNodePostfix = "_ascend_mbatch_batch_";
  36. } // namespace
  37. Status MultiBatchClonePass::Run(ComputeGraphPtr graph) {
  38. if (graph->GetParentGraph() != nullptr) {
  39. GELOGD("Subgraph %s skip the MultiBatchClonePass", graph->GetName().c_str());
  40. return SUCCESS;
  41. }
  42. if (!multibatch::InitDynamicParams(batch_shapes_)) {
  43. GELOGD("There is no multi-batch options, no need clone multi-batch graph");
  44. return SUCCESS;
  45. }
  46. GELOGD("Begin to run Multi-batch clone on graph: %s", graph->GetName().c_str());
  47. GE_CHK_STATUS_RET(multibatch::CheckDynamicParams(batch_shapes_), "Invalid multi-batch param");
  48. if (CollectIoNodes(graph) != SUCCESS) {
  49. GELOGE(INTERNAL_ERROR, "Collect input output nodes failed");
  50. return INTERNAL_ERROR;
  51. }
  52. // parser data dynamic info from atc parameter --input_shape
  53. if (multibatch::ParserDataToDynmaicInfo(batch_shapes_, GetLocalOmgContext().user_input_dims,
  54. data_to_dynamic_info_) != SUCCESS) {
  55. GELOGE(PARAM_INVALID, "Parse each data's own dynamic info failed");
  56. return PARAM_INVALID;
  57. }
  58. (void)AttrUtils::GetStr(graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_);
  59. ComputeGraphPtr branch = MakeShared<ComputeGraph>(graph->GetName());
  60. if (branch == nullptr) {
  61. GELOGE(OUT_OF_MEMORY, "Create multi-batch graph failed");
  62. return OUT_OF_MEMORY;
  63. }
  64. (void)AttrUtils::SetStr(branch, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_);
  65. graph->InValid(); // Will modify, need topological again.
  66. graph->Swap(*branch);
  67. if (CreateRootGraph(graph) != SUCCESS) {
  68. return FAILED;
  69. }
  70. if (CreateSubgraphs(graph, branch) != SUCCESS) {
  71. return FAILED;
  72. }
  73. GE_CHK_STATUS_RET(PruneDirectOutput(graph), "Prune direct output failed");
  74. GELOGD("MultiBatchClonePass Leave");
  75. return SUCCESS;
  76. }
  77. ///
  78. /// @ingroup ge
  79. /// @brief Collect input output node from original graph.
  80. /// @param [in] const ComputeGraphPtr &graph: original graph.
  81. /// @return 0: SUCCESS / others: FAILED
  82. ///
  83. Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) {
  84. for (const auto &node : graph->GetDirectNode()) {
  85. if (node->GetType() == DATA) {
  86. all_data_nodes_.emplace_back(node);
  87. } else if (node->GetType() == CONSTANT) {
  88. all_const_nodes_.emplace_back(node);
  89. } else if (node->GetType() == NETOUTPUT) {
  90. all_output_nodes_.emplace_back(node);
  91. }
  92. // If the node save as input/output node, delete record.
  93. (void)graph->RemoveInputNode(node);
  94. (void)graph->RemoveOutputNode(node);
  95. }
  96. if (all_data_nodes_.empty() || all_output_nodes_.size() != 1) {
  97. GELOGE(FAILED, "data nodes: %zu, output nodes: %zu", all_data_nodes_.size(), all_output_nodes_.size());
  98. return FAILED;
  99. }
  100. int64_t data_index = 0;
  101. for (size_t i = 0; i < all_data_nodes_.size(); ++i) {
  102. const auto &op_desc = all_data_nodes_[i]->GetOpDesc();
  103. if (!AttrUtils::GetInt(op_desc, ATTR_NAME_INDEX, data_index)) {
  104. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, i);
  105. }
  106. }
  107. const auto &output = all_output_nodes_[0];
  108. for (size_t i = 0; i < output->GetAllInDataAnchorsSize(); ++i) {
  109. const auto in_anchor = output->GetInDataAnchor(i);
  110. const auto out_anchor = in_anchor->GetPeerOutAnchor();
  111. const auto data_node = out_anchor->GetOwnerNode();
  112. if (data_node->GetType() == DATA) {
  113. direct_output_[i] = data_node->GetName();
  114. GE_CHK_GRAPH_STATUS_RET(
  115. GraphUtils::RemoveEdge(data_node->GetOutDataAnchor(kDataOutIndex), output->GetInDataAnchor(i)),
  116. "Remove edge failed");
  117. }
  118. }
  119. return SUCCESS;
  120. }
  121. ///
  122. /// @ingroup ge
  123. /// @brief Create nodes for root graph.
  124. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  125. /// @return 0: SUCCESS / others: FAILED
  126. ///
  127. Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) {
  128. uint32_t input_num = all_data_nodes_.size() + all_const_nodes_.size();
  129. uint32_t output_num = all_output_nodes_[0]->GetAllInDataAnchorsSize();
  130. OpDescBuilder op_builder(kMultiBatchCaseNode, CASE);
  131. op_builder.AddInput("branch_index").AddDynamicInput("input", input_num).AddDynamicOutput("output", output_num);
  132. const OpDescPtr op_desc = op_builder.Build();
  133. if (op_desc == nullptr) {
  134. GELOGE(OUT_OF_MEMORY, "Create multi-batch case desc failed");
  135. return OUT_OF_MEMORY;
  136. }
  137. op_desc->RegisterSubgraphIrName("branches", kDynamic);
  138. case_node_ = graph->AddNode(op_desc);
  139. if (case_node_ == nullptr) {
  140. GELOGE(OUT_OF_MEMORY, "Create multi-batch case node failed");
  141. return OUT_OF_MEMORY;
  142. }
  143. uint32_t batch_num = static_cast<uint32_t>(batch_shapes_.size());
  144. if (!AttrUtils::SetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) {
  145. GELOGE(FAILED, "Set attr ATTR_NAME_BATCH_NUM failed, Case: %s.", op_desc->GetName().c_str());
  146. return FAILED;
  147. }
  148. for (uint32_t i = 0; i < batch_num; i++) {
  149. const std::string &attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i);
  150. if (!AttrUtils::SetListInt(op_desc, attr_name, batch_shapes_[i])) {
  151. GELOGE(FAILED, "Set attr ATTR_NAME_PRED_VALUE failed, Case: %s.", op_desc->GetName().c_str());
  152. return FAILED;
  153. }
  154. }
  155. std::vector<std::string> data_name_order;
  156. for (auto &item : GetLocalOmgContext().user_input_dims) {
  157. data_name_order.push_back(item.first);
  158. }
  159. if (!AttrUtils::SetListStr(op_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, data_name_order)) {
  160. GELOGE(FAILED, "Failed to add user designate shape order attr on case node %s",
  161. op_desc->GetName().c_str());
  162. return FAILED;
  163. }
  164. GE_CHK_STATUS_RET(multibatch::StampDynamicType(op_desc), "Set dynamic type failed");
  165. GE_CHK_STATUS_RET(CreateIndexNode(graph), "Create index node failed");
  166. GE_CHK_STATUS_RET(CreateInputNode(graph), "Create input node failed");
  167. GE_CHK_STATUS_RET(CreateConstNode(graph), "Create const node failed");
  168. GE_CHK_STATUS_RET(CreateOutputNode(graph), "Create output node failed");
  169. return SUCCESS;
  170. }
  171. ///
  172. /// @ingroup ge
  173. /// @brief Create index data node for root graph.
  174. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  175. /// @param [in] NodePtr node: index data node.
  176. /// @return 0: SUCCESS / others: FAILED
  177. ///
  178. Status MultiBatchClonePass::CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &node) {
  179. const OpDescPtr data_desc = MakeShared<OpDesc>(kMultiBatchDataNode, DATA);
  180. if (data_desc == nullptr) {
  181. GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed");
  182. return FAILED;
  183. }
  184. GeTensorDesc data_tensor(GeShape({static_cast<int64_t>(batch_shapes_[0].size())}), FORMAT_ND, DT_INT32);
  185. if (data_desc->AddInputDesc(data_tensor) != GRAPH_SUCCESS) {
  186. GELOGE(FAILED, "Add input desc failed");
  187. return FAILED;
  188. }
  189. if (data_desc->AddOutputDesc(data_tensor) != GRAPH_SUCCESS) {
  190. GELOGE(FAILED, "Add output desc failed");
  191. return FAILED;
  192. }
  193. size_t data_index = all_data_nodes_.size();
  194. (void)AttrUtils::SetInt(data_desc, ATTR_NAME_INDEX, data_index);
  195. (void)AttrUtils::SetBool(data_desc, ATTR_INSERT_BY_MBATCH, true);
  196. node = graph->AddNode(data_desc);
  197. if (node == nullptr) {
  198. GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed");
  199. return OUT_OF_MEMORY;
  200. }
  201. return SUCCESS;
  202. }
  203. ///
  204. /// @ingroup ge
  205. /// @brief Create index const node for root graph.
  206. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  207. /// @param [in] NodePtr node: index const node.
  208. /// @return 0: SUCCESS / others: FAILED
  209. ///
  210. Status MultiBatchClonePass::CreateIndexConstNode(const ComputeGraphPtr &graph, NodePtr &node) {
  211. const OpDescPtr const_desc = MakeShared<OpDesc>(kMultiBatchConstNode, CONSTANT);
  212. if (const_desc == nullptr) {
  213. GELOGE(OUT_OF_MEMORY, "Create multi-batch const node failed");
  214. return FAILED;
  215. }
  216. int64_t count = batch_shapes_.size() * batch_shapes_[0].size();
  217. std::unique_ptr<int32_t[]> addr(new (std::nothrow) int32_t[count]);
  218. GE_CHECK_NOTNULL(addr);
  219. size_t i = 0;
  220. for (auto &batch_shape : batch_shapes_) {
  221. for (int64_t dim : batch_shape) {
  222. addr[i++] = static_cast<int32_t>(dim);
  223. }
  224. }
  225. GeTensorDesc const_tensor(GeShape({count}), FORMAT_ND, DT_INT32);
  226. GeTensor tensor(const_tensor);
  227. (void)tensor.SetData(reinterpret_cast<uint8_t *>(addr.get()), count * sizeof(int32_t));
  228. if (!AttrUtils::SetTensor(const_desc, ATTR_NAME_WEIGHTS, tensor)) {
  229. GELOGE(OUT_OF_MEMORY, "Failed to init tensor value for const %s", const_desc->GetName().c_str());
  230. return FAILED;
  231. }
  232. if (const_desc->AddOutputDesc(const_tensor) != GRAPH_SUCCESS) {
  233. GELOGE(OUT_OF_MEMORY, "Failed to add output desc for const node %s", const_desc->GetName().c_str());
  234. return FAILED;
  235. }
  236. node = graph->AddNode(const_desc);
  237. if (node == nullptr) {
  238. GELOGE(OUT_OF_MEMORY, "Create multi-batch const node failed");
  239. return OUT_OF_MEMORY;
  240. }
  241. return SUCCESS;
  242. }
  243. ///
  244. /// @ingroup ge
  245. /// @brief Create index node for root graph.
  246. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  247. /// @return 0: SUCCESS / others: FAILED
  248. ///
  249. Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) {
  250. // Data --> MapIndex --> Case
  251. NodePtr data_node;
  252. GE_CHK_STATUS_RET(CreateIndexDataNode(graph, data_node), "Create data node failed");
  253. NodePtr const_node;
  254. GE_CHK_STATUS_RET(CreateIndexConstNode(graph, const_node), "Create const node failed");
  255. OpDescBuilder op_builder(kMultiBatchMapIndexNode, "MapIndex");
  256. op_builder.AddInput("x", data_node->GetOpDesc()->GetOutputDesc(0))
  257. .AddInput("data_seq", const_node->GetOpDesc()->GetOutputDesc(0))
  258. .AddOutput("y", GeTensorDesc(GeShape(), FORMAT_ND, DT_INT32));
  259. const OpDescPtr op_desc = op_builder.Build();
  260. if (op_desc == nullptr) {
  261. GELOGE(OUT_OF_MEMORY, "Create multi-batch index desc failed");
  262. return FAILED;
  263. }
  264. NodePtr index_node = graph->AddNode(op_desc);
  265. if (index_node == nullptr) {
  266. GELOGE(OUT_OF_MEMORY, "Create multi-batch index node failed");
  267. return OUT_OF_MEMORY;
  268. }
  269. if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), index_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  270. GELOGE(FAILED, "Failed to add edge between node:%s to MapIndex:%s", data_node->GetName().c_str(),
  271. index_node->GetName().c_str());
  272. return FAILED;
  273. }
  274. if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), index_node->GetInDataAnchor(1)) != GRAPH_SUCCESS) {
  275. GELOGE(FAILED, "Failed to add edge between node:%s to MapIndex:%s", const_node->GetName().c_str(),
  276. index_node->GetName().c_str());
  277. return FAILED;
  278. }
  279. if (GraphUtils::AddEdge(index_node->GetOutDataAnchor(0), case_node_->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  280. GELOGE(FAILED, "Failed to add edge between MapIndex:%s to Case:%s", index_node->GetName().c_str(),
  281. case_node_->GetName().c_str());
  282. return FAILED;
  283. }
  284. return SUCCESS;
  285. }
  286. ///
  287. /// @ingroup ge
  288. /// @brief Create input node for root graph.
  289. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  290. /// @return 0: SUCCESS / others: FAILED
  291. ///
  292. Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) {
  293. // Data --> Case
  294. std::vector<NodePtr> all_data_nodes;
  295. const size_t arg_index = kCaseArgIndex;
  296. for (size_t i = 0; i < all_data_nodes_.size(); ++i) {
  297. const auto &node = all_data_nodes_[i];
  298. const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc());
  299. if (op_desc == nullptr) {
  300. GELOGE(OUT_OF_MEMORY, "Create multi-batch Data node failed, name: %s", node->GetName().c_str());
  301. return FAILED;
  302. }
  303. if (GraphUtils::CopyTensorAttrs(op_desc, node) != GRAPH_SUCCESS) {
  304. return FAILED;
  305. }
  306. op_desc->SetName(node->GetName());
  307. const NodePtr &data = graph->AddNode(op_desc);
  308. GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str());
  309. if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) {
  310. GELOGE(FAILED, "Failed to add edge between Data:%s to Case:%s",
  311. data->GetName().c_str(), case_node_->GetName().c_str());
  312. return FAILED;
  313. }
  314. if (SetMaxShapeToData(data) != SUCCESS) {
  315. return FAILED;
  316. }
  317. all_data_nodes.emplace_back(data);
  318. }
  319. all_data_nodes_.swap(all_data_nodes);
  320. return SUCCESS;
  321. }
  322. ///
  323. /// @ingroup ge
  324. /// @brief Create Const node for root graph.
  325. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  326. /// @return 0: SUCCESS / others: FAILED
  327. ///
  328. Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) {
  329. // Const --> Case
  330. std::vector<NodePtr> all_const_nodes;
  331. const size_t arg_index = kCaseArgIndex + all_data_nodes_.size();
  332. for (size_t i = 0; i < all_const_nodes_.size(); ++i) {
  333. const auto &node = all_const_nodes_[i];
  334. const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc());
  335. if (op_desc == nullptr) {
  336. GELOGE(OUT_OF_MEMORY, "Create multi-batch Const node failed, name: %s", node->GetName().c_str());
  337. return FAILED;
  338. }
  339. op_desc->SetName(node->GetName());
  340. if (GraphUtils::CopyTensorAttrs(op_desc, node) != GRAPH_SUCCESS) {
  341. return FAILED;
  342. }
  343. const NodePtr &data = graph->AddNode(op_desc);
  344. GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str());
  345. if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) {
  346. GELOGE(FAILED, "Failed to add edge between Const:%s to Case:%s",
  347. data->GetName().c_str(), case_node_->GetName().c_str());
  348. return FAILED;
  349. }
  350. all_const_nodes.emplace_back(data);
  351. }
  352. size_t data_index = all_data_nodes_.size();
  353. for (size_t i = 0; i < all_const_nodes_.size(); ++i, ++data_index) { // Trans subgraph Const to Data.
  354. const OpDescPtr &op_desc = all_const_nodes_[i]->GetOpDesc();
  355. op_desc->SetType(DATA);
  356. (void)op_desc->DelAttr(ATTR_NAME_WEIGHTS); // Delete weight.
  357. // Const no InputDesc, Data need InputDesc.
  358. (void)op_desc->AddInputDesc(op_desc->GetOutputDesc(kDataOutIndex));
  359. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index);
  360. (void)NodeUtils::AppendInputAnchor(all_const_nodes_[i], 1);
  361. }
  362. all_const_nodes_.swap(all_const_nodes);
  363. return SUCCESS;
  364. }
  365. ///
  366. /// @ingroup ge
  367. /// @brief Create output node for root graph.
  368. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  369. /// @return 0: SUCCESS / others: FAILED
  370. ///
  371. Status MultiBatchClonePass::CreateOutputNode(const ComputeGraphPtr &graph) {
  372. const auto &output = all_output_nodes_[0];
  373. const OpDescPtr op_desc = AttrUtils::CopyOpDesc(output->GetOpDesc());
  374. if (op_desc == nullptr) {
  375. GELOGE(OUT_OF_MEMORY, "Create multi-batch output node failed");
  376. return FAILED;
  377. }
  378. if (GraphUtils::CopyTensorAttrs(op_desc, output) != GRAPH_SUCCESS) {
  379. return FAILED;
  380. }
  381. op_desc->SetName(output->GetName());
  382. const NodePtr &node = graph->AddNode(op_desc);
  383. GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str());
  384. for (size_t i = 0; i < case_node_->GetAllOutDataAnchorsSize(); ++i) {
  385. const auto it = direct_output_.find(i);
  386. if (it == direct_output_.end()) {
  387. if (GraphUtils::AddEdge(case_node_->GetOutDataAnchor(i), node->GetInDataAnchor(i)) != GRAPH_SUCCESS) {
  388. GELOGE(FAILED, "Failed to add edge between Case:%s to NetOutput:%s",
  389. case_node_->GetName().c_str(), node->GetName().c_str());
  390. return FAILED;
  391. }
  392. } else {
  393. const auto data_node = graph->FindNode(it->second);
  394. if (data_node == nullptr) {
  395. GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "Data node:%s not found", it->second.c_str());
  396. return GE_GRAPH_GRAPH_NODE_NULL;
  397. }
  398. if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(kDataOutIndex), node->GetInDataAnchor(i)) != GRAPH_SUCCESS) {
  399. GELOGE(FAILED, "Failed to add edge between Data:%s to NetOutput:%s",
  400. data_node->GetName().c_str(), node->GetName().c_str());
  401. return FAILED;
  402. }
  403. }
  404. }
  405. all_output_nodes_.clear();
  406. all_output_nodes_.emplace_back(node);
  407. return SUCCESS;
  408. }
  409. ///
  410. /// @ingroup ge
  411. /// @brief Set max shape to Data node in root graph.
  412. /// @param [in] const NodePtr &data: data in Root/Case graph.
  413. /// @return 0: SUCCESS / others: FAILED
  414. ///
  415. Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) {
  416. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  417. auto data_name = data->GetName();
  418. const auto &dims = data_shape.GetDims();
  419. if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) {
  420. return SUCCESS;
  421. }
  422. (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims());
  423. GeTensorDesc tensor(NodeUtils::GetOutputDesc(*data, kDataOutIndex));
  424. std::vector<std::string> input_dims_str;
  425. for (size_t i = 0; i < batch_shapes_.size(); ++i) {
  426. auto shape = data_shape;
  427. auto ret = multibatch::CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape);
  428. if (ret != SUCCESS) {
  429. GELOGE(ret, "Failed to calculate the shape for data node %s, the shape may not match", data->GetName().c_str());
  430. return ret;
  431. }
  432. tensor.SetShape(shape);
  433. int64_t tensor_size = 0;
  434. (void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size);
  435. string input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" +
  436. TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" +
  437. std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" +
  438. formats::JoinToString(tensor.GetShape().GetDims());
  439. input_dims_str.emplace_back(input_str);
  440. }
  441. (void)AttrUtils::SetListStr(data->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str);
  442. size_t max_shape_index = 0;
  443. int64_t max_size = 0;
  444. for (size_t i = 0; i < batch_shapes_.size(); ++i) {
  445. int64_t size = 1;
  446. for (auto dim : data_to_dynamic_info_.at(data_name).at(i)) {
  447. if (INT64_MAX / dim < size) {
  448. GELOGE(PARAM_INVALID, "The shape %s size overflow",
  449. formats::ShapeToString(data_to_dynamic_info_.at(data_name).at(i)).c_str());
  450. return PARAM_INVALID;
  451. }
  452. size *= dim;
  453. }
  454. if (size > max_size) {
  455. max_size = size;
  456. max_shape_index = i;
  457. }
  458. }
  459. return SetShapeToData(data_to_dynamic_info_.at(data_name).at(max_shape_index), data, data_shape);
  460. }
  461. ///
  462. /// @ingroup ge
  463. /// @brief Update Data node in Subgraph.
  464. /// @param [in] const NodePtr &data: data in Subgraph.
  465. /// @param [in] size_t index: The batch index.
  466. /// @return 0: SUCCESS / others: FAILED
  467. ///
  468. Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t index) {
  469. int node_index = -1;
  470. if (!AttrUtils::GetInt(data->GetOpDesc(), ATTR_NAME_INDEX, node_index)) {
  471. GELOGE(FAILED, "Failed to get index from data[%s]", data->GetName().c_str());
  472. return FAILED;
  473. }
  474. int parent_index = node_index + 1;
  475. if (!AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
  476. GELOGE(FAILED, "Failed to set parent index for node %s", data->GetName().c_str());
  477. return FAILED;
  478. }
  479. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  480. const auto &dims = data_shape.GetDims();
  481. if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) {
  482. return SUCCESS;
  483. }
  484. (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims());
  485. auto data_name = data->GetName();
  486. size_t pos = data_name.find(kMultiBatchNodePostfix);
  487. if (pos == string::npos) {
  488. GELOGE(FAILED, "Cannot find key string [%s] of multi-batch in name of virtual input node, node name: %s.",
  489. kMultiBatchNodePostfix.c_str(), data_name.c_str());
  490. return FAILED;
  491. }
  492. auto parent_name = data_name.substr(0, pos);
  493. return SetShapeToData(data_to_dynamic_info_.at(parent_name).at(index), data, data_shape);
  494. }
  495. ///
  496. /// @ingroup ge
  497. /// @brief Set max shape to Data node in root graph.
  498. /// @param [in] const std::vector<int64_t> &shapes: dims of shape.
  499. /// @param [in] const NodePtr &data: data in Root/Case graph.
  500. /// @param [in] GeShape &data_shape: dims of data node.
  501. /// @return 0: SUCCESS / others: FAILED
  502. ///
  503. Status MultiBatchClonePass::SetShapeToData(const vector<int64_t> &shapes, const NodePtr &data, GeShape &data_shape) {
  504. // must not be error, the calc result has been checked in function InsertSwitchNForData
  505. if (multibatch::CalcShape(shapes, data_shape) != SUCCESS) {
  506. return INTERNAL_ERROR;
  507. }
  508. if (NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape) != GRAPH_SUCCESS) {
  509. GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str());
  510. return INTERNAL_ERROR;
  511. }
  512. if (NodeUtils::UpdateOutputShape(*data, kDataOutIndex, data_shape) != GRAPH_SUCCESS) {
  513. GELOGE(INTERNAL_ERROR, "Failed to update output shape for data %s", data->GetName().c_str());
  514. return INTERNAL_ERROR;
  515. }
  516. GELOGI("Update %s input/output shape to %s", data->GetName().c_str(), formats::ShapeToString(data_shape).c_str());
  517. return SUCCESS;
  518. }
  519. ///
  520. /// @ingroup ge
  521. /// @brief Create nodes for root graph.
  522. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  523. /// @param [in] const ComputeGraphPtr &branch: original graph.
  524. /// @return 0: SUCCESS / others: FAILED
  525. ///
  526. Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch) {
  527. const auto &op_desc = case_node_->GetOpDesc();
  528. for (size_t i = 0; i < batch_shapes_.size(); ++i) {
  529. std::vector<NodePtr> input_nodes;
  530. std::vector<NodePtr> output_nodes;
  531. const std::string postfix = kMultiBatchNodePostfix + std::to_string(i);
  532. ComputeGraphPtr subgraph = (i == 0) ? branch : GraphUtils::CloneGraph(branch, postfix, input_nodes, output_nodes);
  533. if (subgraph == nullptr) {
  534. GELOGE(FAILED, "Create multi-batch case node failed");
  535. return FAILED;
  536. }
  537. subgraph->SetName("Batch_" + std::to_string(i));
  538. subgraph->SetParentNode(case_node_);
  539. subgraph->SetParentGraph(graph);
  540. graph->AddSubgraph(subgraph->GetName(), subgraph);
  541. all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT);
  542. GE_CHK_STATUS_RET(UpdateSubgraphOutput(all_branch_output_[subgraph]),
  543. "Update %s failed", all_branch_output_[subgraph]->GetName().c_str());
  544. const string key_name = "branches" + std::to_string(i);
  545. op_desc->AddSubgraphName(key_name);
  546. op_desc->SetSubgraphInstanceName(i, subgraph->GetName());
  547. for (const auto &data : input_nodes) {
  548. GE_CHK_STATUS_RET(UpdateSubgraphData(data, i), "Update %s failed", subgraph->GetName().c_str());
  549. }
  550. }
  551. // Origninal graph take as first subgraph, update node name.
  552. for (const auto &n : branch->GetDirectNode()) {
  553. const auto &op_desc = n->GetOpDesc();
  554. op_desc->SetName(n->GetName() + kMultiBatchNodePostfix + "0");
  555. if (n->GetType() == DATA) {
  556. GE_CHK_STATUS_RET(UpdateSubgraphData(n, 0), "Update %s failed", branch->GetName().c_str());
  557. }
  558. }
  559. return SUCCESS;
  560. }
  561. ///
  562. /// @ingroup ge
  563. /// @brief Update output_node in Subgraph.
  564. /// @param [in] const NodePtr &output_node: output_node in Subgraph.
  565. /// @return 0: SUCCESS / others: FAILED
  566. ///
  567. Status MultiBatchClonePass::UpdateSubgraphOutput(const NodePtr &output_node) {
  568. const auto &op_desc = output_node->GetOpDesc();
  569. GE_CHECK_NOTNULL(op_desc);
  570. for (size_t index = 0; index < op_desc->GetInputsSize(); ++index) {
  571. GeTensorDescPtr tensor = op_desc->MutableInputDesc(index);
  572. GE_CHECK_NOTNULL(tensor);
  573. if (!AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) {
  574. GELOGE(FAILED, "Failed to set parent index for node %s", output_node->GetName().c_str());
  575. return FAILED;
  576. }
  577. }
  578. return SUCCESS;
  579. }
  580. ///
  581. /// @ingroup ge
  582. /// @brief Remove subgraph suspend output anchor.
  583. /// @param [in] ComputeGraphPtr &graph: Parent compute graph.
  584. /// @return 0: SUCCESS / others: FAILED
  585. ///
  586. Status MultiBatchClonePass::PruneDirectOutput(const ComputeGraphPtr &graph) {
  587. const auto &func_desc = case_node_->GetOpDesc();
  588. uint32_t unused_num = 0;
  589. uint32_t output_num = func_desc->GetOutputsSize();
  590. for (size_t i = 0; i < output_num; ++i) {
  591. bool is_unused_tensor = true;
  592. for (const auto &item : all_branch_output_) {
  593. const auto &netoutput = item.second;
  594. GE_CHECK_NOTNULL(netoutput);
  595. const auto in_anchor = netoutput->GetInDataAnchor(i);
  596. if (in_anchor->GetPeerOutAnchor() != nullptr) {
  597. is_unused_tensor = false;
  598. break;
  599. }
  600. }
  601. if (is_unused_tensor) {
  602. unused_num++;
  603. continue;
  604. }
  605. GE_CHK_STATUS_RET(UpdateOutputTensor(i, unused_num), "Graph:%s Update output failed", graph->GetName().c_str());
  606. }
  607. if (unused_num == 0) {
  608. return SUCCESS;
  609. }
  610. GE_CHK_STATUS_RET(NodeUtils::RemoveOutputAnchor(case_node_, output_num - unused_num), "Remove output failed");
  611. for (const auto &item : all_branch_output_) {
  612. GE_CHK_STATUS_RET(NodeUtils::RemoveInputAnchor(item.second, output_num - unused_num), "Remove input failed");
  613. }
  614. return SUCCESS;
  615. }
  616. ///
  617. /// @ingroup ge
  618. /// @brief Update subgraph suspend output tensor.
  619. /// @param [in] parent_index: parent index for check.
  620. /// @param [in] unused_num: total unused tensor.
  621. /// @return 0: SUCCESS / others: FAILED
  622. ///
  623. Status MultiBatchClonePass::UpdateOutputTensor(uint32_t parent_index, uint32_t unused_num) {
  624. if (unused_num == 0) {
  625. return SUCCESS;
  626. }
  627. uint32_t update_index = parent_index - unused_num;
  628. for (const auto &item : all_branch_output_) {
  629. const auto &node = item.second;
  630. const auto &new_anchor = node->GetInDataAnchor(update_index);
  631. const auto &old_anchor = node->GetInDataAnchor(parent_index);
  632. const auto &out_anchor = old_anchor->GetPeerOutAnchor();
  633. const auto &out_node = out_anchor->GetOwnerNode();
  634. const auto &op_desc = node->GetOpDesc();
  635. (void)op_desc->UpdateInputDesc(update_index, op_desc->GetInputDesc(parent_index));
  636. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_anchor, new_anchor), "Add edge failed");
  637. GELOGI("Add edge success, func node: %s, node: %s, parent index: %u, update index: %u",
  638. case_node_->GetName().c_str(), out_node->GetName().c_str(), parent_index, update_index);
  639. GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, old_anchor), "Remove edge failed");
  640. GELOGI("Remove edge success, func node: %s, node: %s", case_node_->GetName().c_str(), out_node->GetName().c_str());
  641. }
  642. const auto &new_anchor = case_node_->GetOutDataAnchor(update_index);
  643. const auto &old_anchor = case_node_->GetOutDataAnchor(parent_index);
  644. for (const auto in_anchor : old_anchor->GetPeerInDataAnchors()) {
  645. const auto &in_node = in_anchor->GetOwnerNode();
  646. GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(old_anchor, in_anchor), "Remove edge failed");
  647. GELOGI("Remove edge success, func node: %s, node: %s", case_node_->GetName().c_str(), in_node->GetName().c_str());
  648. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(new_anchor, in_anchor), "Add edge failed");
  649. GELOGI("Add edge success, func node: %s, node: %s, parent index: %u, update index: %u",
  650. case_node_->GetName().c_str(), in_node->GetName().c_str(), parent_index, update_index);
  651. }
  652. return SUCCESS;
  653. }
  654. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示