You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_batch_clone_pass.cc 26 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/multi_batch_clone_pass.h"
  17. #include "common/formats/utils/formats_trans_utils.h"
  18. #include "common/ge/ge_util.h"
  19. #include "graph/preprocess/multi_batch_options.h"
  20. #include "graph/utils/node_utils.h"
  21. #include "graph/utils/op_desc_utils.h"
  22. #include "register/op_registry.h"
  23. namespace ge {
  24. namespace {
  25. constexpr uint8_t kDataInIndex = 0;
  26. constexpr uint8_t kDataOutIndex = 0;
  27. constexpr uint8_t kCaseArgIndex = 1;
  28. const std::string kMultiBatchCaseNode = "ascend_mbatch_shape_case";
  29. const std::string kMultiBatchDataNode = "ascend_mbatch_shape_data";
  30. const std::string kMultiBatchConstNode = "ascend_mbatch_shape_const";
  31. const std::string kMultiBatchMapIndexNode = "ascend_mbatch_shape_mapindex";
  32. } // namespace
  33. Status MultiBatchClonePass::Run(ComputeGraphPtr graph) {
  34. if (graph->GetParentGraph() != nullptr) {
  35. GELOGD("Subgraph %s skip the MultiBatchClonePass", graph->GetName().c_str());
  36. return SUCCESS;
  37. }
  38. if (!multibatch::InitDynamicParams(batch_shapes_)) {
  39. GELOGD("There is no multi-batch options, no need clone multi-batch graph");
  40. return SUCCESS;
  41. }
  42. GELOGD("Begin to run Multi-batch clone on graph: %s", graph->GetName().c_str());
  43. GE_CHK_STATUS_RET(multibatch::CheckDynamicParams(batch_shapes_), "Invalid multi-batch param");
  44. if (CollectIoNodes(graph) != SUCCESS) {
  45. GELOGE(INTERNAL_ERROR, "Collect input output nodes failed");
  46. return INTERNAL_ERROR;
  47. }
  48. (void)AttrUtils::GetStr(graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_);
  49. ComputeGraphPtr branch = MakeShared<ComputeGraph>(graph->GetName());
  50. if (branch == nullptr) {
  51. GELOGE(OUT_OF_MEMORY, "Create multi-batch graph failed");
  52. return OUT_OF_MEMORY;
  53. }
  54. (void)AttrUtils::SetStr(branch, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_);
  55. graph->InValid(); // Will modify, need topological again.
  56. graph->Swap(*branch);
  57. if (CreateRootGraph(graph) != SUCCESS) {
  58. return FAILED;
  59. }
  60. if (CreateSubgraphs(graph, branch) != SUCCESS) {
  61. return FAILED;
  62. }
  63. GE_CHK_STATUS_RET(PruneDirectOutput(graph), "Prune direct output failed");
  64. GELOGD("MultiBatchClonePass Leave");
  65. return SUCCESS;
  66. }
  67. ///
  68. /// @ingroup ge
  69. /// @brief Collect input output node from original graph.
  70. /// @param [in] const ComputeGraphPtr &graph: original graph.
  71. /// @return 0: SUCCESS / others: FAILED
  72. ///
  73. Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) {
  74. for (const auto &node : graph->GetDirectNode()) {
  75. if (node->GetType() == DATA) {
  76. all_data_nodes_.emplace_back(node);
  77. } else if (node->GetType() == CONSTANT) {
  78. all_const_nodes_.emplace_back(node);
  79. } else if (node->GetType() == NETOUTPUT) {
  80. all_output_nodes_.emplace_back(node);
  81. }
  82. // If the node save as input/output node, delete record.
  83. (void)graph->RemoveInputNode(node);
  84. (void)graph->RemoveOutputNode(node);
  85. }
  86. if (all_data_nodes_.empty() || all_output_nodes_.size() != 1) {
  87. GELOGE(FAILED, "data nodes: %zu, output nodes: %zu", all_data_nodes_.size(), all_output_nodes_.size());
  88. return FAILED;
  89. }
  90. int64_t data_index = 0;
  91. for (size_t i = 0; i < all_data_nodes_.size(); ++i) {
  92. const auto &op_desc = all_data_nodes_[i]->GetOpDesc();
  93. if (!AttrUtils::GetInt(op_desc, ATTR_NAME_INDEX, data_index)) {
  94. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, i);
  95. }
  96. }
  97. const auto &output = all_output_nodes_[0];
  98. for (size_t i = 0; i < output->GetAllInDataAnchorsSize(); ++i) {
  99. const auto in_anchor = output->GetInDataAnchor(i);
  100. const auto out_anchor = in_anchor->GetPeerOutAnchor();
  101. const auto data_node = out_anchor->GetOwnerNode();
  102. if (data_node->GetType() == DATA) {
  103. direct_output_[i] = data_node->GetName();
  104. GE_CHK_GRAPH_STATUS_RET(
  105. GraphUtils::RemoveEdge(data_node->GetOutDataAnchor(kDataOutIndex), output->GetInDataAnchor(i)),
  106. "Remove edge failed");
  107. }
  108. }
  109. return SUCCESS;
  110. }
  111. ///
  112. /// @ingroup ge
  113. /// @brief Create nodes for root graph.
  114. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  115. /// @return 0: SUCCESS / others: FAILED
  116. ///
  117. Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) {
  118. uint32_t input_num = all_data_nodes_.size() + all_const_nodes_.size();
  119. uint32_t output_num = all_output_nodes_[0]->GetAllInDataAnchorsSize();
  120. OpDescBuilder op_builder(kMultiBatchCaseNode, CASE);
  121. op_builder.AddInput("branch_index").AddDynamicInput("input", input_num).AddDynamicOutput("output", output_num);
  122. const OpDescPtr op_desc = op_builder.Build();
  123. if (op_desc == nullptr) {
  124. GELOGE(OUT_OF_MEMORY, "Create multi-batch case desc failed");
  125. return OUT_OF_MEMORY;
  126. }
  127. op_desc->RegisterSubgraphIrName("branches", kDynamic);
  128. case_node_ = graph->AddNode(op_desc);
  129. if (case_node_ == nullptr) {
  130. GELOGE(OUT_OF_MEMORY, "Create multi-batch case node failed");
  131. return OUT_OF_MEMORY;
  132. }
  133. uint32_t batch_num = static_cast<uint32_t>(batch_shapes_.size());
  134. if (!AttrUtils::SetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) {
  135. GELOGE(FAILED, "Set attr ATTR_NAME_BATCH_NUM failed, Case: %s.", op_desc->GetName().c_str());
  136. return FAILED;
  137. }
  138. for (uint32_t i = 0; i < batch_num; i++) {
  139. const std::string &attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i);
  140. if (!AttrUtils::SetListInt(op_desc, attr_name, batch_shapes_[i])) {
  141. GELOGE(FAILED, "Set attr ATTR_NAME_PRED_VALUE failed, Case: %s.", op_desc->GetName().c_str());
  142. return FAILED;
  143. }
  144. }
  145. GE_CHK_STATUS_RET(multibatch::StampDynamicType(op_desc), "Set dynamic type failed");
  146. GE_CHK_STATUS_RET(CreateIndexNode(graph), "Create index node failed");
  147. GE_CHK_STATUS_RET(CreateInputNode(graph), "Create input node failed");
  148. GE_CHK_STATUS_RET(CreateConstNode(graph), "Create const node failed");
  149. GE_CHK_STATUS_RET(CreateOutputNode(graph), "Create output node failed");
  150. return SUCCESS;
  151. }
  152. ///
  153. /// @ingroup ge
  154. /// @brief Create index data node for root graph.
  155. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  156. /// @param [in] NodePtr node: index data node.
  157. /// @return 0: SUCCESS / others: FAILED
  158. ///
  159. Status MultiBatchClonePass::CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &node) {
  160. const OpDescPtr data_desc = MakeShared<OpDesc>(kMultiBatchDataNode, DATA);
  161. if (data_desc == nullptr) {
  162. GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed");
  163. return FAILED;
  164. }
  165. GeTensorDesc data_tensor(GeShape({static_cast<int64_t>(batch_shapes_[0].size())}), FORMAT_ND, DT_INT32);
  166. if (data_desc->AddInputDesc(data_tensor) != GRAPH_SUCCESS) {
  167. GELOGE(FAILED, "Add input desc failed");
  168. return FAILED;
  169. }
  170. if (data_desc->AddOutputDesc(data_tensor) != GRAPH_SUCCESS) {
  171. GELOGE(FAILED, "Add output desc failed");
  172. return FAILED;
  173. }
  174. size_t data_index = all_data_nodes_.size();
  175. (void)AttrUtils::SetInt(data_desc, ATTR_NAME_INDEX, data_index);
  176. (void)AttrUtils::SetBool(data_desc, ATTR_INSERT_BY_MBATCH, true);
  177. node = graph->AddNode(data_desc);
  178. if (node == nullptr) {
  179. GELOGE(OUT_OF_MEMORY, "Create multi-batch data node failed");
  180. return OUT_OF_MEMORY;
  181. }
  182. return SUCCESS;
  183. }
  184. ///
  185. /// @ingroup ge
  186. /// @brief Create index const node for root graph.
  187. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  188. /// @param [in] NodePtr node: index const node.
  189. /// @return 0: SUCCESS / others: FAILED
  190. ///
  191. Status MultiBatchClonePass::CreateIndexConstNode(const ComputeGraphPtr &graph, NodePtr &node) {
  192. const OpDescPtr const_desc = MakeShared<OpDesc>(kMultiBatchConstNode, CONSTANT);
  193. if (const_desc == nullptr) {
  194. GELOGE(OUT_OF_MEMORY, "Create multi-batch const node failed");
  195. return FAILED;
  196. }
  197. int64_t count = batch_shapes_.size() * batch_shapes_[0].size();
  198. std::unique_ptr<int32_t[]> addr(new (std::nothrow) int32_t[count]);
  199. GE_CHECK_NOTNULL(addr);
  200. size_t i = 0;
  201. for (auto &batch_shape : batch_shapes_) {
  202. for (int64_t dim : batch_shape) {
  203. addr[i++] = static_cast<int32_t>(dim);
  204. }
  205. }
  206. GeTensorDesc const_tensor(GeShape({count}), FORMAT_ND, DT_INT32);
  207. GeTensor tensor(const_tensor);
  208. (void)tensor.SetData(reinterpret_cast<uint8_t *>(addr.get()), count * sizeof(int32_t));
  209. if (!AttrUtils::SetTensor(const_desc, ATTR_NAME_WEIGHTS, tensor)) {
  210. GELOGE(OUT_OF_MEMORY, "Failed to init tensor value for const %s", const_desc->GetName().c_str());
  211. return FAILED;
  212. }
  213. if (const_desc->AddOutputDesc(const_tensor) != GRAPH_SUCCESS) {
  214. GELOGE(OUT_OF_MEMORY, "Failed to add output desc for const node %s", const_desc->GetName().c_str());
  215. return FAILED;
  216. }
  217. node = graph->AddNode(const_desc);
  218. if (node == nullptr) {
  219. GELOGE(OUT_OF_MEMORY, "Create multi-batch const node failed");
  220. return OUT_OF_MEMORY;
  221. }
  222. return SUCCESS;
  223. }
  224. ///
  225. /// @ingroup ge
  226. /// @brief Create index node for root graph.
  227. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  228. /// @return 0: SUCCESS / others: FAILED
  229. ///
  230. Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) {
  231. // Data --> MapIndex --> Case
  232. NodePtr data_node;
  233. GE_CHK_STATUS_RET(CreateIndexDataNode(graph, data_node), "Create data node failed");
  234. NodePtr const_node;
  235. GE_CHK_STATUS_RET(CreateIndexConstNode(graph, const_node), "Create const node failed");
  236. OpDescBuilder op_builder(kMultiBatchMapIndexNode, "MapIndex");
  237. op_builder.AddInput("x", data_node->GetOpDesc()->GetOutputDesc(0))
  238. .AddInput("data_seq", const_node->GetOpDesc()->GetOutputDesc(0))
  239. .AddOutput("y", GeTensorDesc(GeShape(), FORMAT_ND, DT_INT32));
  240. const OpDescPtr op_desc = op_builder.Build();
  241. if (op_desc == nullptr) {
  242. GELOGE(OUT_OF_MEMORY, "Create multi-batch index desc failed");
  243. return FAILED;
  244. }
  245. NodePtr index_node = graph->AddNode(op_desc);
  246. if (index_node == nullptr) {
  247. GELOGE(OUT_OF_MEMORY, "Create multi-batch index node failed");
  248. return OUT_OF_MEMORY;
  249. }
  250. if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), index_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  251. GELOGE(FAILED, "Failed to add edge between node:%s to MapIndex:%s", data_node->GetName().c_str(),
  252. index_node->GetName().c_str());
  253. return FAILED;
  254. }
  255. if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), index_node->GetInDataAnchor(1)) != GRAPH_SUCCESS) {
  256. GELOGE(FAILED, "Failed to add edge between node:%s to MapIndex:%s", const_node->GetName().c_str(),
  257. index_node->GetName().c_str());
  258. return FAILED;
  259. }
  260. if (GraphUtils::AddEdge(index_node->GetOutDataAnchor(0), case_node_->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  261. GELOGE(FAILED, "Failed to add edge between MapIndex:%s to Case:%s", index_node->GetName().c_str(),
  262. case_node_->GetName().c_str());
  263. return FAILED;
  264. }
  265. return SUCCESS;
  266. }
  267. ///
  268. /// @ingroup ge
  269. /// @brief Create input node for root graph.
  270. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  271. /// @return 0: SUCCESS / others: FAILED
  272. ///
  273. Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) {
  274. // Data --> Case
  275. std::vector<NodePtr> all_data_nodes;
  276. const size_t arg_index = kCaseArgIndex;
  277. for (size_t i = 0; i < all_data_nodes_.size(); ++i) {
  278. const auto &node = all_data_nodes_[i];
  279. const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc());
  280. if (op_desc == nullptr) {
  281. GELOGE(OUT_OF_MEMORY, "Create multi-batch Data node failed, name: %s", node->GetName().c_str());
  282. return FAILED;
  283. }
  284. if (GraphUtils::CopyTensorAttrs(op_desc, node) != GRAPH_SUCCESS) {
  285. return FAILED;
  286. }
  287. op_desc->SetName(node->GetName());
  288. const NodePtr &data = graph->AddNode(op_desc);
  289. GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str());
  290. if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) {
  291. GELOGE(FAILED, "Failed to add edge between Data:%s to Case:%s", data->GetName().c_str(),
  292. case_node_->GetName().c_str());
  293. return FAILED;
  294. }
  295. if (SetMaxShapeToData(data) != SUCCESS) {
  296. return FAILED;
  297. }
  298. all_data_nodes.emplace_back(data);
  299. }
  300. all_data_nodes_.swap(all_data_nodes);
  301. return SUCCESS;
  302. }
  303. ///
  304. /// @ingroup ge
  305. /// @brief Create Const node for root graph.
  306. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  307. /// @return 0: SUCCESS / others: FAILED
  308. ///
  309. Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) {
  310. // Const --> Case
  311. std::vector<NodePtr> all_const_nodes;
  312. const size_t arg_index = kCaseArgIndex + all_data_nodes_.size();
  313. for (size_t i = 0; i < all_const_nodes_.size(); ++i) {
  314. const auto &node = all_const_nodes_[i];
  315. const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc());
  316. if (op_desc == nullptr) {
  317. GELOGE(OUT_OF_MEMORY, "Create multi-batch Const node failed, name: %s", node->GetName().c_str());
  318. return FAILED;
  319. }
  320. op_desc->SetName(node->GetName());
  321. if (GraphUtils::CopyTensorAttrs(op_desc, node) != GRAPH_SUCCESS) {
  322. return FAILED;
  323. }
  324. const NodePtr &data = graph->AddNode(op_desc);
  325. GE_CHK_BOOL_EXEC(data != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str());
  326. if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) {
  327. GELOGE(FAILED, "Failed to add edge between Const:%s to Case:%s", data->GetName().c_str(),
  328. case_node_->GetName().c_str());
  329. return FAILED;
  330. }
  331. all_const_nodes.emplace_back(data);
  332. }
  333. size_t data_index = all_data_nodes_.size();
  334. for (size_t i = 0; i < all_const_nodes_.size(); ++i, ++data_index) { // Trans subgraph Const to Data.
  335. const OpDescPtr &op_desc = all_const_nodes_[i]->GetOpDesc();
  336. op_desc->SetType(DATA);
  337. (void)op_desc->DelAttr(ATTR_NAME_WEIGHTS); // Delete weight.
  338. // Const no InputDesc, Data need InputDesc.
  339. (void)op_desc->AddInputDesc(op_desc->GetOutputDesc(kDataOutIndex));
  340. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index);
  341. }
  342. all_const_nodes_.swap(all_const_nodes);
  343. return SUCCESS;
  344. }
  345. ///
  346. /// @ingroup ge
  347. /// @brief Create output node for root graph.
  348. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  349. /// @return 0: SUCCESS / others: FAILED
  350. ///
  351. Status MultiBatchClonePass::CreateOutputNode(const ComputeGraphPtr &graph) {
  352. const auto &output = all_output_nodes_[0];
  353. const OpDescPtr op_desc = AttrUtils::CopyOpDesc(output->GetOpDesc());
  354. if (op_desc == nullptr) {
  355. GELOGE(OUT_OF_MEMORY, "Create multi-batch output node failed");
  356. return FAILED;
  357. }
  358. if (GraphUtils::CopyTensorAttrs(op_desc, output) != GRAPH_SUCCESS) {
  359. return FAILED;
  360. }
  361. op_desc->SetName(output->GetName());
  362. const NodePtr &node = graph->AddNode(op_desc);
  363. GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Add node[%s] to graph failed", op_desc->GetName().c_str());
  364. for (size_t i = 0; i < case_node_->GetAllOutDataAnchorsSize(); ++i) {
  365. const auto it = direct_output_.find(i);
  366. if (it == direct_output_.end()) {
  367. if (GraphUtils::AddEdge(case_node_->GetOutDataAnchor(i), node->GetInDataAnchor(i)) != GRAPH_SUCCESS) {
  368. GELOGE(FAILED, "Failed to add edge between Case:%s to NetOutput:%s", case_node_->GetName().c_str(),
  369. node->GetName().c_str());
  370. return FAILED;
  371. }
  372. } else {
  373. const auto data_node = graph->FindNode(it->second);
  374. if (data_node == nullptr) {
  375. GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "Data node:%s not found", it->second.c_str());
  376. return GE_GRAPH_GRAPH_NODE_NULL;
  377. }
  378. if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(kDataOutIndex), node->GetInDataAnchor(i)) != GRAPH_SUCCESS) {
  379. GELOGE(FAILED, "Failed to add edge between Data:%s to NetOutput:%s", data_node->GetName().c_str(),
  380. node->GetName().c_str());
  381. return FAILED;
  382. }
  383. }
  384. }
  385. all_output_nodes_.clear();
  386. all_output_nodes_.emplace_back(node);
  387. return SUCCESS;
  388. }
  389. ///
  390. /// @ingroup ge
  391. /// @brief Set max shape to Data node in root graph.
  392. /// @param [in] const NodePtr &data: data in Root/Case graph.
  393. /// @return 0: SUCCESS / others: FAILED
  394. ///
  395. Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &data) {
  396. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  397. const auto &dims = data_shape.GetDims();
  398. if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) {
  399. return SUCCESS;
  400. }
  401. (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims());
  402. size_t max_shape_index = 0;
  403. int64_t max_size = 0;
  404. for (size_t i = 0; i < batch_shapes_.size(); ++i) {
  405. int64_t size = 1;
  406. for (auto dim : batch_shapes_[i]) {
  407. if (INT64_MAX / dim < size) {
  408. GELOGE(PARAM_INVALID, "The shape %s size overflow", formats::ShapeToString(batch_shapes_[i]).c_str());
  409. return PARAM_INVALID;
  410. }
  411. size *= dim;
  412. }
  413. if (size > max_size) {
  414. max_size = size;
  415. max_shape_index = i;
  416. }
  417. }
  418. return SetShapeToData(batch_shapes_[max_shape_index], data, data_shape);
  419. }
  420. ///
  421. /// @ingroup ge
  422. /// @brief Set shape to Data node in branch.
  423. /// @param [in] const NodePtr &data: data in branch.
  424. /// @param [in] const std::vector<int64_t> &shapes: dims of shape.
  425. /// @return 0: SUCCESS / others: FAILED
  426. ///
  427. Status MultiBatchClonePass::UpdataShapeToData(const NodePtr &data, const vector<int64_t> &shapes) {
  428. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  429. const auto &dims = data_shape.GetDims();
  430. if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) {
  431. return SUCCESS;
  432. }
  433. (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims());
  434. return SetShapeToData(shapes, data, data_shape);
  435. }
  436. ///
  437. /// @ingroup ge
  438. /// @brief Set max shape to Data node in root graph.
  439. /// @param [in] const std::vector<int64_t> &shapes: dims of shape.
  440. /// @param [in] const NodePtr &data: data in Root/Case graph.
  441. /// @param [in] GeShape &data_shape: dims of data node.
  442. /// @return 0: SUCCESS / others: FAILED
  443. ///
  444. Status MultiBatchClonePass::SetShapeToData(const vector<int64_t> &shapes, const NodePtr &data, GeShape &data_shape) {
  445. // must not be error, the calc result has been checked in function InsertSwitchNForData
  446. if (multibatch::CalcShape(shapes, data_shape) != SUCCESS) {
  447. return INTERNAL_ERROR;
  448. }
  449. if (NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape) != GRAPH_SUCCESS) {
  450. GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str());
  451. return INTERNAL_ERROR;
  452. }
  453. if (NodeUtils::UpdateOutputShape(*data, kDataOutIndex, data_shape) != GRAPH_SUCCESS) {
  454. GELOGE(INTERNAL_ERROR, "Failed to update output shape for data %s", data->GetName().c_str());
  455. return INTERNAL_ERROR;
  456. }
  457. GELOGI("Update %s input/output shape to %s", data->GetName().c_str(), formats::ShapeToString(data_shape).c_str());
  458. return SUCCESS;
  459. }
  460. ///
  461. /// @ingroup ge
  462. /// @brief Create nodes for root graph.
  463. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  464. /// @param [in] const ComputeGraphPtr &branch: original graph.
  465. /// @return 0: SUCCESS / others: FAILED
  466. ///
  467. Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch) {
  468. const std::string name = graph->GetName() + "_branche_";
  469. const auto &op_desc = case_node_->GetOpDesc();
  470. for (size_t i = 0; i < batch_shapes_.size(); ++i) {
  471. std::vector<NodePtr> input_nodes;
  472. std::vector<NodePtr> output_nodes;
  473. const std::string prefix = "branche_" + std::to_string(i) + "_";
  474. ComputeGraphPtr subgraph = (i == 0) ? branch : GraphUtils::CloneGraph(branch, prefix, input_nodes, output_nodes);
  475. if (subgraph == nullptr) {
  476. GELOGE(FAILED, "Create multi-batch case node failed");
  477. return FAILED;
  478. }
  479. subgraph->SetName(name + std::to_string(i));
  480. subgraph->SetParentNode(case_node_);
  481. subgraph->SetParentGraph(graph);
  482. (void)AttrUtils::SetStr(subgraph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_);
  483. all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT);
  484. graph->AddSubgraph(subgraph->GetName(), subgraph);
  485. const std::string key_name = "branches" + std::to_string(i);
  486. op_desc->AddSubgraphName(key_name);
  487. op_desc->SetSubgraphInstanceName(i, subgraph->GetName());
  488. for (const auto &data : input_nodes) {
  489. GE_CHK_STATUS_RET(UpdataShapeToData(data, batch_shapes_[i]), "Update %s failed", subgraph->GetName().c_str());
  490. }
  491. }
  492. // Origninal graph take as first subgraph, update node name.
  493. for (const auto &n : branch->GetDirectNode()) {
  494. const auto &op_desc = n->GetOpDesc();
  495. op_desc->SetName("branche_0_" + n->GetName());
  496. if (n->GetType() == DATA) {
  497. GE_CHK_STATUS_RET(UpdataShapeToData(n, batch_shapes_[0]), "Update %s failed", branch->GetName().c_str());
  498. }
  499. }
  500. return PostProcSubgraph(graph);
  501. }
  502. ///
  503. /// @ingroup ge
  504. /// @brief Assign parent index for branches.
  505. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  506. /// @return 0: SUCCESS / others: FAILED
  507. ///
  508. Status MultiBatchClonePass::PostProcSubgraph(const ComputeGraphPtr &graph) {
  509. auto func_desc = case_node_->GetOpDesc();
  510. auto post_func = domi::OpRegistry::Instance()->GetParseSubgraphPostFunc(func_desc->GetType());
  511. if (post_func == nullptr) {
  512. GELOGW("The subgraph post func for node %s type %s is null.", case_node_->GetName().c_str(),
  513. case_node_->GetType().c_str());
  514. return FAILED;
  515. }
  516. for (const auto &name : func_desc->GetSubgraphInstanceNames()) {
  517. const auto &subgraph = graph->GetSubgraph(name);
  518. if (subgraph == nullptr) {
  519. GELOGE(FAILED, "Subgraph not found, name: %s", name.c_str());
  520. return FAILED;
  521. }
  522. std::string subgraph_name;
  523. GE_CHK_STATUS_RET(func_desc->GetSubgraphNameByInstanceName(subgraph->GetName(), subgraph_name),
  524. "Subgraph: %s get subgraph name failed.", subgraph->GetName().c_str());
  525. auto graph = GraphUtils::CreateGraphFromComputeGraph(subgraph);
  526. auto ret = post_func(subgraph_name, graph);
  527. if (ret != SUCCESS) {
  528. GELOGE(FAILED, "Failed to post-process subgraph %s on node %s type %s", graph.GetName().c_str(),
  529. case_node_->GetName().c_str(), case_node_->GetType().c_str());
  530. return FAILED;
  531. }
  532. }
  533. return SUCCESS;
  534. }
  535. ///
  536. /// @ingroup ge
  537. /// @brief Remove subgraph suspend output anchor.
  538. /// @param [in] ComputeGraphPtr &graph: Parent compute graph.
  539. /// @return 0: SUCCESS / others: FAILED
  540. ///
  541. Status MultiBatchClonePass::PruneDirectOutput(const ComputeGraphPtr &graph) {
  542. const auto &func_desc = case_node_->GetOpDesc();
  543. uint32_t unused_num = 0;
  544. uint32_t output_num = func_desc->GetOutputsSize();
  545. for (size_t i = 0; i < output_num; ++i) {
  546. bool is_unused_tensor = true;
  547. for (const auto &item : all_branch_output_) {
  548. const auto &netoutput = item.second;
  549. GE_CHECK_NOTNULL(netoutput);
  550. const auto in_anchor = netoutput->GetInDataAnchor(i);
  551. if (in_anchor->GetPeerOutAnchor() != nullptr) {
  552. is_unused_tensor = false;
  553. break;
  554. }
  555. }
  556. if (is_unused_tensor) {
  557. unused_num++;
  558. continue;
  559. }
  560. GE_CHK_STATUS_RET(UpdateOutputTensor(i, unused_num), "Graph:%s Update output failed", graph->GetName().c_str());
  561. }
  562. if (unused_num == 0) {
  563. return SUCCESS;
  564. }
  565. GE_CHK_STATUS_RET(NodeUtils::RemoveOutputAnchor(case_node_, output_num - unused_num), "Remove output failed");
  566. for (const auto &item : all_branch_output_) {
  567. GE_CHK_STATUS_RET(NodeUtils::RemoveInputAnchor(item.second, output_num - unused_num), "Remove input failed");
  568. }
  569. return SUCCESS;
  570. }
  571. ///
  572. /// @ingroup ge
  573. /// @brief Update subgraph suspend output tensor.
  574. /// @param [in] parent_index: parent index for check.
  575. /// @param [in] unused_num: total unused tensor.
  576. /// @return 0: SUCCESS / others: FAILED
  577. ///
  578. Status MultiBatchClonePass::UpdateOutputTensor(uint32_t parent_index, uint32_t unused_num) {
  579. if (unused_num == 0) {
  580. return SUCCESS;
  581. }
  582. uint32_t update_index = parent_index - unused_num;
  583. for (const auto &item : all_branch_output_) {
  584. const auto &node = item.second;
  585. const auto &new_anchor = node->GetInDataAnchor(update_index);
  586. const auto &old_anchor = node->GetInDataAnchor(parent_index);
  587. const auto &out_anchor = old_anchor->GetPeerOutAnchor();
  588. const auto &out_node = out_anchor->GetOwnerNode();
  589. const auto &op_desc = node->GetOpDesc();
  590. (void)op_desc->UpdateInputDesc(update_index, op_desc->GetInputDesc(parent_index));
  591. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_anchor, new_anchor), "Add edge failed");
  592. GELOGI("Add edge success, func node: %s, node: %s, parent index: %u, update index: %u",
  593. case_node_->GetName().c_str(), out_node->GetName().c_str(), parent_index, update_index);
  594. GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, old_anchor), "Remove edge failed");
  595. GELOGI("Remove edge success, func node: %s, node: %s", case_node_->GetName().c_str(), out_node->GetName().c_str());
  596. }
  597. const auto &new_anchor = case_node_->GetOutDataAnchor(update_index);
  598. const auto &old_anchor = case_node_->GetOutDataAnchor(parent_index);
  599. for (const auto in_anchor : old_anchor->GetPeerInDataAnchors()) {
  600. const auto &in_node = in_anchor->GetOwnerNode();
  601. GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(old_anchor, in_anchor), "Remove edge failed");
  602. GELOGI("Remove edge success, func node: %s, node: %s", case_node_->GetName().c_str(), in_node->GetName().c_str());
  603. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(new_anchor, in_anchor), "Add edge failed");
  604. GELOGI("Add edge success, func node: %s, node: %s, parent index: %u, update index: %u",
  605. case_node_->GetName().c_str(), in_node->GetName().c_str(), parent_index, update_index);
  606. }
  607. return SUCCESS;
  608. }
  609. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示