You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_batch_clone_pass.cc 72 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/multi_batch_clone_pass.h"
  17. #include "common/formats/utils/formats_trans_utils.h"
  18. #include "common/ge/ge_util.h"
  19. #include "graph/common/local_context.h"
  20. #include "graph/preprocess/multi_batch_options.h"
  21. #include "graph/utils/node_utils.h"
  22. #include "graph/utils/op_desc_utils.h"
  23. #include "graph/utils/tensor_utils.h"
  24. #include "graph/utils/type_utils.h"
  25. #include "register/op_registry.h"
  26. #include "graph/common/omg_util.h"
  27. namespace ge {
  28. namespace {
  29. constexpr uint8_t kDataInIndex = 0;
  30. constexpr uint8_t kDataOutIndex = 0;
  31. constexpr uint8_t kCaseArgIndex = 1;
  32. const int kDivisionConst = 2;
  33. const size_t kNumOfGetnextNode = 1;
  34. const std::string kMultiBatchCaseNode = "ascend_mbatch_shape_case";
  35. const std::string kMultiBatchDataNode = "ascend_mbatch_shape_data";
  36. const std::string kMultiBatchGetDynamicDimsNode = "ascend_mbatch_get_dynamic_dims_node";
  37. const std::string kMultiBatchConstNode = "ascend_mbatch_shape_const";
  38. const std::string kMultiBatchMapIndexNode = "ascend_mbatch_shape_mapindex";
  39. const std::string kMultiBatchNodePostfix = "_ascend_mbatch_batch_";
  40. const char *const kGetNextName = "IteratorV2";
  41. const char *const kMbatchCaseName = "mbatch-switch-name";
  42. } // namespace
  43. inline bool IsGetNextType(const NodePtr &node) {
  44. std::string original_type;
  45. GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS,
  46. GELOGW("Get original type failed."); return false);
  47. return (original_type == kGetNextName);
  48. }
  49. Status MultiBatchClonePass::Run(ComputeGraphPtr graph) {
  50. GE_IF_BOOL_EXEC(graph == nullptr,
  51. REPORT_INNER_ERROR("E19999", "Param graph is nullptr, check invalid");
  52. GELOGE(FAILED, "[Check][Param] Original graph is nullptr"); return FAILED);
  53. if (graph->GetParentGraph() != nullptr) {
  54. GELOGD("Subgraph %s skip the MultiBatchClonePass", graph->GetName().c_str());
  55. return SUCCESS;
  56. }
  57. if (!GetLocalOmgContext().need_multi_batch) {
  58. GELOGI("No need to process_multi for no_train graph.");
  59. return SUCCESS;
  60. }
  61. std::vector<NodePtr> data_nodes;
  62. std::vector<NodePtr> getnext_nosink_nodes;
  63. std::vector<NodePtr> getnext_sink_nodes;
  64. if (multibatch::CheckSequenceOfOptions(graph, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) {
  65. GELOGE(PARAM_INVALID, "[Train_Dynamic] [Check][SequenceOfOptions] failed, graph:%s.", graph->GetName().c_str());
  66. return PARAM_INVALID;
  67. }
  68. if (multibatch::UpdateNameOfInputShape(graph, data_nodes, getnext_nosink_nodes, getnext_sink_nodes) != SUCCESS) {
  69. GELOGE(PARAM_INVALID, "[Train_Dynamic] [Update][Name] Of InputShape failed, graph:%s.", graph->GetName().c_str());
  70. return PARAM_INVALID;
  71. }
  72. if (multibatch::DeleteIdentityInsertByAdapter(graph) != SUCCESS) {
  73. GELOGE(PARAM_INVALID, "[Train_Dynamic] [Delete][IdentityInsertByAdapter] failed, graph:%s.",
  74. graph->GetName().c_str());
  75. return PARAM_INVALID;
  76. }
  77. if (!multibatch::InitDynamicParams(batch_shapes_)) {
  78. GELOGD("There is no multi-batch options, no need clone multi-batch graph");
  79. return SUCCESS;
  80. }
  81. if (multibatch::CheckNegativeCountOfOptions(batch_shapes_) != SUCCESS) {
  82. GELOGE(PARAM_INVALID, "[Train_Dynamic] [Check][Param] Input_shape and dynamic_dims should set correct params.");
  83. return PARAM_INVALID;
  84. }
  85. GELOGD("Begin to run Multi-batch clone on graph: %s", graph->GetName().c_str());
  86. GE_CHK_STATUS_RET(multibatch::CheckDynamicParams(batch_shapes_), "[Check][Params] Invalid multi-batch param");
  87. if (CollectIoNodes(graph) != SUCCESS) {
  88. GELOGE(INTERNAL_ERROR, "[Collect][IoNodes] failed, graph:%s", graph->GetName().c_str());
  89. return INTERNAL_ERROR;
  90. }
  91. // parser data dynamic info from atc parameter --input_shape
  92. if (CheckAndParseDynamicData() != SUCCESS) {
  93. GELOGE(PARAM_INVALID, "[CheckAndParse][DynamicData] failed");
  94. return PARAM_INVALID;
  95. }
  96. (void)AttrUtils::GetStr(graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_);
  97. ComputeGraphPtr branch = MakeShared<ComputeGraph>(graph->GetName());
  98. GE_IF_BOOL_EXEC(branch == nullptr,
  99. REPORT_CALL_ERROR("E19999", "New ComputeGraph failed");
  100. GELOGE(OUT_OF_MEMORY, "[New][ComputeGraph] failed"); return OUT_OF_MEMORY);
  101. (void)AttrUtils::SetStr(branch, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id_);
  102. graph->InValid(); // Will modify, need topological again.
  103. graph->Swap(*branch);
  104. GE_CHK_STATUS_RET(CreateRootGraph(graph), "[Construct][RootGraph] for graph:%s failed.", graph->GetName().c_str());
  105. GE_CHK_STATUS_RET(CreateOriGraph(branch), "[Construct][OriGraph] for graph:%s failed.", graph->GetName().c_str());
  106. GE_CHK_STATUS_RET(CreateSubgraphs(graph, branch),
  107. "[Construct][Subgraphs] for graph:%s failed.", graph->GetName().c_str());
  108. GE_CHK_STATUS_RET(PruneDirectOutput(graph), "[Prune][DirectOutput] for graph:%s failed.", graph->GetName().c_str());
  109. GE_CHK_STATUS_RET(UpdateSubgraphOutput(), "[Update][SubgraphOutput] failed, graph:%s", graph->GetName().c_str());
  110. GELOGD("MultiBatchClonePass Leave");
  111. return SUCCESS;
  112. }
  113. ///
  114. /// @ingroup ge
  115. /// @brief Collect input output node from original graph.
  116. /// @param [in] const ComputeGraphPtr &graph: original graph.
  117. /// @return 0: SUCCESS / others: FAILED
  118. ///
  119. Status MultiBatchClonePass::CollectIoNodes(const ComputeGraphPtr &graph) {
  120. for (const auto &node : graph->GetDirectNode()) {
  121. if (!GetLocalOmgContext().dynamic_node_type.empty() && IsGetNextType(node)) {
  122. all_data_nodes_.emplace_back(node);
  123. GE_CHK_STATUS_RET(InitParamsOfGetNext(node), "[Init][Params] of %s failed.", node->GetName().c_str());
  124. }
  125. if (node->GetType() == DATA) {
  126. all_data_nodes_.emplace_back(node);
  127. } else if (node->GetType() == CONSTANT || node->GetType() == CONSTANTOP) {
  128. all_const_nodes_.emplace_back(node);
  129. } else if (node->GetType() == NETOUTPUT) {
  130. all_output_nodes_.emplace_back(node);
  131. }
  132. // If the node save as input/output node, delete record.
  133. (void)graph->RemoveInputNode(node);
  134. (void)graph->RemoveOutputNode(node);
  135. }
  136. if (all_data_nodes_.empty() || all_output_nodes_.size() != 1) {
  137. REPORT_INNER_ERROR("E19999", "Data node num is 0 or output node num != 1, graph:%s, check invalid",
  138. graph->GetName().c_str());
  139. GELOGE(FAILED, "[Check][Param] Data node num is 0 or output node num != 1, graph:%s", graph->GetName().c_str());
  140. return FAILED;
  141. }
  142. int64_t data_index = 0;
  143. size_t getnext_node_count = 0;
  144. for (size_t i = 0; i < all_data_nodes_.size(); ++i) {
  145. if (IsGetNextType(all_data_nodes_[i])) {
  146. // just one getnext node in graph
  147. getnext_node_count++;
  148. continue;
  149. }
  150. const auto &op_desc = all_data_nodes_[i]->GetOpDesc();
  151. if (!AttrUtils::GetInt(op_desc, ATTR_NAME_INDEX, data_index)) {
  152. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, i - getnext_node_count);
  153. }
  154. }
  155. const auto &output = all_output_nodes_[0];
  156. for (size_t i = 0; i < output->GetAllInDataAnchorsSize(); ++i) {
  157. const auto in_anchor = output->GetInDataAnchor(i);
  158. const auto out_anchor = in_anchor->GetPeerOutAnchor();
  159. const auto data_node = out_anchor->GetOwnerNode();
  160. if (data_node->GetType() == DATA) {
  161. direct_output_[i] = data_node->GetName();
  162. GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(data_node->GetOutDataAnchor(kDataOutIndex),
  163. output->GetInDataAnchor(i)),
  164. "[Remove][Edge] between %s(index:%u) and %s(index:%zu) failed",
  165. data_node->GetName().c_str(), kDataOutIndex, output->GetName().c_str(), i);
  166. }
  167. }
  168. GELOGD("Data count is %zu, const count is %zu, getnext count is %zu, output count is %zu, direct out count is %zu.",
  169. all_data_nodes_.size(), all_const_nodes_.size(), getnext_node_count, all_output_nodes_.size(),
  170. direct_output_.size());
  171. return SUCCESS;
  172. }
  173. Status MultiBatchClonePass::CheckAndParseDynamicData() {
  174. size_t unknown_shape_count = 0;
  175. auto data_name_and_shape = GetLocalOmgContext().user_input_dims;
  176. std::vector<std::string> data_name_order;
  177. for (auto &item : data_name_and_shape) {
  178. data_name_order.push_back(item.first);
  179. }
  180. if (!getnext_sink_dynamic_dims_) {
  181. for (const auto &node : all_data_nodes_) {
  182. auto data_desc = NodeUtils::GetOutputDesc(*node, kDataOutIndex);
  183. auto data_shape = data_desc.GetShape();
  184. auto data_format = data_desc.GetFormat() == Format::FORMAT_NCHW ? "NCHW" :
  185. data_desc.GetFormat() == Format::FORMAT_NHWC ? "NHWC" : "Others";
  186. auto data_name = node->GetName();
  187. const auto &data_shape_dims = data_shape.GetDims();
  188. if (std::all_of(data_shape_dims.begin(), data_shape_dims.end(), [](int64_t val) { return val >= 0; })) {
  189. continue;
  190. }
  191. ++unknown_shape_count;
  192. auto iter = find(data_name_order.begin(), data_name_order.end(), data_name);
  193. if (iter == data_name_order.end()) {
  194. if (!GetLocalOmgContext().dynamic_batch_size.empty()) {
  195. auto ret = multibatch::CheckDynamicBatchShape(data_shape_dims, data_name);
  196. GE_IF_BOOL_EXEC(ret == false,
  197. GELOGE(PARAM_INVALID, "[Check][DynamicBatchShape] of %s failed.", data_name.c_str());
  198. return PARAM_INVALID);
  199. } else if (!GetLocalOmgContext().dynamic_image_size.empty()) {
  200. auto ret = multibatch::CheckDynamicImageSizeShape(data_shape_dims, data_name, data_format);
  201. GE_IF_BOOL_EXEC(ret == false,
  202. GELOGE(PARAM_INVALID, "[Check][DynamicImageSizeShape] of %s failed.", data_name.c_str());
  203. return PARAM_INVALID);
  204. } else if (!GetLocalOmgContext().dynamic_dims.empty()) {
  205. ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"},
  206. {"--dynamic_dims", data_name, "all dynamic node must be set in --input_shape, please check"});
  207. GELOGE(INTERNAL_ERROR, "[Check][Param] data:%s shape:%s must be set int --input_shape",
  208. node->GetName().c_str(), data_shape.ToString().c_str());
  209. return INTERNAL_ERROR;
  210. }
  211. data_name_and_shape.emplace_back(data_name, data_shape_dims);
  212. }
  213. }
  214. }
  215. auto ret = multibatch::ParserDataToDynamicInfo(batch_shapes_, data_name_and_shape, data_to_dynamic_info_);
  216. GE_CHK_STATUS_RET(ret, "[Parser][DataToDynamicInfo] failed.");
  217. if (!getnext_sink_dynamic_dims_ && unknown_shape_count == 0) {
  218. ErrorManager::GetInstance().ATCReportErrMessage("E10040");
  219. GELOGE(PARAM_INVALID, "[Check][Param] Need unknow shape data "
  220. "when user set --dynamic_batch_size, --dynamic_image_size or --dynamic_dims");
  221. return PARAM_INVALID;
  222. }
  223. return SUCCESS;
  224. }
  225. Status MultiBatchClonePass::InitParamsOfGetNext(const NodePtr &node) {
  226. data_count_from_getnext_ = 0;
  227. getnext_sink_dynamic_dims_ = false;
  228. GE_CHECK_NOTNULL(node->GetOpDesc());
  229. data_count_from_getnext_ = node->GetOpDesc()->GetOutputsSize();
  230. if (GetLocalOmgContext().dynamic_node_type == GETNEXT) {
  231. data_count_from_getnext_ = data_count_from_getnext_ / kDivisionConst;
  232. for (size_t i = 0; i < data_count_from_getnext_; ++i) {
  233. GeTensorDesc output_desc = node->GetOpDesc()->GetOutputDesc(i);
  234. GELOGD("The %zu data shape from getnext sink is %s.", i,
  235. formats::JoinToString(output_desc.GetShape().GetDims()).c_str());
  236. const auto &dims = output_desc.GetShape().GetDims();
  237. if (std::all_of(dims.begin(), dims.end(), [](int64_t val) {return val >= 0; })) {
  238. GELOGD("The %zu data from %s is static.", i, node->GetName().c_str());
  239. } else {
  240. getnext_sink_dynamic_dims_ = true;
  241. GELOGD("Dynamic dims in the pattern of getnext sink.");
  242. }
  243. }
  244. }
  245. if (node->GetOutControlAnchor() != nullptr) {
  246. for (const auto &peer_in_control_anchor : node->GetOutControlAnchor()->GetPeerInControlAnchors()) {
  247. NodePtr next_node = peer_in_control_anchor->GetOwnerNode();
  248. GE_CHECK_NOTNULL(next_node);
  249. if (next_node->GetType() == CONSTANTOP) {
  250. out_control_nodes_.insert(next_node);
  251. GELOGD("Control edge: %s connect with %s.", node->GetName().c_str(), next_node->GetName().c_str());
  252. }
  253. }
  254. }
  255. return SUCCESS;
  256. }
  257. ///
  258. /// @ingroup ge
  259. /// @brief Create nodes for root graph.
  260. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  261. /// @return 0: SUCCESS / others: FAILED
  262. ///
  263. Status MultiBatchClonePass::CreateRootGraph(const ComputeGraphPtr &graph) {
  264. GELOGD("Start create root graph of %s.", graph->GetName().c_str());
  265. uint32_t input_num = all_data_nodes_.size() + all_const_nodes_.size();
  266. if (data_count_from_getnext_ != 0) {
  267. input_num = input_num + data_count_from_getnext_ - kNumOfGetnextNode;
  268. }
  269. uint32_t output_num = all_output_nodes_[0]->GetAllInDataAnchorsSize();
  270. OpDescBuilder op_builder(kMultiBatchCaseNode, CASE);
  271. op_builder.AddInput("branch_index").AddDynamicInput("input", input_num).AddDynamicOutput("output", output_num);
  272. const OpDescPtr op_desc = op_builder.Build();
  273. if (op_desc == nullptr) {
  274. REPORT_INNER_ERROR("E19999", "Build op:%s(%s) failed", kMultiBatchCaseNode.c_str(), CASE);
  275. GELOGE(OUT_OF_MEMORY, "[Build][Op] %s(%s) failed", kMultiBatchCaseNode.c_str(), CASE);
  276. return OUT_OF_MEMORY;
  277. }
  278. op_desc->RegisterSubgraphIrName("branches", kDynamic);
  279. case_node_ = graph->AddNode(op_desc);
  280. if (case_node_ == nullptr) {
  281. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  282. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  283. GELOGE(OUT_OF_MEMORY, "[Add][Node] %s(%s) to graph:%s failed",
  284. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  285. return OUT_OF_MEMORY;
  286. }
  287. uint32_t batch_num = static_cast<uint32_t>(batch_shapes_.size());
  288. if (!AttrUtils::SetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) {
  289. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_BATCH_NUM.c_str(),
  290. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  291. GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_BATCH_NUM.c_str(),
  292. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  293. return FAILED;
  294. }
  295. for (uint32_t i = 0; i < batch_num; i++) {
  296. const std::string &attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i);
  297. if (!AttrUtils::SetListInt(op_desc, attr_name, batch_shapes_[i])) {
  298. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", attr_name.c_str(),
  299. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  300. GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", attr_name.c_str(),
  301. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  302. return FAILED;
  303. }
  304. }
  305. std::vector<std::string> data_name_order;
  306. for (auto &item : GetLocalOmgContext().user_input_dims) {
  307. data_name_order.push_back(item.first);
  308. }
  309. if (!AttrUtils::SetListStr(op_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, data_name_order)) {
  310. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_USER_DESIGNEATE_SHAPE_ORDER.c_str(),
  311. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  312. GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", ATTR_USER_DESIGNEATE_SHAPE_ORDER.c_str(),
  313. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  314. return FAILED;
  315. }
  316. if (!AttrUtils::SetBool(op_desc, ATTR_INSERT_BY_MBATCH, true)) {
  317. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_INSERT_BY_MBATCH.c_str(),
  318. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  319. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s to op:%s(%s) failed", ATTR_INSERT_BY_MBATCH.c_str(),
  320. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  321. return INTERNAL_ERROR;
  322. }
  323. GE_CHK_STATUS_RET(multibatch::StampDynamicType(op_desc),
  324. "[Call][StampDynamicType] for op:%s(%s) failed",
  325. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  326. GE_CHK_STATUS_RET(CreateIndexNode(graph), "[Create][IndexNode] for graph:%s failed", graph->GetName().c_str());
  327. GE_CHK_STATUS_RET(CreateInputNode(graph), "[Create][InputNode] for graph:%s failed", graph->GetName().c_str());
  328. GE_CHK_STATUS_RET(CreateConstNode(graph), "[Create][ConstNode] for graph:%s failed", graph->GetName().c_str());
  329. GE_CHK_STATUS_RET(CreateOutputNode(graph), "[Create][OutputNode] for graph:%s failed", graph->GetName().c_str());
  330. return SUCCESS;
  331. }
  332. ///
  333. /// @ingroup ge
  334. /// @brief Create index data node for root graph.
  335. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  336. /// @param [in] NodePtr node: index data node.
  337. /// @return 0: SUCCESS / others: FAILED
  338. ///
  339. Status MultiBatchClonePass::CreateIndexDataNode(const ComputeGraphPtr &graph, NodePtr &shape_node) {
  340. const OpDescPtr data_desc = MakeShared<OpDesc>(kMultiBatchDataNode, DATA);
  341. if (data_desc == nullptr) {
  342. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  343. GELOGE(OUT_OF_MEMORY, "[New][OpDesc] failed");
  344. return FAILED;
  345. }
  346. GeTensorDesc data_tensor(GeShape({static_cast<int64_t>(batch_shapes_[0].size())}), FORMAT_ND, DT_INT32);
  347. if (data_desc->AddInputDesc(data_tensor) != GRAPH_SUCCESS) {
  348. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
  349. data_desc->GetName().c_str(), data_desc->GetType().c_str());
  350. GELOGE(FAILED, "[Add][InputDesc] to op:%s(%s) failed",
  351. data_desc->GetName().c_str(), data_desc->GetType().c_str());
  352. return FAILED;
  353. }
  354. if (data_desc->AddOutputDesc(data_tensor) != GRAPH_SUCCESS) {
  355. REPORT_CALL_ERROR("E19999", "Add ouput desc to op:%s(%s) failed",
  356. data_desc->GetName().c_str(), data_desc->GetType().c_str());
  357. GELOGE(FAILED, "[Add][OutputDesc] to op:%s(%s) failed",
  358. data_desc->GetName().c_str(), data_desc->GetType().c_str());
  359. return FAILED;
  360. }
  361. size_t data_index = all_data_nodes_.size();
  362. data_index = data_count_from_getnext_ != 0 ? data_index - kNumOfGetnextNode : data_index;
  363. (void)AttrUtils::SetInt(data_desc, ATTR_NAME_INDEX, data_index);
  364. (void)AttrUtils::SetBool(data_desc, ATTR_INSERT_BY_MBATCH, true);
  365. shape_node = graph->AddNode(data_desc);
  366. if (shape_node == nullptr) {
  367. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  368. data_desc->GetName().c_str(), data_desc->GetType().c_str(), graph->GetName().c_str());
  369. GELOGE(OUT_OF_MEMORY, "[Add][Node] %s(%s) to graph:%s failed",
  370. data_desc->GetName().c_str(), data_desc->GetType().c_str(), graph->GetName().c_str());
  371. return OUT_OF_MEMORY;
  372. }
  373. return SUCCESS;
  374. }
  375. ///
  376. /// @ingroup ge
  377. /// @brief Create index const node for root graph.
  378. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  379. /// @param [in] NodePtr node: index const node.
  380. /// @return 0: SUCCESS / others: FAILED
  381. ///
  382. Status MultiBatchClonePass::CreateIndexConstNode(const ComputeGraphPtr &graph, NodePtr &node) {
  383. const OpDescPtr const_desc = MakeShared<OpDesc>(kMultiBatchConstNode, CONSTANT);
  384. if (const_desc == nullptr) {
  385. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  386. GELOGE(OUT_OF_MEMORY, "[New][OpDesc] failed");
  387. return FAILED;
  388. }
  389. int64_t count = batch_shapes_.size() * batch_shapes_[0].size();
  390. std::unique_ptr<int32_t[]> addr(new (std::nothrow) int32_t[count]);
  391. GE_CHECK_NOTNULL(addr);
  392. size_t i = 0;
  393. for (auto &batch_shape : batch_shapes_) {
  394. for (int64_t dim : batch_shape) {
  395. addr[i++] = static_cast<int32_t>(dim);
  396. }
  397. }
  398. GeTensorDesc const_tensor(GeShape({count}), FORMAT_ND, DT_INT32);
  399. GeTensor tensor(const_tensor);
  400. (void)tensor.SetData(reinterpret_cast<uint8_t *>(addr.get()), count * sizeof(int32_t));
  401. if (!AttrUtils::SetTensor(const_desc, ATTR_NAME_WEIGHTS, tensor)) {
  402. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_WEIGHTS.c_str(),
  403. const_desc->GetName().c_str(), const_desc->GetType().c_str());
  404. GELOGE(OUT_OF_MEMORY, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_WEIGHTS.c_str(),
  405. const_desc->GetName().c_str(), const_desc->GetType().c_str());
  406. return FAILED;
  407. }
  408. if (const_desc->AddOutputDesc(const_tensor) != GRAPH_SUCCESS) {
  409. REPORT_CALL_ERROR("E19999", "Add ouput desc to op:%s(%s) failed",
  410. const_desc->GetName().c_str(), const_desc->GetType().c_str());
  411. GELOGE(OUT_OF_MEMORY, "[Add][OutputDesc] to op:%s(%s) failed",
  412. const_desc->GetName().c_str(), const_desc->GetType().c_str());
  413. return FAILED;
  414. }
  415. node = graph->AddNode(const_desc);
  416. if (node == nullptr) {
  417. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  418. const_desc->GetName().c_str(), const_desc->GetType().c_str(), graph->GetName().c_str());
  419. GELOGE(OUT_OF_MEMORY, "[Add][Node] %s(%s) to graph:%s failed",
  420. const_desc->GetName().c_str(), const_desc->GetType().c_str(), graph->GetName().c_str());
  421. return OUT_OF_MEMORY;
  422. }
  423. return SUCCESS;
  424. }
  425. ///
  426. /// @ingroup ge
  427. /// @brief Create index node for root graph.
  428. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  429. /// @return 0: SUCCESS / others: FAILED
  430. ///
  431. Status MultiBatchClonePass::CreateIndexNode(const ComputeGraphPtr &graph) {
  432. // Data/GetDynamicDims --> MapIndex --> Case
  433. if (!getnext_sink_dynamic_dims_) {
  434. GE_CHK_STATUS_RET(CreateIndexDataNode(graph, shape_node_),
  435. "[Create][IndexDataNode] failed, graph:%s", graph->GetName().c_str());
  436. } else {
  437. GE_CHK_STATUS_RET(CreateGetDynamicDimsNode(graph, shape_node_),
  438. "[Create][GetDynamicDimsNode] failed, graph:%s", graph->GetName().c_str());
  439. }
  440. NodePtr const_node;
  441. GE_CHK_STATUS_RET(CreateIndexConstNode(graph, const_node),
  442. "[Create][ConstNode] failed, graph:%s", graph->GetName().c_str());
  443. GELOGD("Shape node name is %s, type is %s, const node name is %s.", shape_node_->GetName().c_str(),
  444. shape_node_->GetType().c_str(), const_node->GetName().c_str());
  445. OpDescBuilder op_builder(kMultiBatchMapIndexNode, "MapIndex");
  446. op_builder.AddInput("x", shape_node_->GetOpDesc()->GetOutputDesc(0))
  447. .AddInput("data_seq", const_node->GetOpDesc()->GetOutputDesc(0))
  448. .AddOutput("y", GeTensorDesc(GeShape(), FORMAT_ND, DT_INT32));
  449. const OpDescPtr op_desc = op_builder.Build();
  450. if (op_desc == nullptr) {
  451. REPORT_INNER_ERROR("E19999", "Build op:%s(%s) failed", kMultiBatchMapIndexNode.c_str(), "MapIndex");
  452. GELOGE(OUT_OF_MEMORY, "[Build][Op] %s(MapIndex) failed", kMultiBatchMapIndexNode.c_str());
  453. return FAILED;
  454. }
  455. NodePtr index_node = graph->AddNode(op_desc);
  456. if (index_node == nullptr) {
  457. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  458. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  459. GELOGE(OUT_OF_MEMORY, "[Add][Node] %s(%s) to graph:%s failed",
  460. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  461. return OUT_OF_MEMORY;
  462. }
  463. GE_CHK_STATUS_RET(AddAttrForGetDynamicDims(shape_node_), "[Add][Attr] for %s failed.",
  464. shape_node_->GetName().c_str());
  465. if (GraphUtils::AddEdge(shape_node_->GetOutDataAnchor(0), index_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  466. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:0) failed",
  467. shape_node_->GetName().c_str(), shape_node_->GetType().c_str(),
  468. index_node->GetName().c_str(), index_node->GetType().c_str());
  469. GELOGE(FAILED, "[Add][Edge] between op:%s(%s)(index:0) and op:%s(%s)(index:0) failed",
  470. shape_node_->GetName().c_str(), shape_node_->GetType().c_str(),
  471. index_node->GetName().c_str(), index_node->GetType().c_str());
  472. return FAILED;
  473. }
  474. if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), index_node->GetInDataAnchor(1)) != GRAPH_SUCCESS) {
  475. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:1) failed",
  476. const_node->GetName().c_str(), const_node->GetType().c_str(),
  477. index_node->GetName().c_str(), index_node->GetType().c_str());
  478. GELOGE(FAILED, "[Add][Edge] between node:%s to MapIndex:%s", const_node->GetName().c_str(),
  479. index_node->GetName().c_str());
  480. return FAILED;
  481. }
  482. if (GraphUtils::AddEdge(index_node->GetOutDataAnchor(0), case_node_->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  483. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:0) failed",
  484. index_node->GetName().c_str(), index_node->GetType().c_str(),
  485. case_node_->GetName().c_str(), case_node_->GetType().c_str());
  486. GELOGE(FAILED, "[Add][Edge] between op:%s(%s)(index:0) and op:%s(%s)(index:0) failed",
  487. index_node->GetName().c_str(), index_node->GetType().c_str(),
  488. case_node_->GetName().c_str(), case_node_->GetType().c_str());
  489. return FAILED;
  490. }
  491. return SUCCESS;
  492. }
  493. Status MultiBatchClonePass::CreateGetDynamicDimsNode(const ComputeGraphPtr &graph, NodePtr &shape_node) {
  494. const OpDescPtr data_desc = MakeShared<OpDesc>(kMultiBatchGetDynamicDimsNode, GETDYNAMICDIMS);
  495. if (data_desc == nullptr) {
  496. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  497. GELOGE(OUT_OF_MEMORY, "[New][OpDesc] failed");
  498. return OUT_OF_MEMORY;
  499. }
  500. // input of GetDynamicDims is shape_of_each_data, output is gear_info
  501. for (size_t i = 0; i < GetLocalOmgContext().user_input_dims.size(); ++i) {
  502. size_t input_shape_dims = GetLocalOmgContext().user_input_dims.at(i).second.size();
  503. // add input desc without GeShape for const input, value of input_shape is 1 transferred by adapter
  504. if (input_shape_dims == 1 && GetLocalOmgContext().user_input_dims.at(i).second.at(0) == 0) {
  505. GeTensorDesc tensor_desc;
  506. tensor_desc.SetFormat(FORMAT_ND);
  507. tensor_desc.SetDataType(DT_INT32);
  508. auto ret = data_desc->AddInputDesc(tensor_desc);
  509. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS,
  510. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
  511. data_desc->GetName().c_str(), data_desc->GetType().c_str());
  512. GELOGE(INTERNAL_ERROR, "[Add][InputDesc] to op:%s(%s) failed",
  513. data_desc->GetName().c_str(), data_desc->GetType().c_str());
  514. return FAILED);
  515. continue;
  516. }
  517. GeTensorDesc tensor_desc(GeShape({static_cast<int32_t>(input_shape_dims)}), FORMAT_ND, DT_INT32);
  518. auto ret = data_desc->AddInputDesc(tensor_desc);
  519. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS,
  520. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
  521. data_desc->GetName().c_str(), data_desc->GetType().c_str());
  522. GELOGE(INTERNAL_ERROR, "[Add][InputDesc] to op:%s(%s) failed",
  523. data_desc->GetName().c_str(), data_desc->GetType().c_str());
  524. return FAILED);
  525. }
  526. GeTensorDesc tensor_desc(GeShape({static_cast<int32_t>(batch_shapes_.at(0).size())}), FORMAT_ND, DT_INT32);
  527. auto ret = data_desc->AddOutputDesc(tensor_desc);
  528. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS,
  529. REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed",
  530. data_desc->GetName().c_str(), data_desc->GetType().c_str());
  531. GELOGE(INTERNAL_ERROR, "[Add][OutputDesc] to op:%s(%s) failed",
  532. data_desc->GetName().c_str(), data_desc->GetType().c_str());
  533. return FAILED);
  534. (void)AttrUtils::SetBool(data_desc, ATTR_INSERT_BY_MBATCH, true);
  535. shape_node = graph->AddNode(data_desc);
  536. if (shape_node == nullptr) {
  537. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  538. data_desc->GetName().c_str(), data_desc->GetType().c_str(), graph->GetName().c_str());
  539. GELOGE(OUT_OF_MEMORY, "[Add][Node] %s(%s) to graph:%s failed",
  540. data_desc->GetName().c_str(), data_desc->GetType().c_str(), graph->GetName().c_str());
  541. return OUT_OF_MEMORY;
  542. }
  543. return SUCCESS;
  544. }
  545. Status MultiBatchClonePass::AddAttrForGetDynamicDims(const NodePtr &shape_node) {
  546. if (!getnext_sink_dynamic_dims_) {
  547. GELOGD("No need to add attr when not insert get dynamic dims node.");
  548. return SUCCESS;
  549. }
  550. GELOGD("Add attr for :%s, type is %s:", shape_node->GetName().c_str(), shape_node->GetType().c_str());
  551. if (!AttrUtils::SetInt(shape_node->GetOpDesc(), ATTR_GETNEXT_SINK_DATA_COUNT, data_count_from_getnext_)) {
  552. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_GETNEXT_SINK_DATA_COUNT.c_str(),
  553. shape_node->GetName().c_str(), shape_node->GetType().c_str());
  554. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s to op:%s(%s) failed", ATTR_GETNEXT_SINK_DATA_COUNT.c_str(),
  555. shape_node->GetName().c_str(), shape_node->GetType().c_str());
  556. return INTERNAL_ERROR;
  557. }
  558. vector<int64_t> shape_info;
  559. for (size_t i = 0; i < GetLocalOmgContext().user_input_dims.size(); ++i) {
  560. if (GetLocalOmgContext().user_input_dims.at(i).second.size() == 1 &&
  561. GetLocalOmgContext().user_input_dims.at(i).second.at(0) == 0) {
  562. shape_info.emplace_back(0);
  563. continue;
  564. }
  565. shape_info.emplace_back(GetLocalOmgContext().user_input_dims.at(i).second.size());
  566. for (size_t j = 0; j < GetLocalOmgContext().user_input_dims.at(i).second.size(); ++j) {
  567. shape_info.emplace_back(GetLocalOmgContext().user_input_dims.at(i).second.at(j));
  568. }
  569. }
  570. if (!AttrUtils::SetListInt(shape_node->GetOpDesc(), ATTR_GETNEXT_SINK_SHAPE_INFO, shape_info)) {
  571. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_GETNEXT_SINK_SHAPE_INFO.c_str(),
  572. shape_node->GetName().c_str(), shape_node->GetType().c_str());
  573. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s to op:%s(%s) failed", ATTR_GETNEXT_SINK_SHAPE_INFO.c_str(),
  574. shape_node->GetName().c_str(), shape_node->GetType().c_str());
  575. return INTERNAL_ERROR;
  576. }
  577. return SUCCESS;
  578. }
  579. Status MultiBatchClonePass::LinkGetNextToGetDynamicDims(const NodePtr &getnext_node, const NodePtr &shape_node) {
  580. GELOGD("Start relink shape anchor of %s to %s.", getnext_node->GetName().c_str(), shape_node->GetName().c_str());
  581. size_t input_index = 0;
  582. size_t data_count = getnext_node->GetAllOutDataAnchors().size() / kDivisionConst;
  583. for (size_t out_index = data_count; out_index < getnext_node->GetAllOutDataAnchors().size(); ++out_index,
  584. ++input_index) {
  585. GELOGD("Start add %s of %zu out_anchor to %s of %zu in_anchor.", getnext_node->GetName().c_str(), out_index,
  586. shape_node->GetName().c_str(), input_index);
  587. auto out_data_anchor = getnext_node->GetOutDataAnchor(out_index);
  588. auto ret = GraphUtils::AddEdge(out_data_anchor, shape_node->GetInDataAnchor(input_index));
  589. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS,
  590. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%zu) and op:%s(%s)(index:%zu) failed",
  591. getnext_node->GetName().c_str(), getnext_node->GetType().c_str(), out_index,
  592. shape_node->GetName().c_str(), shape_node->GetType().c_str(), input_index);
  593. GELOGE(INTERNAL_ERROR, "[Add][Edge] between op:%s(%s)(index:%zu) and op:%s(%s)(index:%zu) failed",
  594. getnext_node->GetName().c_str(), getnext_node->GetType().c_str(), out_index,
  595. shape_node->GetName().c_str(), shape_node->GetType().c_str(), input_index);
  596. return INTERNAL_ERROR);
  597. }
  598. return SUCCESS;
  599. }
  600. Status MultiBatchClonePass::LinkGetDynamicDimsToNetOutput(const NodePtr &output_node) {
  601. if (!GetLocalOmgContext().dynamic_node_type.empty()) {
  602. if (!AttrUtils::SetStr(output_node->GetOpDesc(), ATTR_ALL_GEARS_INFO, GetLocalOmgContext().dynamic_dims)) {
  603. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_ALL_GEARS_INFO.c_str(),
  604. output_node->GetName().c_str(), output_node->GetType().c_str());
  605. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s to op:%s(%s) failed", ATTR_ALL_GEARS_INFO.c_str(),
  606. output_node->GetName().c_str(), output_node->GetType().c_str());
  607. return INTERNAL_ERROR;
  608. }
  609. }
  610. if (getnext_sink_dynamic_dims_) {
  611. GELOGD("Start link %s to %s.", shape_node_->GetName().c_str(), output_node->GetName().c_str());
  612. size_t input_index = output_node->GetAllInDataAnchors().size();
  613. if (NodeUtils::AppendInputAnchor(output_node, input_index + 1) != GRAPH_SUCCESS) {
  614. REPORT_CALL_ERROR("E19999", "Append input anchor to op:%s(%s) failed, size:%zu",
  615. output_node->GetName().c_str(), output_node->GetType().c_str(), input_index + 1);
  616. GELOGE(INTERNAL_ERROR, "[Append][InputAnchor] to op:%s(%s) failed, size:%zu",
  617. output_node->GetName().c_str(), output_node->GetType().c_str(), input_index + 1);
  618. return INTERNAL_ERROR;
  619. }
  620. auto ret = GraphUtils::AddEdge(shape_node_->GetOutDataAnchor(kDataOutIndex),
  621. output_node->GetInDataAnchor(input_index));
  622. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS,
  623. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%zu) failed",
  624. shape_node_->GetName().c_str(), shape_node_->GetType().c_str(), kDataOutIndex,
  625. output_node->GetName().c_str(), output_node->GetType().c_str(), input_index);
  626. GELOGE(INTERNAL_ERROR, "[Add][Edge] between op:%s(%s)(index:%d) and op:%s(%s)(index:%zu) failed",
  627. shape_node_->GetName().c_str(), shape_node_->GetType().c_str(), kDataOutIndex,
  628. output_node->GetName().c_str(), output_node->GetType().c_str(), input_index);
  629. return INTERNAL_ERROR);
  630. if (!AttrUtils::SetBool(output_node->GetOpDesc(), ATTR_GETNEXT_SINK_DYNMAIC, true)) {
  631. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_GETNEXT_SINK_DYNMAIC.c_str(),
  632. output_node->GetName().c_str(), output_node->GetType().c_str());
  633. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s to op:%s(%s) failed", ATTR_GETNEXT_SINK_DYNMAIC.c_str(),
  634. output_node->GetName().c_str(), output_node->GetType().c_str());
  635. return INTERNAL_ERROR;
  636. }
  637. }
  638. return SUCCESS;
  639. }
  640. ///
  641. /// @ingroup ge
  642. /// @brief Create input node for root graph.
  643. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  644. /// @return 0: SUCCESS / others: FAILED
  645. ///
  646. Status MultiBatchClonePass::CreateInputNode(const ComputeGraphPtr &graph) {
  647. // Data --> Case
  648. std::vector<NodePtr> all_data_nodes;
  649. size_t case_input_index = kCaseArgIndex;
  650. NodePtr getnext_node = nullptr;
  651. size_t input_index_of_getnext = 0;
  652. for (size_t i = 0; i < all_data_nodes_.size(); ++i, ++case_input_index) {
  653. const auto &node = all_data_nodes_[i];
  654. const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc());
  655. if (op_desc == nullptr) {
  656. REPORT_CALL_ERROR("E19999", "Copy op_desc from op:%s(%s) failed",
  657. node->GetName().c_str(), node->GetType().c_str());
  658. GELOGE(OUT_OF_MEMORY, "[Copy][OpDesc] from op:%s(%s) failed",
  659. node->GetName().c_str(), node->GetType().c_str());
  660. return FAILED;
  661. }
  662. if (GraphUtils::CopyTensorAttrs(op_desc, node) != GRAPH_SUCCESS) {
  663. REPORT_CALL_ERROR("E19999", "Copy tensor attr from op:%s(%s) failed",
  664. node->GetName().c_str(), node->GetType().c_str());
  665. GELOGE(OUT_OF_MEMORY, "[Copy][TensorAttrs] from op:%s(%s) failed",
  666. node->GetName().c_str(), node->GetType().c_str());
  667. return FAILED;
  668. }
  669. op_desc->SetName(node->GetName());
  670. const NodePtr &data = graph->AddNode(op_desc);
  671. GE_CHK_BOOL_EXEC(data != nullptr,
  672. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  673. op_desc->GetName().c_str(), op_desc->GetType().c_str(),
  674. graph->GetName().c_str());
  675. return FAILED,
  676. "[Add][Node] %s(%s) to graph:%s failed",
  677. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  678. if (IsGetNextType(node)) {
  679. getnext_node = data;
  680. input_index_of_getnext = case_input_index;
  681. case_input_index = case_input_index + data_count_from_getnext_;
  682. continue;
  683. } else {
  684. if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(case_input_index)) !=
  685. GRAPH_SUCCESS) {
  686. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%zu) failed",
  687. data->GetName().c_str(), data->GetType().c_str(),
  688. case_node_->GetName().c_str(), case_node_->GetType().c_str(), case_input_index);
  689. GELOGE(FAILED, "[Add][Edge] between op:%s(%s)(index:0) and op:%s(%s)(index:%zu) failed",
  690. data->GetName().c_str(), data->GetType().c_str(),
  691. case_node_->GetName().c_str(), case_node_->GetType().c_str(), case_input_index);
  692. return FAILED;
  693. }
  694. }
  695. if (SetMaxShape(data) != SUCCESS) {
  696. GELOGE(FAILED, "[Set][MaxShape] of %s failed.", data->GetName().c_str());
  697. return FAILED;
  698. }
  699. all_data_nodes.emplace_back(data);
  700. }
  701. if (getnext_node != nullptr) {
  702. if (LinkEdgeForGetNext(getnext_node, input_index_of_getnext) != SUCCESS) {
  703. GELOGE(FAILED, "[Link][Edge] for %s failed.", getnext_node->GetName().c_str());
  704. return FAILED;
  705. }
  706. if (SetMaxShape(getnext_node) != SUCCESS) {
  707. GELOGE(FAILED, "[Set][MaxShape] of %s failed.", getnext_node->GetName().c_str());
  708. return FAILED;
  709. }
  710. all_data_nodes.emplace_back(getnext_node);
  711. }
  712. all_data_nodes_.swap(all_data_nodes);
  713. return SUCCESS;
  714. }
  715. Status MultiBatchClonePass::LinkEdgeForGetNext(const NodePtr &getnext_node, size_t &case_input_index) {
  716. GELOGD("Start link edge for %s, which is the %zu input of %s.", getnext_node->GetName().c_str(),
  717. case_input_index, case_node_->GetName().c_str());
  718. for (size_t out_index = 0; out_index < data_count_from_getnext_; ++out_index, ++case_input_index) {
  719. if (GraphUtils::AddEdge(getnext_node->GetOutDataAnchor(out_index),
  720. case_node_->GetInDataAnchor(case_input_index)) != GRAPH_SUCCESS) {
  721. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%zu) and op:%s(%s)(index:%zu) failed",
  722. getnext_node->GetName().c_str(), getnext_node->GetType().c_str(), out_index,
  723. case_node_->GetName().c_str(), case_node_->GetType().c_str(), case_input_index);
  724. GELOGE(FAILED, "[Add][Edge] between op:%s(%s)(index:%zu) and op:%s(%s)(index:%zu) failed",
  725. getnext_node->GetName().c_str(), getnext_node->GetType().c_str(), out_index,
  726. case_node_->GetName().c_str(), case_node_->GetType().c_str(), case_input_index);
  727. return FAILED;
  728. }
  729. }
  730. if (getnext_sink_dynamic_dims_) {
  731. GE_CHK_STATUS_RET(LinkGetNextToGetDynamicDims(getnext_node, shape_node_), "[Add][Link] for %s failed.",
  732. shape_node_->GetName().c_str());
  733. }
  734. return SUCCESS;
  735. }
  736. ///
  737. /// @ingroup ge
  738. /// @brief Create Const node for root graph.
  739. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  740. /// @return 0: SUCCESS / others: FAILED
  741. ///
  742. Status MultiBatchClonePass::CreateConstNode(const ComputeGraphPtr &graph) {
  743. // Const --> Case
  744. std::vector<NodePtr> all_const_nodes;
  745. size_t arg_index = kCaseArgIndex + all_data_nodes_.size();
  746. if (data_count_from_getnext_ != 0) {
  747. arg_index = arg_index + data_count_from_getnext_ - kNumOfGetnextNode;
  748. }
  749. for (size_t i = 0; i < all_const_nodes_.size(); ++i) {
  750. const auto &node = all_const_nodes_[i];
  751. const OpDescPtr op_desc = AttrUtils::CopyOpDesc(node->GetOpDesc());
  752. if (op_desc == nullptr) {
  753. REPORT_CALL_ERROR("E19999", "Copy op_desc from op:%s(%s) failed",
  754. node->GetName().c_str(), node->GetType().c_str());
  755. GELOGE(OUT_OF_MEMORY, "[Copy][OpDesc] from op:%s(%s) failed", node->GetName().c_str(), node->GetType().c_str());
  756. return FAILED;
  757. }
  758. op_desc->SetName(node->GetName());
  759. if (GraphUtils::CopyTensorAttrs(op_desc, node) != GRAPH_SUCCESS) {
  760. REPORT_CALL_ERROR("E19999", "Copy tensor attr from op:%s(%s) failed",
  761. node->GetName().c_str(), node->GetType().c_str());
  762. GELOGE(OUT_OF_MEMORY, "[Copy][TensorAttrs] from op:%s(%s) failed",
  763. node->GetName().c_str(), node->GetType().c_str());
  764. return FAILED;
  765. }
  766. const NodePtr &data = graph->AddNode(op_desc);
  767. GE_CHK_BOOL_EXEC(data != nullptr,
  768. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  769. op_desc->GetName().c_str(), op_desc->GetType().c_str(),
  770. graph->GetName().c_str());
  771. return FAILED,
  772. "[Add][Node] %s(%s) to graph:%s failed",
  773. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  774. if (GraphUtils::AddEdge(data->GetOutDataAnchor(0), case_node_->GetInDataAnchor(arg_index + i)) != GRAPH_SUCCESS) {
  775. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%zu) failed",
  776. data->GetName().c_str(), data->GetType().c_str(),
  777. case_node_->GetName().c_str(), case_node_->GetType().c_str(), arg_index + i);
  778. GELOGE(FAILED, "[Add][Edge] between op:%s(%s)(index:0) and op:%s(%s)(index:%zu) failed",
  779. data->GetName().c_str(), data->GetType().c_str(),
  780. case_node_->GetName().c_str(), case_node_->GetType().c_str(), arg_index + i);
  781. return FAILED;
  782. }
  783. all_const_nodes.emplace_back(data);
  784. }
  785. ChangeConstToData();
  786. all_const_nodes_.swap(all_const_nodes);
  787. return SUCCESS;
  788. }
  789. void MultiBatchClonePass::ChangeConstToData() {
  790. size_t data_index = all_data_nodes_.size();
  791. if (data_count_from_getnext_ != 0) {
  792. data_index = data_index + data_count_from_getnext_ - kNumOfGetnextNode;
  793. }
  794. for (size_t i = 0; i < all_const_nodes_.size(); ++i, ++data_index) { // Trans subgraph Const to Data.
  795. auto &const_node = all_const_nodes_[i];
  796. bool need_change_type = true;
  797. if (out_control_nodes_.find(const_node) != out_control_nodes_.end()) {
  798. GELOGD("No need to change %s to data type.", const_node->GetName().c_str());
  799. need_change_type = false;
  800. break;
  801. }
  802. if (!need_change_type) {
  803. continue;
  804. }
  805. const OpDescPtr &op_desc = all_const_nodes_[i]->GetOpDesc();
  806. op_desc->SetType(DATA);
  807. (void)op_desc->DelAttr(ATTR_NAME_WEIGHTS); // Delete weight.
  808. // Const no InputDesc, Data need InputDesc.
  809. (void)op_desc->AddInputDesc(op_desc->GetOutputDesc(kDataOutIndex));
  810. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index);
  811. (void)NodeUtils::AppendInputAnchor(all_const_nodes_[i], 1);
  812. }
  813. }
  814. ///
  815. /// @ingroup ge
  816. /// @brief Create output node for root graph.
  817. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  818. /// @return 0: SUCCESS / others: FAILED
  819. ///
  820. Status MultiBatchClonePass::CreateOutputNode(const ComputeGraphPtr &graph) {
  821. const auto &output = all_output_nodes_[0];
  822. const OpDescPtr op_desc = AttrUtils::CopyOpDesc(output->GetOpDesc());
  823. if (op_desc == nullptr) {
  824. REPORT_CALL_ERROR("E19999", "Copy op_desc from op:%s(%s) failed",
  825. output->GetName().c_str(), output->GetType().c_str());
  826. GELOGE(OUT_OF_MEMORY, "[Copy][OpDesc] from op:%s(%s) failed",
  827. output->GetName().c_str(), output->GetType().c_str());
  828. return FAILED;
  829. }
  830. if (GraphUtils::CopyTensorAttrs(op_desc, output) != GRAPH_SUCCESS) {
  831. REPORT_CALL_ERROR("E19999", "Copy tensor attr from op:%s(%s) failed",
  832. output->GetName().c_str(), output->GetType().c_str());
  833. GELOGE(OUT_OF_MEMORY, "[Copy][TensorAttrs] from op:%s(%s) failed",
  834. output->GetName().c_str(), output->GetType().c_str());
  835. return FAILED;
  836. }
  837. op_desc->SetName(output->GetName());
  838. const NodePtr &node = graph->AddNode(op_desc);
  839. GE_CHK_BOOL_EXEC(node != nullptr,
  840. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  841. op_desc->GetName().c_str(), op_desc->GetType().c_str(),
  842. graph->GetName().c_str());
  843. return FAILED,
  844. "[Add][Node] %s(%s) to graph:%s failed",
  845. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  846. for (size_t i = 0; i < case_node_->GetAllOutDataAnchorsSize(); ++i) {
  847. const auto it = direct_output_.find(i);
  848. if (it == direct_output_.end()) {
  849. if (GraphUtils::AddEdge(case_node_->GetOutDataAnchor(i), node->GetInDataAnchor(i)) != GRAPH_SUCCESS) {
  850. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%zu) and op:%s(%s)(index:%zu) failed",
  851. case_node_->GetName().c_str(), case_node_->GetType().c_str(), i,
  852. node->GetName().c_str(), node->GetType().c_str(), i);
  853. GELOGE(FAILED, "[Add][Edge] between op:%s(%s)(index:%zu) and op:%s(%s)(index:%zu) failed",
  854. case_node_->GetName().c_str(), case_node_->GetType().c_str(), i,
  855. node->GetName().c_str(), node->GetType().c_str(), i);
  856. return FAILED;
  857. }
  858. } else {
  859. const auto data_node = graph->FindNode(it->second);
  860. if (data_node == nullptr) {
  861. REPORT_CALL_ERROR("E19999", "Find node:%s from graph:%s failed", it->second.c_str(), graph->GetName().c_str());
  862. GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "[Check][Param] Data node:%s not found in graph:%s",
  863. it->second.c_str(), graph->GetName().c_str());
  864. return GE_GRAPH_GRAPH_NODE_NULL;
  865. }
  866. if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(kDataOutIndex), node->GetInDataAnchor(i)) != GRAPH_SUCCESS) {
  867. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%zu) failed",
  868. data_node->GetName().c_str(), data_node->GetType().c_str(), kDataOutIndex,
  869. node->GetName().c_str(), node->GetType().c_str(), i);
  870. GELOGE(FAILED, "[Add][Edge] between op:%s(%s)(index:%d) and op:%s(%s)(index:%zu) failed",
  871. data_node->GetName().c_str(), data_node->GetType().c_str(), kDataOutIndex,
  872. node->GetName().c_str(), node->GetType().c_str(), i);
  873. return FAILED;
  874. }
  875. }
  876. }
  877. GE_CHK_STATUS_RET(LinkGetDynamicDimsToNetOutput(node), "[Add][Edge] between %s and netoutput:%s failed.",
  878. shape_node_->GetName().c_str(), output->GetName().c_str());
  879. all_output_nodes_.clear();
  880. all_output_nodes_.emplace_back(node);
  881. return SUCCESS;
  882. }
  883. ///
  884. /// @ingroup ge
  885. /// @brief Set max shape to Data node in root graph.
  886. /// @param [in] const NodePtr &data: data in Root/Case graph.
  887. /// @return 0: SUCCESS / others: FAILED
  888. ///
  889. Status MultiBatchClonePass::SetMaxShape(const NodePtr &data) {
  890. GELOGD("Start set max shape for %s.", data->GetName().c_str());
  891. if (!IsGetNextType(data)) {
  892. if (SetMaxShapeToData(data, kDataOutIndex) != SUCCESS) {
  893. GELOGE(PARAM_INVALID, "[Update][MaxShape] of %s failed.", data->GetName().c_str());
  894. return PARAM_INVALID;
  895. }
  896. } else {
  897. for (size_t out_anchor_index = 0; out_anchor_index < data_count_from_getnext_; ++out_anchor_index) {
  898. if (SetMaxShapeToData(data, out_anchor_index) != SUCCESS) {
  899. GELOGE(PARAM_INVALID, "[Update][MaxShape] of %s failed.", data->GetName().c_str());
  900. return PARAM_INVALID;
  901. }
  902. }
  903. }
  904. return SUCCESS;
  905. }
  906. Status MultiBatchClonePass::SetMaxShapeToData(const NodePtr &node, size_t out_anchor_index) {
  907. GELOGD("Start update max shape of %s, %zu output.", node->GetName().c_str(), out_anchor_index);
  908. auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape();
  909. string data_name = node->GetName();
  910. if (IsGetNextType(node)) {
  911. data_name.append("_").append(std::to_string(out_anchor_index));
  912. }
  913. GELOGD("Update max shape of %s, shape dims is %s.", data_name.c_str(),
  914. formats::JoinToString(data_shape.GetDims()).c_str());
  915. const auto &dims = data_shape.GetDims();
  916. if (!IsGetNextType(node)) {
  917. if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) {
  918. GELOGD("No need to do anything for static data.");
  919. return SUCCESS;
  920. }
  921. } else {
  922. if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) {
  923. if (getnext_sink_dynamic_dims_) {
  924. // need to update shape of Shape_node when getnext node has dynamic data
  925. GE_CHK_STATUS_RET(UpdateShapeOfShapeNode(node, out_anchor_index),
  926. "[Update][Shape] of shape node:%s failed, out_anchor_index:%zu",
  927. node->GetName().c_str(), out_anchor_index);
  928. }
  929. return SUCCESS;
  930. }
  931. }
  932. (void)AttrUtils::SetListInt(node->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims());
  933. if (!AttrUtils::SetStr(node->GetOpDesc(), kMbatchCaseName, case_node_->GetName())) {
  934. REPORT_CALL_ERROR("E19999", "Set Attr:%s to node:%s(%s) failed",
  935. kMbatchCaseName, node->GetName().c_str(), node->GetType().c_str());
  936. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s to node:%s(%s) failed",
  937. kMbatchCaseName, node->GetName().c_str(), node->GetType().c_str());
  938. return INTERNAL_ERROR;
  939. }
  940. GeTensorDesc tensor(NodeUtils::GetOutputDesc(*node, kDataOutIndex));
  941. std::vector<std::string> input_dims_str;
  942. for (size_t i = 0; i < batch_shapes_.size(); ++i) {
  943. auto shape = data_shape;
  944. auto ret = multibatch::CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape);
  945. if (ret != SUCCESS) {
  946. GELOGE(ret, "[Calculate][Shape] for data node %s failed, the shape may not match", node->GetName().c_str());
  947. return ret;
  948. }
  949. tensor.SetShape(shape);
  950. int64_t tensor_size = 0;
  951. (void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size);
  952. string input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" +
  953. TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + node->GetName() + ":" +
  954. std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" +
  955. formats::JoinToString(tensor.GetShape().GetDims());
  956. input_dims_str.emplace_back(input_str);
  957. }
  958. (void)AttrUtils::SetListStr(node->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str);
  959. size_t max_shape_index = 0;
  960. int64_t max_size = 0;
  961. for (size_t i = 0; i < batch_shapes_.size(); ++i) {
  962. int64_t size = 1;
  963. for (auto dim : data_to_dynamic_info_.at(data_name).at(i)) {
  964. if (INT64_MAX / dim < size) {
  965. REPORT_INNER_ERROR("E19999", "The shape %s size will overflow after multi",
  966. formats::ShapeToString(data_to_dynamic_info_.at(data_name).at(i)).c_str());
  967. GELOGE(PARAM_INVALID, "[Check][Param] The shape %s size overflow",
  968. formats::ShapeToString(data_to_dynamic_info_.at(data_name).at(i)).c_str());
  969. return PARAM_INVALID;
  970. }
  971. size *= dim;
  972. }
  973. if (size > max_size) {
  974. max_size = size;
  975. max_shape_index = i;
  976. }
  977. }
  978. return SetShapeToData(data_to_dynamic_info_.at(data_name).at(max_shape_index), node, data_shape, out_anchor_index);
  979. }
  980. ///
  981. /// @ingroup ge
  982. /// @brief Set max shape to Data/GetNext node in root graph.
  983. /// @param [in] const std::vector<int64_t> &shapes: dims of shape.
  984. /// @param [in] const NodePtr &data: data in Root/Case graph.
  985. /// @param [in] GeShape &data_shape: dims of data node.
  986. /// @param [in] size_t out_anchor_index: out anchor index of data node.
  987. /// @return 0: SUCCESS / others: FAILED
  988. ///
  989. Status MultiBatchClonePass::SetShapeToData(const std::vector<int64_t> &shapes, const NodePtr &data, GeShape &data_shape,
  990. size_t out_anchor_index) {
  991. GELOGD("Start set shape to %zu out of %s.", out_anchor_index, data->GetName().c_str());
  992. if (multibatch::CalcShape(shapes, data_shape) != SUCCESS) {
  993. GELOGE(INTERNAL_ERROR, "[Calculate][Shape] for data node %s failed, the shapes may not match",
  994. data->GetName().c_str());
  995. return INTERNAL_ERROR;
  996. }
  997. if (NodeUtils::UpdateOutputShape(*data, out_anchor_index, data_shape) != GRAPH_SUCCESS) {
  998. REPORT_CALL_ERROR("E19999", "Update ouput desc shape to op:%s(%s) failed, index:%zu",
  999. data->GetName().c_str(), data->GetType().c_str(), out_anchor_index);
  1000. GELOGE(INTERNAL_ERROR, "[Update][OutputShape] to op:%s(%s) failed, index:%zu",
  1001. data->GetName().c_str(), data->GetType().c_str(), out_anchor_index);
  1002. return INTERNAL_ERROR;
  1003. }
  1004. if (!IsGetNextType(data)) {
  1005. if (NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape) != GRAPH_SUCCESS) {
  1006. REPORT_CALL_ERROR("E19999", "Update input desc shape to op:%s(%s) failed, index:%u",
  1007. data->GetName().c_str(), data->GetType().c_str(), kDataInIndex);
  1008. GELOGE(INTERNAL_ERROR, "[Update][InputShape] to op:%s(%s) failed, index:%u",
  1009. data->GetName().c_str(), data->GetType().c_str(), kDataInIndex);
  1010. return INTERNAL_ERROR;
  1011. }
  1012. } else {
  1013. if (getnext_sink_dynamic_dims_) {
  1014. // need to update shape of Shape_node when getnext_sink_dynamic
  1015. GE_CHK_STATUS_RET(UpdateShapeOfShapeNode(data, out_anchor_index),
  1016. "[Update][ShapeOfShapeNode] for %s(%s) failed, index:%zu,",
  1017. data->GetName().c_str(), data->GetType().c_str(), out_anchor_index);
  1018. }
  1019. }
  1020. GELOGI("Update the data %s input/output shape to the max %s", data->GetName().c_str(),
  1021. formats::ShapeToString(data_shape).c_str());
  1022. return SUCCESS;
  1023. }
  1024. Status MultiBatchClonePass::UpdateShapeOfShapeNode(const NodePtr &node, size_t out_anchor_index) {
  1025. GELOGD("Start update output shape of shape node insert by adapter, which is the %zu out of %s.", out_anchor_index,
  1026. node->GetName().c_str());
  1027. auto data_shape = NodeUtils::GetOutputDesc(*node, out_anchor_index).GetShape();
  1028. size_t shape_index = out_anchor_index + (node->GetAllOutDataAnchors().size() / kDivisionConst);
  1029. GeTensorDesc output_desc = node->GetOpDesc()->GetOutputDesc(shape_index);
  1030. std::vector<int64_t> output_dims = {static_cast<int64_t>(data_shape.GetDims().size())};
  1031. GeShape output_shape(output_dims);
  1032. output_desc.SetShape(output_shape);
  1033. if (node->GetOpDesc()->UpdateOutputDesc(shape_index, output_desc) != SUCCESS) {
  1034. REPORT_CALL_ERROR("E19999", "Update ouput desc to op:%s(%s) failed, index:%zu",
  1035. node->GetName().c_str(), node->GetType().c_str(), shape_index);
  1036. GELOGE(FAILED, "[Update][OutputDesc] to op:%s(%s) failed, index:%zu",
  1037. node->GetName().c_str(), node->GetType().c_str(), shape_index);
  1038. return FAILED;
  1039. }
  1040. return SUCCESS;
  1041. }
  1042. ///
  1043. /// @ingroup ge
  1044. /// @brief Update Data node in Subgraph.
  1045. /// @param [in] const NodePtr &data: data in Subgraph.
  1046. /// @param [in] size_t batch_index: The batch index.
  1047. /// @return 0: SUCCESS / others: FAILED
  1048. ///
  1049. Status MultiBatchClonePass::UpdateSubgraphData(const NodePtr &data, size_t batch_index) {
  1050. int node_index = -1;
  1051. if (!AttrUtils::GetInt(data->GetOpDesc(), ATTR_NAME_INDEX, node_index)) {
  1052. REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_NAME_INDEX.c_str(),
  1053. data->GetName().c_str(), data->GetType().c_str());
  1054. GELOGE(FAILED, "[Get][Attr] %s from op:%s(%s) failed", ATTR_NAME_INDEX.c_str(),
  1055. data->GetName().c_str(), data->GetType().c_str());
  1056. return FAILED;
  1057. }
  1058. int parent_index = node_index + 1;
  1059. if (!AttrUtils::SetInt(data->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) {
  1060. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_PARENT_NODE_INDEX.c_str(),
  1061. data->GetName().c_str(), data->GetType().c_str());
  1062. GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_PARENT_NODE_INDEX.c_str(),
  1063. data->GetName().c_str(), data->GetType().c_str());
  1064. return FAILED;
  1065. }
  1066. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  1067. const auto &dims = data_shape.GetDims();
  1068. GELOGD("Start update shape of %s , batch index is %zu, dims is %s.", data->GetName().c_str(), batch_index,
  1069. formats::JoinToString(dims).c_str());
  1070. if (std::all_of(dims.begin(), dims.end(), [](int64_t val) { return val >= 0; })) {
  1071. return SUCCESS;
  1072. }
  1073. (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims());
  1074. auto data_name = data->GetName();
  1075. size_t pos = data_name.find(kMultiBatchNodePostfix);
  1076. if (pos == string::npos) {
  1077. REPORT_INNER_ERROR("E19999", "Cannot find key string [%s] of multi-batch in name of virtual input node:%s(%s)",
  1078. kMultiBatchNodePostfix.c_str(), data->GetName().c_str(), data->GetType().c_str());
  1079. GELOGE(FAILED, "[Check][Param] Cannot find key string [%s] of multi-batch in name of virtual input node, "
  1080. "node name: %s.", kMultiBatchNodePostfix.c_str(), data_name.c_str());
  1081. return FAILED;
  1082. }
  1083. auto parent_name = data_name.substr(0, pos);
  1084. return SetShapeToData(data_to_dynamic_info_.at(parent_name).at(batch_index), data, data_shape, kDataOutIndex);
  1085. }
  1086. Status MultiBatchClonePass::CreateOriGraph(const ComputeGraphPtr &graph) {
  1087. if (data_count_from_getnext_ == 0) {
  1088. GELOGD("No need to change original graph without getnext node.");
  1089. return SUCCESS;
  1090. }
  1091. GELOGD("Start change original graph: %s when exit getnext node.", graph->GetName().c_str());
  1092. size_t data_index = all_data_nodes_.size() - kNumOfGetnextNode;
  1093. for (const auto &node : graph->GetDirectNode()) {
  1094. if (IsGetNextType(node)) {
  1095. for (size_t out_index = 0; out_index < data_count_from_getnext_; ++out_index, ++data_index) {
  1096. auto out_data_anchor = node->GetOutDataAnchor(out_index);
  1097. GE_IF_BOOL_EXEC(out_data_anchor == nullptr, continue);
  1098. NodePtr data_node = CreateDataNode(graph, out_data_anchor, data_index);
  1099. GE_IF_BOOL_EXEC(data_node == nullptr,
  1100. REPORT_CALL_ERROR("E19999", "Create data node in graph:%s failed", graph->GetName().c_str());
  1101. GELOGE(INTERNAL_ERROR, "[Create][DataNode] in graph:%s failed", graph->GetName().c_str());
  1102. return INTERNAL_ERROR);
  1103. for (auto &in_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  1104. GE_IF_BOOL_EXEC(in_anchor == nullptr, continue);
  1105. NodePtr dst_node = in_anchor->GetOwnerNode();
  1106. if (GraphUtils::RemoveEdge(out_data_anchor, in_anchor) != GRAPH_SUCCESS) {
  1107. REPORT_CALL_ERROR("E19999", "Remove edge between op:%s(%s)(index:%zu) and op:%s(%s)(index:%d) failed",
  1108. node->GetName().c_str(), node->GetType().c_str(), out_index,
  1109. dst_node->GetName().c_str(), dst_node->GetType().c_str(), in_anchor->GetIdx());
  1110. GELOGE(INTERNAL_ERROR, "[Remove][Edge] between op:%s(%s)(index:%zu) and op:%s(%s)(index:%d) failed",
  1111. node->GetName().c_str(), node->GetType().c_str(), out_index,
  1112. dst_node->GetName().c_str(), dst_node->GetType().c_str(), in_anchor->GetIdx());
  1113. return INTERNAL_ERROR;
  1114. }
  1115. if (GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), dst_node->GetInDataAnchor(in_anchor->GetIdx())) !=
  1116. GRAPH_SUCCESS) {
  1117. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed",
  1118. data_node->GetName().c_str(), data_node->GetType().c_str(),
  1119. dst_node->GetName().c_str(), dst_node->GetType().c_str(), in_anchor->GetIdx());
  1120. GELOGE(INTERNAL_ERROR, "[Add][Edge] between op:%s(%s)(index:0) and op:%s(%s)(index:%d) failed",
  1121. data_node->GetName().c_str(), data_node->GetType().c_str(),
  1122. dst_node->GetName().c_str(), dst_node->GetType().c_str(), in_anchor->GetIdx());
  1123. return INTERNAL_ERROR;
  1124. }
  1125. }
  1126. }
  1127. if (graph->RemoveNode(node) != GRAPH_SUCCESS) {
  1128. REPORT_CALL_ERROR("E19999", "Remove node:%s(%s) from graph:%s failed",
  1129. node->GetName().c_str(), node->GetType().c_str(), graph->GetName().c_str());
  1130. GELOGE(GRAPH_FAILED, "[Remove][Node] %s(%s) from graph:%s failed",
  1131. node->GetName().c_str(), node->GetType().c_str(), graph->GetName().c_str());
  1132. return GRAPH_FAILED;
  1133. }
  1134. break;
  1135. }
  1136. }
  1137. return SUCCESS;
  1138. }
  1139. NodePtr MultiBatchClonePass::CreateDataNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor,
  1140. size_t data_index) {
  1141. size_t out_anchor_index = out_data_anchor->GetIdx();
  1142. std::string node_name = out_data_anchor->GetOwnerNode()->GetName() + "_" + std::to_string(out_anchor_index);
  1143. OpDescPtr op_desc = MakeShared<OpDesc>(node_name, DATA);
  1144. if (op_desc == nullptr) {
  1145. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  1146. GELOGE(OUT_OF_MEMORY, "[New][OpDesc] failed.");
  1147. return nullptr;
  1148. }
  1149. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_INDEX, data_index);
  1150. OpDescPtr getnext_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc();
  1151. if (getnext_op_desc == nullptr) {
  1152. REPORT_INNER_ERROR("E19999", "Param out_data_anchor's owner node is nullptr, check invalid");
  1153. GELOGE(OUT_OF_MEMORY, "[Get][OpDesc] failed, Param out_data_anchor's owner node is nullptr.");
  1154. return nullptr;
  1155. }
  1156. if (op_desc->AddInputDesc(getnext_op_desc->GetOutputDesc(out_anchor_index)) != GRAPH_SUCCESS) {
  1157. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
  1158. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  1159. GELOGE(INTERNAL_ERROR, "[Add][InputDesc] to op:%s(%s) failed",
  1160. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  1161. return nullptr;
  1162. }
  1163. if (op_desc->AddOutputDesc(getnext_op_desc->GetOutputDesc(out_anchor_index)) != GRAPH_SUCCESS) {
  1164. REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed",
  1165. getnext_op_desc->GetName().c_str(), getnext_op_desc->GetType().c_str());
  1166. GELOGE(INTERNAL_ERROR, "[Add][OutputDesc] to op:%s(%s) failed",
  1167. getnext_op_desc->GetName().c_str(), getnext_op_desc->GetType().c_str());
  1168. return nullptr;
  1169. }
  1170. NodePtr data_node = graph->AddNode(op_desc);
  1171. GELOGD("Success create %s node.", data_node->GetName().c_str());
  1172. return data_node;
  1173. }
  1174. ///
  1175. /// @ingroup ge
  1176. /// @brief Create nodes for root graph.
  1177. /// @param [in] const ComputeGraphPtr &graph: Root/Case graph.
  1178. /// @param [in] const ComputeGraphPtr &branch: original graph.
  1179. /// @return 0: SUCCESS / others: FAILED
  1180. ///
  1181. Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const ComputeGraphPtr &branch) {
  1182. GELOGD("Start create subgraphs for %s.", graph->GetName().c_str());
  1183. const auto &op_desc = case_node_->GetOpDesc();
  1184. for (size_t i = 0; i < batch_shapes_.size(); ++i) {
  1185. std::vector<NodePtr> input_nodes;
  1186. std::vector<NodePtr> output_nodes;
  1187. const std::string postfix = kMultiBatchNodePostfix + std::to_string(i);
  1188. ComputeGraphPtr subgraph = (i == 0) ? branch : GraphUtils::CloneGraph(branch, postfix, input_nodes, output_nodes);
  1189. GE_IF_BOOL_EXEC(subgraph == nullptr,
  1190. REPORT_CALL_ERROR("E19999", "Clone graph from graph:%s failed", branch->GetName().c_str());
  1191. GELOGE(FAILED, "[Clone][Graph] from graph:%s failed", branch->GetName().c_str()); return FAILED);
  1192. subgraph->SetName("Batch_" + std::to_string(i));
  1193. subgraph->SetParentNode(case_node_);
  1194. subgraph->SetParentGraph(graph);
  1195. graph->AddSubgraph(subgraph->GetName(), subgraph);
  1196. all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT);
  1197. const string key_name = "branches" + std::to_string(i);
  1198. op_desc->AddSubgraphName(key_name);
  1199. op_desc->SetSubgraphInstanceName(i, subgraph->GetName());
  1200. GELOGD("The %s has %zu input, %zu output.", subgraph->GetName().c_str(), input_nodes.size(), output_nodes.size());
  1201. for (const auto &data : input_nodes) {
  1202. GE_CHK_STATUS_RET(UpdateSubgraphData(data, i),
  1203. "[Update][SubgraphData] in subgraph:%s failed, node:%s, index:%zu",
  1204. subgraph->GetName().c_str(), data->GetName().c_str(), i);
  1205. }
  1206. }
  1207. // Origninal graph take as first subgraph, update node name.
  1208. for (const auto &n : branch->GetDirectNode()) {
  1209. const auto &op_desc = n->GetOpDesc();
  1210. op_desc->SetName(n->GetName() + kMultiBatchNodePostfix + "0");
  1211. if (n->GetType() == DATA) {
  1212. GE_CHK_STATUS_RET(UpdateSubgraphData(n, 0),
  1213. "[Update][SubgraphData] in graph:%s failed, node:%s, index:0",
  1214. branch->GetName().c_str(), n->GetName().c_str());
  1215. }
  1216. }
  1217. return SUCCESS;
  1218. }
  1219. ///
  1220. /// @ingroup ge
  1221. /// @brief Update output_node in Subgraph.
  1222. /// @return 0: SUCCESS / others: FAILED
  1223. ///
  1224. Status MultiBatchClonePass::UpdateSubgraphOutput() {
  1225. for (const auto &item : all_branch_output_) {
  1226. const auto &output_node = item.second;
  1227. const auto &op_desc = output_node->GetOpDesc();
  1228. GE_CHECK_NOTNULL(op_desc);
  1229. for (size_t index = 0; index < op_desc->GetInputsSize(); ++index) {
  1230. GeTensorDescPtr tensor = op_desc->MutableInputDesc(index);
  1231. GE_CHECK_NOTNULL(tensor);
  1232. if (!AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) {
  1233. REPORT_CALL_ERROR("E19999", "Set Attr:%s to input:%zu tensor of op:%s(%s) failed",
  1234. ATTR_NAME_PARENT_NODE_INDEX.c_str(), index,
  1235. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  1236. GELOGE(FAILED, "[Set][Attr] %s to input:%zu tensor of op:%s(%s) failed",
  1237. ATTR_NAME_PARENT_NODE_INDEX.c_str(), index,
  1238. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  1239. return FAILED;
  1240. }
  1241. }
  1242. }
  1243. return SUCCESS;
  1244. }
  1245. ///
  1246. /// @ingroup ge
  1247. /// @brief Remove subgraph suspend output anchor.
  1248. /// @param [in] ComputeGraphPtr &graph: Parent compute graph.
  1249. /// @return 0: SUCCESS / others: FAILED
  1250. ///
  1251. Status MultiBatchClonePass::PruneDirectOutput(const ComputeGraphPtr &graph) {
  1252. GELOGD("Start prune direct output.");
  1253. const auto &func_desc = case_node_->GetOpDesc();
  1254. uint32_t unused_num = 0;
  1255. uint32_t output_num = func_desc->GetOutputsSize();
  1256. for (size_t i = 0; i < output_num; ++i) {
  1257. bool is_unused_tensor = true;
  1258. for (const auto &item : all_branch_output_) {
  1259. const auto &netoutput = item.second;
  1260. GE_CHECK_NOTNULL(netoutput);
  1261. const auto in_anchor = netoutput->GetInDataAnchor(i);
  1262. if (in_anchor->GetPeerOutAnchor() != nullptr) {
  1263. is_unused_tensor = false;
  1264. break;
  1265. }
  1266. }
  1267. if (is_unused_tensor) {
  1268. unused_num++;
  1269. continue;
  1270. }
  1271. GE_CHK_STATUS_RET(UpdateOutputTensor(i, unused_num),
  1272. "[Update][OutputTensor] in graph:%s failed, parent_index:%zu, unused_num:%u",
  1273. graph->GetName().c_str(), i, unused_num);
  1274. }
  1275. if (unused_num == 0) {
  1276. return SUCCESS;
  1277. }
  1278. GE_CHK_GRAPH_STATUS_RET(NodeUtils::RemoveOutputAnchor(case_node_, output_num - unused_num),
  1279. "[Remove][OutputAnchor] for node:%s failed", case_node_->GetName().c_str());
  1280. for (const auto &item : all_branch_output_) {
  1281. GE_CHK_GRAPH_STATUS_RET(NodeUtils::RemoveInputAnchor(item.second, output_num - unused_num),
  1282. "[Remove][InputAnchor] for node:%s failed", item.second->GetName().c_str());
  1283. }
  1284. return SUCCESS;
  1285. }
  1286. ///
  1287. /// @ingroup ge
  1288. /// @brief Update subgraph suspend output tensor.
  1289. /// @param [in] parent_index: parent index for check.
  1290. /// @param [in] unused_num: total unused tensor.
  1291. /// @return 0: SUCCESS / others: FAILED
  1292. ///
  1293. Status MultiBatchClonePass::UpdateOutputTensor(uint32_t parent_index, uint32_t unused_num) {
  1294. if (unused_num == 0) {
  1295. GELOGD("No need to update output tensor.");
  1296. return SUCCESS;
  1297. }
  1298. uint32_t update_index = parent_index - unused_num;
  1299. for (const auto &item : all_branch_output_) {
  1300. const auto &node = item.second;
  1301. const auto &new_anchor = node->GetInDataAnchor(update_index);
  1302. const auto &old_anchor = node->GetInDataAnchor(parent_index);
  1303. const auto &out_anchor = old_anchor->GetPeerOutAnchor();
  1304. const auto &out_node = out_anchor->GetOwnerNode();
  1305. const auto &op_desc = node->GetOpDesc();
  1306. (void)op_desc->UpdateInputDesc(update_index, op_desc->GetInputDesc(parent_index));
  1307. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(out_anchor, new_anchor),
  1308. "[Add][Edge] between %s(index:%d) and %s(index:%u) failed",
  1309. out_node->GetName().c_str(), out_anchor->GetIdx(),
  1310. new_anchor->GetOwnerNode()->GetName().c_str(), update_index);
  1311. GELOGI("Add edge success, func node: %s, node: %s, parent index: %u, update index: %u",
  1312. case_node_->GetName().c_str(), out_node->GetName().c_str(), parent_index, update_index);
  1313. GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(out_anchor, old_anchor),
  1314. "[Remove][Edge] between %s(index:%d) and %s(index:%u) failed",
  1315. out_node->GetName().c_str(), out_anchor->GetIdx(),
  1316. old_anchor->GetOwnerNode()->GetName().c_str(), parent_index);
  1317. GELOGI("Remove edge success, func node: %s, node: %s", case_node_->GetName().c_str(), out_node->GetName().c_str());
  1318. }
  1319. const auto &new_anchor = case_node_->GetOutDataAnchor(update_index);
  1320. const auto &old_anchor = case_node_->GetOutDataAnchor(parent_index);
  1321. for (const auto in_anchor : old_anchor->GetPeerInDataAnchors()) {
  1322. const auto &in_node = in_anchor->GetOwnerNode();
  1323. GE_CHK_GRAPH_STATUS_RET(GraphUtils::RemoveEdge(old_anchor, in_anchor),
  1324. "[Remove][Edge] between %s(index:%u) and %s(index:%d) failed",
  1325. case_node_->GetName().c_str(), parent_index,
  1326. in_node->GetName().c_str(), in_anchor->GetIdx());
  1327. GELOGI("Remove edge success, func node: %s, node: %s", case_node_->GetName().c_str(), in_node->GetName().c_str());
  1328. GE_CHK_GRAPH_STATUS_RET(GraphUtils::AddEdge(new_anchor, in_anchor),
  1329. "[Add][Edge] between %s(index:%u) and %s(index:%d) failed",
  1330. case_node_->GetName().c_str(), update_index,
  1331. in_node->GetName().c_str(), in_anchor->GetIdx());
  1332. GELOGI("Add edge success, func node: %s, node: %s, parent index: %u, update index: %u",
  1333. case_node_->GetName().c_str(), in_node->GetName().c_str(), parent_index, update_index);
  1334. }
  1335. return SUCCESS;
  1336. }
  1337. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示