You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_batch_clone_pass.cc 28 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/multi_batch_clone_pass.h"
  17. #include "common/formats/utils/formats_trans_utils.h"
  18. #include "common/ge/ge_util.h"
  19. #include "graph/common/local_context.h"
  20. #include "graph/preprocess/multi_batch_options.h"
  21. #include "graph/utils/node_utils.h"
  22. #include "graph/utils/op_desc_utils.h"
  23. #include "register/op_registry.h"
  24. namespace ge {
  25. namespace {
  26. constexpr uint8_t kDataInIndex = 0;
  27. constexpr uint8_t kDataOutIndex = 0;
  28. constexpr uint8_t kCaseArgIndex = 1;
  29. const std::string kMultiBatchCaseNode = "ascend_mbatch_shape_case";
  30. const std::string kMultiBatchDataNode = "ascend_mbatch_shape_data";
  31. const std::string kMultiBatchConstNode = "ascend_mbatch_shape_const";
  32. const std::string kMultiBatchMapIndexNode = "ascend_mbatch_shape_mapindex";
  33. const std::string kMultiBatchNodePostfix = "_ascend_mbatch_batch_";
  34. } // namespace
  35. Status MultiBatchClonePass::Run(ComputeGraphPtr graph) {
  36. if (graph->GetParentGraph() != nullptr) {
  37. GELOGD("Subgraph %s skip the MultiBatchClonePass", graph->GetName().c_str());
  38. return SUCCESS;
  39. }
  40. if (!multibatch::InitDynamicParams(batch_shapes_)) {
  41. GELOGD("There is no multi-batch options, no need clone multi-batch graph");
  42. return SUCCESS;
  43. }
  44. GELOGD("Begin to run Multi-batch clone on graph: %s", graph->GetName().c_str());
  45. GE_CHK_STATUS_RET(multibatch::CheckDynamicParams(batch_shapes_), "Invalid multi-batch param");
  46. if (CollectIoNodes(graph) != SUCCESS) {
  47. GELOGE(INTERNAL_ERROR, "Collect input output nodes failed");
  48. return INTERNAL_ERROR;
  49. }
  50. // parser data dynamic info from atc parameter --input_shape
  51. if (multibatch::ParserDataToDynmaicInfo(batch_shapes_, GetLocalOmgContext().user_input_dims,
  52. data_to_dynamic_info_) != SUCCESS) {
  53. GELOGE(PARAM_INVALID, "Parse each data's own dynamic info failed");
  54. return PARAM_INVALID;
  55. }
  56. (void)AttrUtils::GetStr(graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_);
  57. ComputeGraphPtr branch = MakeShared<ComputeGraph>(graph->GetName());
  58. if (branch == nullptr) {
  59. GELOGE(OUT_OF_MEMORY, "Create multi-batch graph failed");
  60. return OUT_OF_MEMORY;
  61. }
  62. (void)AttrUtils::SetStr(branch, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_);
  63. graph->InValid(); // Will modify, need topological again.
  64. graph->Swap(*branch);
  65. if (CreateRootGraph(graph) != SUCCESS) {
  66. return FAILED;
  67. }
  68. if (CreateSubgraphs(graph, branch) != SUCCESS) {
  69. return FAILED;
  70. }
  71. GE_CHK_STATUS_RET(PruneDirectOutput(graph), "Prune direct output failed");
  72. GELOGD("MultiBatchClonePass Leave");
  73. return SUCCESS;
  74. }
  75. ///
  76. /// @ingroup ge
  77. /// @brief Collect input output node from original graph.
  78. /// @param [in] const ComputeGraphPtr &graph: original graph.
  79. /// @return 0: SUCCESS / others: FAILED
  80. ///
  81. Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) {
  82. for (const auto &node : graph->GetDirectNode()) {
  83. if (node->GetType() == DATA) {
  84. all_data_nodes_.emplace_back(node);
  85. } else if (node->GetType() == CONSTANT) {
  86. all_const_nodes_.emplace_back(node);
  87. } else if (node->GetType() == NETOUTPUT) {
  88. all_output_nodes_.emplace_back(node);
  89. }
  90. // If the node save as input/output node, delete record.
  91. (void)graph->RemoveInputNode(node);
  92. (void)graph->RemoveOutputNode(node);
  93. }
  94. if (all_data_nodes_.empty() || all_output_nodes_.size() != 1) {
  95. GELOGE(FAILED, "data nodes: %zu, output nodes: %zu", all_data_nodes_.size(), all_output_nodes_.size());
  96. return FAILED;
  97. }
  98. int64_t data_index = 0;
  99. for (size_t i = 0; i < all_data_nodes_.size(); ++i) {
  100. const auto &op_desc = all_data_nodes_[i]->GetOpDesc();
  101. if (!AttrUtils::GetInt(op_desc, ATTR_NAME_INDEX, data_index)) {
  102. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, i);
  103. }
  104. }
  105. const auto &output = all_output_nodes_[0];
  106. for (size_t i = 0; i < output->GetAllInDataAnchorsSize(); ++i) {
  107. const auto in_anchor = output->GetInDataAnchor(i);
  108. const auto out_anchor = in_anchor->GetPeerOutAnchor();
  109. const auto data_node = out_anchor->GetOwnerNode();
  110. if (data_node->GetType() == DATA) {
  111. direct_output_[i] = data_node->GetName();
  112. GE_CHK_GRAPH_STATUS_RET(
  113. GraphUtils::RemoveEdge(data_node->GetOutDataAnchor(kDataOutIndex), output->GetInDataAnchor(i)),
  114. "Remove edge failed");
  115. }
  116. }
  117. return SUCCESS;
  118. }
  119. ///
  120. /// @ingroup ge
  121. /// @brief Create nodes for root graph.
  122. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  123. /// @return 0: SUCCESS / others: FAILED
  124. ///
  125. Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) {
  126. uint32_t input_num = all_data_nodes_.size() + all_const_nodes_.size();
  127. uint32_t output_num = all_output_nodes_[0]->GetAllInDataAnchorsSize();
  128. OpDescBuilder op_builder(kMultiBatchCaseNode, CASE);
  129. op_builder.AddInput("branch_index").AddDynamicInput("input", input_num).AddDynamicOutput("output", output_num);
  130. const OpDescPtr op_desc = op_builder.Build();
  131. if (op_desc == nullptr) {
  132. GELOGE(OUT_OF_MEMORY, "Create multi-batch case desc failed");
  133. return OUT_OF_MEMORY;
  134. }
  135. op_desc->RegisterSubgraphIrName("branches", kDynamic);
  136. case_node_ = graph->AddNode(op_desc);
  137. if (case_node_ == nullptr) {
  138. GELOGE(OUT_OF_MEMORY, "Create multi-batch case node failed");
  139. return OUT_OF_MEMORY;
  140. }
  141. uint32_t batch_num = static_cast<uint32_t>(batch_shapes_.size());
  142. if (!AttrUtils::SetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) {
  143. GELOGE(FAILED, "Set attr ATTR_NAME_BATCH_NUM failed, Case: %s.", op_desc->GetName().c_str());
  144. return FAILED;
  145. }
  146. for (uint32_t i = 0; i < batch_num; i++) {
  147. const std::string &attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i);
  148. if (!AttrUtils::SetListInt(op_desc, attr_name, batch_shapes_[i])) {
  149. GELOGE(FAILED, "Set attr ATTR_NAME_PRED_VALUE failed, Case: %s.", op_desc->GetName().c_str());
  150. return FAILED;
  151. }
  152. }
  153. std::vector<std::string> data_name_order;
  154. for (auto &item : GetLocalOmgContext().user_input_dims) {
  155. data_name_order.push_back(item.first);
  156. }
  157. if (!AttrUtils::SetListStr(op_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, data_name_order)) {
  158. GELOGE(FAILED, "Failed to add user designate shape order attr on case node %s",
  159. op_desc->GetName().c_str());
  160. return FAILED;
  161. }
  162. GE_CHK_STATUS_RET(multibatch::StampDynamicType(op_desc), "Set dynamic type failed");
  163. GE_CHK_STATUS_RET(CreateIndexNode(graph), "Create index node failed");
  164. GE_CHK_STATUS_RET(CreateInputNode(graph), "Create input node failed");
  165. GE_CHK_STATUS_RET(CreateConstNode(graph), "Create const node failed");
  166. GE_CHK_STATUS_RET(CreateOutputNode(graph), "Create output node failed");
  167. return SUCCESS;
  168. }
  169. ///
  170. /// @ingroup ge
  171. /// @brief Create index data node for root graph.
  172. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  173. /// @param [in] NodePtr node: index data node.
  174. /// @return 0: SUCCESS / others: FAILED
  175. ///
  176. Status MultiBatchClonePass::CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &node) {
  177. const OpDescPtr data_desc = MakeShared<OpDesc>(kMultiBatchDataNode, DATA);
  178. if (data_desc == nullptr) {
  179. GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed");
  180. return FAILED;
  181. }
  182. GeTensorDesc data_tensor(GeShape({static_cast<int64_t>(batch_shapes_[0].size())}), FORMAT_ND, DT_INT32);
  183. if (data_desc->AddInputDesc(data_tensor) != GRAPH_SUCCESS) {
  184. GELOGE(FAILED, "Add input desc failed");
  185. return FAILED;
  186. }
  187. if (data_desc->AddOutputDesc(data_tensor) != GRAPH_SUCCESS) {
  188. GELOGE(FAILED, "Add output desc failed");
  189. return FAILED;
  190. }
  191. size_t data_index = all_data_nodes_.size();
  192. (void)AttrUtils::SetInt(data_desc, ATTR_NAME_INDEX, data_index);
  193. (void)AttrUtils::SetBool(data_desc, ATTR_INSERT_BY_MBATCH, true);
  194. node = graph->AddNode(data_desc);
  195. if (node == nullptr) {
  196. GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed");
  197. return OUT_OF_MEMORY;
  198. }
  199. return SUCCESS;
  200. }
  201. ///
  202. /// @ingroup ge
  203. /// @brief Create index const node for root graph.
  204. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  205. /// @param [in] NodePtr node: index const node.
  206. /// @return 0: SUCCESS / others: FAILED
  207. ///
  208. Status MultiBatchClonePass::CreateIndexConstNode(const ComputeGraphPtr &graph, NodePtr &node) {
  209. const OpDescPtr const_desc = MakeShared<OpDesc>(kMultiBatchConstNode, CONSTANT);
  210. if (const_desc == nullptr) {
  211. GELOGE(OUT_OF_MEMORY, "Create multi-batch const node failed");
  212. return FAILED;
  213. }
  214. int64_t count = batch_shapes_.size() * batch_shapes_[0].size();
  215. std::unique_ptr<int32_t[]> addr(new (std::nothrow) int32_t[count]);
  216. GE_CHECK_NOTNULL(addr);
  217. size_t i = 0;
  218. for (auto &batch_shape : batch_shapes_) {
  219. for (int64_t dim : batch_shape) {
  220. addr[i++] = static_cast<int32_t>(dim);
  221. }
  222. }
  223. GeTensorDesc const_tensor(GeShape({count}), FORMAT_ND, DT_INT32);
  224. GeTensor tensor(const_tensor);
  225. (void)tensor.SetData(reinterpret_cast<uint8_t *>(addr.get()), count * sizeof(int32_t));
  226. if (!AttrUtils::SetTensor(const_desc, ATTR_NAME_WEIGHTS, tensor)) {
  227. GELOGE(OUT_OF_MEMORY, "Failed to init tensor value for const %s", const_desc->GetName().c_str());
  228. return FAILED;
  229. }
  230. if (const_desc->AddOutputDesc(const_tensor) != GRAPH_SUCCESS) {
  231. GELOGE(OUT_OF_MEMORY, "Failed to add output desc for const node %s", const_desc->GetName().c_str());
  232. return FAILED;
  233. }
  234. node = graph->AddNode(const_desc);
  235. if (node == nullptr) {
  236. GELOGE(OUT_OF_MEMORY, "Create multi-batch const node failed");
  237. return OUT_OF_MEMORY;
  238. }
  239. return SUCCESS;
  240. }
  241. ///
  242. /// @ingroup ge
  243. /// @brief Create index node for root graph.
  244. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  245. /// @return 0: SUCCESS / others: FAILED
  246. ///
  247. Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) {
  248. // Data --> MapIndex --> Case
  249. NodePtr data_node;
  250. GE_CHK_STATUS_RET(CreateIndexDataNode(graph, data_node), "Create data node failed");
  251. NodePtr const_node;
  252. GE_CHK_STATUS_RET(CreateIndexConstNode(graph, const_node), "Create const node failed");
  253. OpDescBuilder op_builder(kMultiBatchMapIndexNode, "MapIndex");
  254. op_builder.AddInput("x", data_node->GetOpDesc()->GetOutputDesc(0))
  255. .AddInput("data_seq", const_node->GetOpDesc()->GetOutputDesc(0))
  256. .AddOutput("y", GeTensorDesc(GeShape(), FORMAT_ND, DT_INT32));
  257. const OpDescPtr op_desc = op_builder.Build();
  258. if (op_desc == nullptr) {
  259. GELOGE(OUT_OF_MEMORY, "Create multi-batch index desc failed");
  260. return FAILED;
  261. }
  262. NodePtr index_node = graph->AddNode(op_desc);
  263. if (index_node == nullptr) {
  264. GELOGE(OUT_OF_MEMORY, "Create multi-batch index node failed");
  265. return OUT_OF_MEMORY;
  266. }
  267. if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), index_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  268. GELOGE(FAILED, "Failed to add edge between node:%s to MapIndex:%s", data_node->GetName().c_str(),
  269. index_node->GetName().c_str());
  270. return FAILED;
  271. }
  272. if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), index_node->GetInDataAnchor(1)) != GRAPH_SUCCESS) {
  273. GELOGE(FAILED, "Failed to add edge between node:%s to MapIndex:%s", const_node->GetName().c_str(),
  274. index_node->GetName().c_str());
  275. return FAILED;
  276. }
  277. if (GraphUtils::AddEdge(index_node->GetOutDataAnchor(0), case_node_->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  278. GELOGE(FAILED, "Failed to add edge between MapIndex:%s to Case:%s", index_node->GetName().c_str(),
  279. case_node_->GetName().c_str());
  280. return FAILED;
  281. }
  282. return SUCCESS;
  283. }
  284. ///
  285. /// @ingroup ge
  286. /// @brief Create input node for root graph.
  287. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  288. /// @return 0: SUCCESS / others: FAILED
  289. ///
  290. Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) {
  291. // Data --> Case
  292. std::vector<NodePtr> all_data_nodes;
  293. const size_t arg_index = kCaseArgIndex;
  294. for (size_t i = 0; i < all_data_nodes_.size(); ++i) {
  295. const auto &node = all_data_nodes_[i];
  296. const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc());
  297. if (op_desc == nullptr) {
  298. GELOGE(OUT_OF_MEMORY, "Create multi-batch Data node failed, name: %s", node->GetName().c_str());
  299. return FAILED;
  300. }
  301. if (GraphUtils::CopyTensorAttrs(op_desc, node) != GRAPH_SUCCESS) {
  302. return FAILED;
  303. }
  304. op_desc->SetName(node->GetName());
  305. const NodePtr &data = graph->AddNode(op_desc);
  306. GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str());
  307. if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) {
  308. GELOGE(FAILED, "Failed to add edge between Data:%s to Case:%s",
  309. data->GetName().c_str(), case_node_->GetName().c_str());
  310. return FAILED;
  311. }
  312. if (SetMaxShapeToData(data) != SUCCESS) {
  313. return FAILED;
  314. }
  315. all_data_nodes.emplace_back(data);
  316. }
  317. all_data_nodes_.swap(all_data_nodes);
  318. return SUCCESS;
  319. }
  320. ///
  321. /// @ingroup ge
  322. /// @brief Create Const node for root graph.
  323. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  324. /// @return 0: SUCCESS / others: FAILED
  325. ///
  326. Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) {
  327. // Const --> Case
  328. std::vector<NodePtr> all_const_nodes;
  329. const size_t arg_index = kCaseArgIndex + all_data_nodes_.size();
  330. for (size_t i = 0; i < all_const_nodes_.size(); ++i) {
  331. const auto &node = all_const_nodes_[i];
  332. const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc());
  333. if (op_desc == nullptr) {
  334. GELOGE(OUT_OF_MEMORY, "Create multi-batch Const node failed, name: %s", node->GetName().c_str());
  335. return FAILED;
  336. }
  337. op_desc->SetName(node->GetName());
  338. if (GraphUtils::CopyTensorAttrs(op_desc, node) != GRAPH_SUCCESS) {
  339. return FAILED;
  340. }
  341. const NodePtr &data = graph->AddNode(op_desc);
  342. GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str());
  343. if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) {
  344. GELOGE(FAILED, "Failed to add edge between Const:%s to Case:%s",
  345. data->GetName().c_str(), case_node_->GetName().c_str());
  346. return FAILED;
  347. }
  348. all_const_nodes.emplace_back(data);
  349. }
  350. size_t data_index = all_data_nodes_.size();
  351. for (size_t i = 0; i < all_const_nodes_.size(); ++i, ++data_index) { // Trans subgraph Const to Data.
  352. const OpDescPtr &op_desc = all_const_nodes_[i]->GetOpDesc();
  353. op_desc->SetType(DATA);
  354. (void)op_desc->DelAttr(ATTR_NAME_WEIGHTS); // Delete weight.
  355. // Const no InputDesc, Data need InputDesc.
  356. (void)op_desc->AddInputDesc(op_desc->GetOutputDesc(kDataOutIndex));
  357. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index);
  358. (void)NodeUtils::AppendInputAnchor(all_const_nodes_[i], 1);
  359. }
  360. all_const_nodes_.swap(all_const_nodes);
  361. return SUCCESS;
  362. }
  363. ///
  364. /// @ingroup ge
  365. /// @brief Create output node for root graph.
  366. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  367. /// @return 0: SUCCESS / others: FAILED
  368. ///
  369. Status MultiBatchClonePass::CreateOutputNode(const ComputeGraphPtr &graph) {
  370. const auto &output = all_output_nodes_[0];
  371. const OpDescPtr op_desc = AttrUtils::CopyOpDesc(output->GetOpDesc());
  372. if (op_desc == nullptr) {
  373. GELOGE(OUT_OF_MEMORY, "Create multi-batch output node failed");
  374. return FAILED;
  375. }
  376. if (GraphUtils::CopyTensorAttrs(op_desc, output) != GRAPH_SUCCESS) {
  377. return FAILED;
  378. }
  379. op_desc->SetName(output->GetName());
  380. const NodePtr &node = graph->AddNode(op_desc);
  381. GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str());
  382. for (size_t i = 0; i < case_node_->GetAllOutDataAnchorsSize(); ++i) {
  383. const auto it = direct_output_.find(i);
  384. if (it == direct_output_.end()) {
  385. if (GraphUtils::AddEdge(case_node_->GetOutDataAnchor(i), node->GetInDataAnchor(i)) != GRAPH_SUCCESS) {
  386. GELOGE(FAILED, "Failed to add edge between Case:%s to NetOutput:%s",
  387. case_node_->GetName().c_str(), node->GetName().c_str());
  388. return FAILED;
  389. }
  390. } else {
  391. const auto data_node = graph->FindNode(it->second);
  392. if (data_node == nullptr) {
  393. GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "Data node:%s not found", it->second.c_str());
  394. return GE_GRAPH_GRAPH_NODE_NULL;
  395. }
  396. if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(kDataOutIndex), node->GetInDataAnchor(i)) != GRAPH_SUCCESS) {
  397. GELOGE(FAILED, "Failed to add edge between Data:%s to NetOutput:%s",
  398. data_node->GetName().c_str(), node->GetName().c_str());
  399. return FAILED;
  400. }
  401. }
  402. }
  403. all_output_nodes_.clear();
  404. all_output_nodes_.emplace_back(node);
  405. return SUCCESS;
  406. }
  407. ///
  408. /// @ingroup ge
  409. /// @brief Set max shape to Data node in root graph.
  410. /// @param [in] const NodePtr &data: data in Root/Case graph.
  411. /// @return 0: SUCCESS / others: FAILED
  412. ///
  413. Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) {
  414. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  415. auto data_name = data->GetName();
  416. const auto &dims = data_shape.GetDims();
  417. if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) {
  418. return SUCCESS;
  419. }
  420. (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims());
  421. size_t max_shape_index = 0;
  422. int64_t max_size = 0;
  423. for (size_t i = 0; i < batch_shapes_.size(); ++i) {
  424. int64_t size = 1;
  425. for (auto dim : data_to_dynamic_info_.at(data_name).at(i)) {
  426. if (INT64_MAX / dim < size) {
  427. GELOGE(PARAM_INVALID, "The shape %s size overflow",
  428. formats::ShapeToString(data_to_dynamic_info_.at(data_name).at(i)).c_str());
  429. return PARAM_INVALID;
  430. }
  431. size *= dim;
  432. }
  433. if (size > max_size) {
  434. max_size = size;
  435. max_shape_index = i;
  436. }
  437. }
  438. return SetShapeToData(data_to_dynamic_info_.at(data_name).at(max_shape_index), data, data_shape);
  439. }
  440. ///
  441. /// @ingroup ge
  442. /// @brief Set shape to Data node in branch.
  443. /// @param [in] const NodePtr &data: data in branch.
  444. /// @param [in] size_t index: The batch index.
  445. /// @return 0: SUCCESS / others: FAILED
  446. ///
  447. Status MultiBatchClonePass::UpdateShapeToData(const NodePtr &data, size_t index) {
  448. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  449. const auto &dims = data_shape.GetDims();
  450. if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) {
  451. return SUCCESS;
  452. }
  453. (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims());
  454. auto data_name = data->GetName();
  455. size_t pos = data_name.find(kMultiBatchNodePostfix);
  456. if (pos == string::npos) {
  457. GELOGE(FAILED, "Cannot find key string [%s] of multi-batch in name of virtual input node, node name: %s.",
  458. kMultiBatchNodePostfix.c_str(), data_name.c_str());
  459. return FAILED;
  460. }
  461. auto parent_name = data_name.substr(0, pos);
  462. return SetShapeToData(data_to_dynamic_info_.at(parent_name).at(index), data, data_shape);
  463. }
  464. ///
  465. /// @ingroup ge
  466. /// @brief Set max shape to Data node in root graph.
  467. /// @param [in] const std::vector<int64_t> &shapes: dims of shape.
  468. /// @param [in] const NodePtr &data: data in Root/Case graph.
  469. /// @param [in] GeShape &data_shape: dims of data node.
  470. /// @return 0: SUCCESS / others: FAILED
  471. ///
  472. Status MultiBatchClonePass::SetShapeToData(const vector<int64_t> &shapes, const NodePtr &data, GeShape &data_shape) {
  473. // must not be error, the calc result has been checked in function InsertSwitchNForData
  474. if (multibatch::CalcShape(shapes, data_shape) != SUCCESS) {
  475. return INTERNAL_ERROR;
  476. }
  477. if (NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape) != GRAPH_SUCCESS) {
  478. GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str());
  479. return INTERNAL_ERROR;
  480. }
  481. if (NodeUtils::UpdateOutputShape(*data, kDataOutIndex, data_shape) != GRAPH_SUCCESS) {
  482. GELOGE(INTERNAL_ERROR, "Failed to update output shape for data %s", data->GetName().c_str());
  483. return INTERNAL_ERROR;
  484. }
  485. GELOGI("Update %s input/output shape to %s", data->GetName().c_str(), formats::ShapeToString(data_shape).c_str());
  486. return SUCCESS;
  487. }
  488. ///
  489. /// @ingroup ge
  490. /// @brief Create nodes for root graph.
  491. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  492. /// @param [in] const ComputeGraphPtr &branch: original graph.
  493. /// @return 0: SUCCESS / others: FAILED
  494. ///
  495. Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch) {
  496. const auto &op_desc = case_node_->GetOpDesc();
  497. for (size_t i = 0; i < batch_shapes_.size(); ++i) {
  498. std::vector<NodePtr> input_nodes;
  499. std::vector<NodePtr> output_nodes;
  500. const std::string postfix = kMultiBatchNodePostfix + std::to_string(i);
  501. ComputeGraphPtr subgraph = (i == 0) ? branch : GraphUtils::CloneGraph(branch, postfix, input_nodes, output_nodes);
  502. if (subgraph == nullptr) {
  503. GELOGE(FAILED, "Create multi-batch case node failed");
  504. return FAILED;
  505. }
  506. subgraph->SetName("Batch_" + std::to_string(i));
  507. subgraph->SetParentNode(case_node_);
  508. subgraph->SetParentGraph(graph);
  509. graph->AddSubgraph(subgraph->GetName(), subgraph);
  510. all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT);
  511. const string key_name = "branches" + std::to_string(i);
  512. op_desc->AddSubgraphName(key_name);
  513. op_desc->SetSubgraphInstanceName(i, subgraph->GetName());
  514. for (const auto &data : input_nodes) {
  515. GE_CHK_STATUS_RET(UpdateShapeToData(data, i), "Update %s failed", subgraph->GetName().c_str());
  516. }
  517. }
  518. // Origninal graph take as first subgraph, update node name.
  519. for (const auto &n : branch->GetDirectNode()) {
  520. const auto &op_desc = n->GetOpDesc();
  521. op_desc->SetName(n->GetName() + kMultiBatchNodePostfix + "0");
  522. if (n->GetType() == DATA) {
  523. GE_CHK_STATUS_RET(UpdateShapeToData(n, 0), "Update %s failed", branch->GetName().c_str());
  524. }
  525. }
  526. return PostProcSubgraph(graph);
  527. }
  528. ///
  529. /// @ingroup ge
  530. /// @brief Assign parent index for branches.
  531. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  532. /// @return 0: SUCCESS / others: FAILED
  533. ///
  534. Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) {
  535. auto func_desc = case_node_->GetOpDesc();
  536. domi::ParseSubgraphFuncV2 parse_func_v2 = nullptr;
  537. auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType());
  538. if (post_func == nullptr) {
  539. GELOGW("The subgraph post func for node %s type %s is null.", case_node_->GetName().c_str(),
  540. case_node_->GetType().c_str());
  541. if (domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType(), parse_func_v2) != SUCCESS ||
  542. parse_func_v2 == nullptr) {
  543. GELOGW("The subgraph new post func v2 for node %s type %s is null", case_node_->GetName().c_str(),
  544. case_node_->GetType().c_str());
  545. return FAILED;
  546. }
  547. }
  548. for (const auto &name : func_desc->GetSubgraphInstanceNames()) {
  549. const auto &subgraph = graph->GetSubgraph(name);
  550. if (subgraph == nullptr) {
  551. GELOGE(FAILED, "Subgraph not found, name: %s", name.c_str());
  552. return FAILED;
  553. }
  554. std::string subgraph_name;
  555. GE_CHK_STATUS_RET(func_desc->GetSubgraphNameByInstanceName(subgraph->GetName(), subgraph_name),
  556. "Subgraph: %s get subgraph name failed.", subgraph->GetName().c_str());
  557. auto graph = GraphUtils::CreateGraphFromComputeGraph(subgraph);
  558. Status ret = FAILED;
  559. if (post_func != nullptr) {
  560. ret = post_func(subgraph_name, graph);
  561. } else if (parse_func_v2 != nullptr) {
  562. ret = parse_func_v2(subgraph_name.c_str(), graph);
  563. }
  564. if (ret != SUCCESS) {
  565. GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", graph.GetName().c_str(),
  566. case_node_->GetName().c_str(), case_node_->GetType().c_str());
  567. return FAILED;
  568. }
  569. }
  570. return SUCCESS;
  571. }
  572. ///
  573. /// @ingroup ge
  574. /// @brief Remove subgraph suspend output anchor.
  575. /// @param [in] ComputeGraphPtr &graph: Parent compute graph.
  576. /// @return 0: SUCCESS / others: FAILED
  577. ///
  578. Status MultiBatchClonePass::PruneDirectOutput(const ComputeGraphPtr &graph) {
  579. const auto &func_desc = case_node_->GetOpDesc();
  580. uint32_t unused_num = 0;
  581. uint32_t output_num = func_desc->GetOutputsSize();
  582. for (size_t i = 0; i < output_num; ++i) {
  583. bool is_unused_tensor = true;
  584. for (const auto &item : all_branch_output_) {
  585. const auto &netoutput = item.second;
  586. GE_CHECK_NOTNULL(netoutput);
  587. const auto in_anchor = netoutput->GetInDataAnchor(i);
  588. if (in_anchor->GetPeerOutAnchor() != nullptr) {
  589. is_unused_tensor = false;
  590. break;
  591. }
  592. }
  593. if (is_unused_tensor) {
  594. unused_num++;
  595. continue;
  596. }
  597. GE_CHK_STATUS_RET(UpdateOutputTensor(i, unused_num), "Graph:%s Update output failed", graph->GetName().c_str());
  598. }
  599. if (unused_num == 0) {
  600. return SUCCESS;
  601. }
  602. GE_CHK_STATUS_RET(NodeUtils::RemoveOutputAnchor(case_node_, output_num - unused_num), "Remove output failed");
  603. for (const auto &item : all_branch_output_) {
  604. GE_CHK_STATUS_RET(NodeUtils::RemoveInputAnchor(item.second, output_num - unused_num), "Remove input failed");
  605. }
  606. return SUCCESS;
  607. }
  608. ///
  609. /// @ingroup ge
  610. /// @brief Update subgraph suspend output tensor.
  611. /// @param [in] parent_index: parent index for check.
  612. /// @param [in] unused_num: total unused tensor.
  613. /// @return 0: SUCCESS / others: FAILED
  614. ///
  615. Status MultiBatchClonePass::UpdateOutputTensor(uint32_t parent_index, uint32_t unused_num) {
  616. if (unused_num == 0) {
  617. return SUCCESS;
  618. }
  619. uint32_t update_index = parent_index - unused_num;
  620. for (const auto &item : all_branch_output_) {
  621. const auto &node = item.second;
  622. const auto &new_anchor = node->GetInDataAnchor(update_index);
  623. const auto &old_anchor = node->GetInDataAnchor(parent_index);
  624. const auto &out_anchor = old_anchor->GetPeerOutAnchor();
  625. const auto &out_node = out_anchor->GetOwnerNode();
  626. const auto &op_desc = node->GetOpDesc();
  627. (void)op_desc->UpdateInputDesc(update_index, op_desc->GetInputDesc(parent_index));
  628. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_anchor, new_anchor), "Add edge failed");
  629. GELOGI("Add edge success, func node: %s, node: %s, parent index: %u, update index: %u",
  630. case_node_->GetName().c_str(), out_node->GetName().c_str(), parent_index, update_index);
  631. GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, old_anchor), "Remove edge failed");
  632. GELOGI("Remove edge success, func node: %s, node: %s", case_node_->GetName().c_str(), out_node->GetName().c_str());
  633. }
  634. const auto &new_anchor = case_node_->GetOutDataAnchor(update_index);
  635. const auto &old_anchor = case_node_->GetOutDataAnchor(parent_index);
  636. for (const auto in_anchor : old_anchor->GetPeerInDataAnchors()) {
  637. const auto &in_node = in_anchor->GetOwnerNode();
  638. GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(old_anchor, in_anchor), "Remove edge failed");
  639. GELOGI("Remove edge success, func node: %s, node: %s", case_node_->GetName().c_str(), in_node->GetName().c_str());
  640. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(new_anchor, in_anchor), "Add edge failed");
  641. GELOGI("Add edge success, func node: %s, node: %s, parent index: %u, update index: %u",
  642. case_node_->GetName().c_str(), in_node->GetName().c_str(), parent_index, update_index);
  643. }
  644. return SUCCESS;
  645. }
  646. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示