You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_batch_copy_graph.cc 51 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/preprocess/multi_batch_copy_graph.h"
  17. #include <queue>
  18. #include <set>
  19. #include <string>
  20. #include "common/formats/utils/formats_trans_utils.h"
  21. #include "common/ge/ge_util.h"
  22. #include "common/util/error_manager/error_manager.h"
  23. #include "framework/common/debug/ge_log.h"
  24. #include "framework/common/ge_inner_error_codes.h"
  25. #include "framework/common/string_util.h"
  26. #include "framework/common/types.h"
  27. #include "framework/omg/omg_inner_types.h"
  28. #include "graph/debug/ge_attr_define.h"
  29. #include "graph/ge_context.h"
  30. #include "graph/passes/multi_batch_clone_pass.h"
  31. #include "graph/passes/prune_pass.h"
  32. #include "graph/preprocess/multi_batch_options.h"
  33. #include "graph/utils/attr_utils.h"
  34. #include "graph/utils/graph_utils.h"
  35. #include "graph/utils/node_utils.h"
  36. #include "graph/utils/tensor_utils.h"
  37. #include "graph/utils/type_utils.h"
  38. #include "inc/pass_manager.h"
  39. #include "graph/common/local_context.h"
  40. using std::set;
  41. using std::string;
  42. using std::vector;
  43. namespace ge {
  44. namespace multibatch {
  45. namespace {
  46. const char *const kMbatchSwitchnName = "mbatch-switch-name";
  47. const int kSwitchNDataIndex = 0;
  48. const int kSwitchNPredIndex = 1;
  49. const int kDataOutIndex = 0;
  50. const int kDataInIndex = 0;
  51. const int kMergeDataOutIndex = 0;
  52. const int kStaticOutput = -1;
  53. inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); }
  54. NodePtr InsertMergeNodeToGraph(const std::string &name, size_t input_num, const ComputeGraphPtr &graph) {
  55. OpDescPtr desc = MakeShared<OpDesc>();
  56. if (desc == nullptr) {
  57. GELOGE(OUT_OF_MEMORY, "Failed to insert merge node, name %s", name.c_str());
  58. return nullptr;
  59. }
  60. desc->SetName(name);
  61. desc->SetType(MERGE);
  62. GeTensorDesc tensor_desc;
  63. for (size_t i = 0; i < input_num; ++i) {
  64. auto ret = desc->AddInputDesc("x" + std::to_string(i), tensor_desc);
  65. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS,
  66. GELOGE(INTERNAL_ERROR, "Failed to create merge node %s, failed to add input %zu, error-code %u",
  67. name.c_str(), i, ret);
  68. return nullptr);
  69. }
  70. auto ret = desc->AddOutputDesc("y", tensor_desc);
  71. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS,
  72. GELOGE(INTERNAL_ERROR, "Failed to create merge node %s, failed to add output 'y', error-code %u",
  73. name.c_str(), ret);
  74. return nullptr);
  75. tensor_desc.SetDataType(DT_INT32);
  76. ret = desc->AddOutputDesc("value_index", tensor_desc);
  77. if (ret != GRAPH_SUCCESS) {
  78. GELOGE(INTERNAL_ERROR, "Failed to create merge node %s, failed to add output 'value_index', error-code %u",
  79. name.c_str(), ret);
  80. return nullptr;
  81. }
  82. if (!AttrUtils::SetBool(desc, ATTR_INSERT_BY_MBATCH, true)) {
  83. GELOGE(INTERNAL_ERROR, "Failed to create merge node %s, failed to add attr", name.c_str());
  84. return nullptr;
  85. }
  86. return graph->AddNode(desc);
  87. }
  88. NodePtr InsertCopyNode(const NodePtr &node, size_t n) {
  89. const std::string &name = node->GetName() + "_ascend_mbatch_batch_" + std::to_string(n);
  90. auto src_op_desc = node->GetOpDesc();
  91. GE_IF_BOOL_EXEC(src_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "Failed to copy node %s to %s, the OpDesc is null",
  92. node->GetName().c_str(), name.c_str());
  93. return nullptr);
  94. auto desc = AttrUtils::CopyOpDesc(src_op_desc);
  95. GE_IF_BOOL_EXEC(desc == nullptr, GELOGE(OUT_OF_MEMORY, "Failed to create op desc for copy node for node %s name %s",
  96. node->GetName().c_str(), name.c_str());
  97. return nullptr);
  98. desc->SetName(name);
  99. desc->CopyAttrsFrom(*src_op_desc);
  100. for (uint32_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) {
  101. auto input_desc = desc->MutableInputDesc(i);
  102. GE_IF_BOOL_EXEC(input_desc == nullptr,
  103. GELOGW("Get null input desc by index %u from node %s when copy from %s", i,
  104. desc->GetName().c_str(), node->GetName().c_str());
  105. continue);
  106. input_desc->CopyAttrsFrom(src_op_desc->GetInputDesc(i));
  107. }
  108. for (uint32_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) {
  109. auto output_desc = desc->MutableOutputDesc(i);
  110. GE_IF_BOOL_EXEC(output_desc == nullptr,
  111. GELOGE(INTERNAL_ERROR, "Failed to get output desc by index %u from node %s when copy from %s", i,
  112. desc->GetName().c_str(), node->GetName().c_str());
  113. return nullptr);
  114. output_desc->CopyAttrsFrom(src_op_desc->GetOutputDesc(i));
  115. }
  116. const std::string &batch_label = "Batch_" + std::to_string(n);
  117. if (!AttrUtils::SetStr(desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
  118. GELOGE(FAILED, "set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", name.c_str());
  119. return nullptr;
  120. }
  121. (void)AttrUtils::SetListStr(desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, {node->GetName()});
  122. auto graph = node->GetOwnerComputeGraph();
  123. return graph->AddNode(desc);
  124. }
  125. bool IsAllDimsPositive(const std::vector<int64_t> &dims) {
  126. for (auto dim : dims) {
  127. if (dim < 0) {
  128. return false;
  129. }
  130. }
  131. return true;
  132. }
  133. NodePtr InsertConst(const std::string &name, const ComputeGraphPtr &graph) {
  134. auto desc = MakeShared<OpDesc>();
  135. if (desc == nullptr) {
  136. GELOGE(OUT_OF_MEMORY, "Failed to create const op %s, out of memory", name.c_str());
  137. return nullptr;
  138. }
  139. desc->SetName(name);
  140. desc->SetType(CONSTANT);
  141. GeTensor tensor;
  142. tensor.SetData(std::vector<uint8_t>({0}));
  143. if (!AttrUtils::SetTensor(desc, ATTR_NAME_WEIGHTS, tensor)) {
  144. GELOGE(OUT_OF_MEMORY, "Failed to init tensor value for const %s", name.c_str());
  145. return nullptr;
  146. }
  147. if (!AttrUtils::SetBool(desc, ATTR_INSERT_BY_MBATCH, true)) {
  148. GELOGE(OUT_OF_MEMORY, "Failed to set insert flag for const node %s", name.c_str());
  149. return nullptr;
  150. }
  151. if (desc->AddOutputDesc(GeTensorDesc()) != GRAPH_SUCCESS) {
  152. GELOGE(OUT_OF_MEMORY, "Failed to add output desc for const node %s", name.c_str());
  153. return nullptr;
  154. }
  155. return graph->AddNode(desc);
  156. }
  157. bool IsOnlyOutputToAipp(const NodePtr &node) {
  158. for (const auto &out_node : node->GetOutDataNodes()) {
  159. if (out_node->GetType() != AIPP) {
  160. return false;
  161. }
  162. }
  163. return true;
  164. }
  165. Status CheckDataShape(const std::vector<NodePtr> &nodes) {
  166. size_t unknown_shape_count = 0;
  167. for (const auto &node : nodes) {
  168. if (node->GetType() != DATA) {
  169. continue;
  170. }
  171. for (auto dim : NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims()) {
  172. if (dim < 0) {
  173. unknown_shape_count++;
  174. break;
  175. }
  176. }
  177. }
  178. if (unknown_shape_count == 0) {
  179. ErrorManager::GetInstance().ATCReportErrMessage("E10040");
  180. GELOGE(PARAM_INVALID,
  181. "Need unknow shape data when user set --dynamic_batch_size, --dynamic_image_size or --dynamic_dims");
  182. return PARAM_INVALID;
  183. }
  184. return SUCCESS;
  185. }
  186. } // namespace
  187. Status MultiBatchGraphCopyer::CopyGraph() {
  188. auto ret = Init();
  189. if (ret != SUCCESS) {
  190. return ret;
  191. }
  192. if (LabelStatus() != SUCCESS) {
  193. GELOGE(INTERNAL_ERROR, "Failed to label status for all nodes.");
  194. return INTERNAL_ERROR;
  195. }
  196. ret = CheckAndParseDynamicData();
  197. if (ret != SUCCESS) {
  198. return ret;
  199. }
  200. ret = CreateNewNodes();
  201. if (ret != SUCCESS) {
  202. return ret;
  203. }
  204. ret = LinkEdges();
  205. if (ret != SUCCESS) {
  206. return ret;
  207. }
  208. ret = InsertIdentityAfterSwitchN();
  209. if (ret != SUCCESS) {
  210. GELOGE(INTERNAL_ERROR, "Failed to insert identity nodes after switchn node.");
  211. return INTERNAL_ERROR;
  212. }
  213. GELOGI("Begin to remove useless nodes by prune pass after copy process");
  214. PrunePass prune_pass;
  215. ret = prune_pass.Run(graph_);
  216. if (ret != SUCCESS) {
  217. GELOGE(ret, "Failed to prune");
  218. return ret;
  219. }
  220. return CheckCopyResult(origin_data_nodes_);
  221. }
  222. Status MultiBatchGraphCopyer::Init() {
  223. auto ret = CheckArguments();
  224. if (ret != SUCCESS) {
  225. return ret;
  226. }
  227. for (auto &node : graph_->GetAllNodes()) {
  228. origin_all_nodes_.emplace_back(node);
  229. if (IsDataLikeType(node->GetType())) {
  230. origin_data_nodes_.emplace_back(node);
  231. }
  232. }
  233. return SUCCESS;
  234. }
  235. Status MultiBatchGraphCopyer::LabelStatus() {
  236. for (const auto &data : origin_data_nodes_) {
  237. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  238. if (!IsAllDimsPositive(data_shape.GetDims())) {
  239. origin_nodes_status_[data.get()] = kNodeInBatchBranch;
  240. }
  241. }
  242. bool changed = true;
  243. // If anyone of in node is kNodeInBatchBranch, it is also kNodeInBatchBranch
  244. while (changed) {
  245. changed = false;
  246. for (const auto &node : origin_all_nodes_) {
  247. auto iter = origin_nodes_status_.find(node.get());
  248. if (iter != origin_nodes_status_.end()) {
  249. continue;
  250. }
  251. for (auto &in_node : node->GetInAllNodes()) {
  252. bool is_in_batch = origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end() &&
  253. origin_nodes_status_[in_node.get()] == kNodeInBatchBranch;
  254. if (is_in_batch) {
  255. origin_nodes_status_[node.get()] = kNodeInBatchBranch;
  256. changed = true;
  257. break;
  258. }
  259. }
  260. }
  261. }
  262. for (const auto &node : origin_all_nodes_) {
  263. if (!(node->GetOpDesc()->GetSubgraphInstanceNames().empty())) {
  264. origin_nodes_status_[node.get()] = kNodeNotSupportNode;
  265. continue;
  266. }
  267. if (node->GetType() == NETOUTPUT) {
  268. origin_nodes_status_[node.get()] = kNodeOutBatchBranch;
  269. continue;
  270. }
  271. if (IsDataLikeType(node->GetType())) {
  272. if (IsOnlyOutputToAipp(node)) {
  273. origin_nodes_status_[node.get()] = kNodeOutBatchBranch;
  274. } else {
  275. origin_nodes_status_[node.get()] = kNodeStartNode;
  276. }
  277. continue;
  278. }
  279. if (origin_nodes_status_.find(node.get()) == origin_nodes_status_.end()) {
  280. origin_nodes_status_[node.get()] = kNodeOutBatchBranch;
  281. }
  282. }
  283. return SUCCESS;
  284. }
  285. Status MultiBatchGraphCopyer::CheckAndParseDynamicData(){
  286. size_t unknown_shape_count = 0;
  287. auto data_name_and_shape = GetLocalOmgContext().user_input_dims;
  288. GELOGD("raw data_name_and_shape size: %zu", data_name_and_shape.size());
  289. for (const auto &node : origin_all_nodes_) {
  290. auto data_desc = NodeUtils::GetOutputDesc(*node, kDataOutIndex);
  291. auto data_shape = data_desc.GetShape();
  292. auto data_format = data_desc.GetFormat() == Format::FORMAT_NCHW ? "NCHW" :
  293. data_desc.GetFormat() == Format::FORMAT_NHWC ? "NHWC" : "Others";
  294. auto data_name = node->GetName();
  295. auto branch_status = GetNodeStatus(node);
  296. if (branch_status != kNodeStartNode) {
  297. continue;
  298. }
  299. if (IsAllDimsPositive(data_shape.GetDims())) {
  300. continue;
  301. }
  302. ++unknown_shape_count;
  303. auto iter = find(data_name_order_.begin(), data_name_order_.end(), data_name);
  304. if (iter == data_name_order_.end()) {
  305. if (dynamic_type_ == DynamicType::kDynamicBatch) {
  306. auto ret = CheckDynamicBatchShape(data_shape.GetDims(), data_name);
  307. if (!ret) {
  308. return PARAM_INVALID;
  309. }
  310. } else if (dynamic_type_ == DynamicType::kDynamicImageSize) {
  311. auto ret = CheckDynamicImageSizeShape(data_shape.GetDims(), data_name, data_format);
  312. if (!ret) {
  313. return PARAM_INVALID;
  314. }
  315. } else if (dynamic_type_ == DynamicType::kDynamicDims) {
  316. ErrorManager::GetInstance().ATCReportErrMessage("E10001",
  317. {"parameter", "reason"},
  318. {"--input_shape",
  319. "all dynamic data must be set in --input_shape"});
  320. GELOGE(INTERNAL_ERROR, "data: %s shape:%s must be set int --input_shape",
  321. node->GetName().c_str(), data_shape.ToString().c_str());
  322. return INTERNAL_ERROR;
  323. }
  324. data_name_and_shape.emplace_back(data_name, data_shape.GetDims());
  325. }
  326. }
  327. auto ret = ParserDataToDynmaicInfo(shapes_, data_name_and_shape, data_to_dynamic_info_);
  328. if (ret != SUCCESS){
  329. return ret;
  330. }
  331. if (unknown_shape_count == 0) {
  332. ErrorManager::GetInstance().ATCReportErrMessage("E10040");
  333. GELOGE(PARAM_INVALID,
  334. "Need unknow shape data when user set --dynamic_batch_size, --dynamic_image_size or --dynamic_dims");
  335. return PARAM_INVALID;
  336. }
  337. return SUCCESS;
  338. }
  339. Status MultiBatchGraphCopyer::CreateNewNodes() {
  340. shape_data_ = InsertShapeDataNode();
  341. if (shape_data_ == nullptr) {
  342. GELOGE(INTERNAL_ERROR, "Failed to create the shape data node for muti-batch");
  343. return INTERNAL_ERROR;
  344. }
  345. for (const auto &node : origin_all_nodes_) {
  346. auto node_type = node->GetType();
  347. Status ret = INTERNAL_ERROR;
  348. auto branch_status = GetNodeStatus(node);
  349. GELOGD("Process node %s, status %d", node->GetName().c_str(), static_cast<int>(branch_status));
  350. switch (branch_status) {
  351. case kNodeStartNode:
  352. GELOGD("Name: %s, type: %s, status: kNodeStartNode.", node->GetName().c_str(), node->GetType().c_str());
  353. ret = InsertSwitchNForData(node);
  354. if (ret == SUCCESS) {
  355. ret = UpdateMaxShapeToData(node);
  356. }
  357. break;
  358. case kNodeInBatchBranch:
  359. GELOGD("Name: %s, type: %s, status: kNodeInBatchBranch.", node->GetName().c_str(), node->GetType().c_str());
  360. ret = CopyNodeInBatchBranch(node);
  361. break;
  362. case kNodeOutBatchBranch:
  363. GELOGD("Name: %s, type: %s, status: kNodeOutBatchBranch.", node->GetName().c_str(), node->GetType().c_str());
  364. ret = InsertMergeForEdgeNode(node);
  365. break;
  366. case kNodeNotSupportNode:
  367. GELOGD("Name: %s, type: %s, status: kNodeNotSupportNode.", node->GetName().c_str(), node->GetType().c_str());
  368. break;
  369. default:
  370. GELOGE(INTERNAL_ERROR, "Unexpected status %d on node %s", static_cast<int>(branch_status),
  371. node->GetName().c_str());
  372. break;
  373. }
  374. if (ret != SUCCESS) {
  375. GELOGE(ret, "Failed to deal with node %s in multi-batch process", node->GetName().c_str());
  376. return ret;
  377. }
  378. }
  379. return SUCCESS;
  380. }
  381. NodePtr MultiBatchGraphCopyer::InsertMergeNode(const NodePtr &node, int index) {
  382. if (index < 0) {
  383. // the merge node must has data inputs, if origin connection is a control
  384. // edge, we use data edge instead
  385. index = 0;
  386. }
  387. auto &merge_nodes = nodes_to_merge_nodes_[node.get()];
  388. if (merge_nodes.empty()) {
  389. auto count = node->GetAllOutDataAnchorsSize();
  390. if (count == 0) {
  391. count = 1;
  392. }
  393. merge_nodes.resize(count, nullptr);
  394. }
  395. if (merge_nodes.at(index) != nullptr) {
  396. return merge_nodes[index];
  397. }
  398. auto merge_node_name = node->GetName() + "_ascend_mbatch_merge_" + std::to_string(index);
  399. auto merge_node = InsertMergeNodeToGraph(merge_node_name, shapes_.size(), node->GetOwnerComputeGraph());
  400. GE_IF_BOOL_EXEC(merge_node == nullptr, GELOGE(INTERNAL_ERROR, "Failed to create merge node for node %s, out index %d",
  401. node->GetName().c_str(), index);
  402. return nullptr);
  403. merge_nodes[index] = merge_node;
  404. GELOGI("Create merge node %s for node %s index %d", merge_node_name.c_str(), node->GetName().c_str(), index);
  405. return merge_node;
  406. }
  407. Status MultiBatchGraphCopyer::CopyInDataEdges(const NodePtr &origin_node, int batch_num, const NodePtr &copyed_node) {
  408. for (auto &in_anchor : origin_node->GetAllInDataAnchors()) {
  409. auto origin_src_anchor = in_anchor->GetPeerOutAnchor();
  410. if (origin_src_anchor == nullptr) {
  411. GELOGD("The node %s does not have input on index %d", origin_node->GetName().c_str(), in_anchor->GetIdx());
  412. continue;
  413. }
  414. auto origin_src_node = origin_src_anchor->GetOwnerNode();
  415. auto dst_anchor = copyed_node->GetInDataAnchor(in_anchor->GetIdx());
  416. GE_CHECK_NOTNULL(dst_anchor);
  417. auto switchn_iter = data_nodes_to_switchn_.find(origin_src_node.get());
  418. if (switchn_iter != data_nodes_to_switchn_.end()) {
  419. auto ret = GraphUtils::AddEdge(switchn_iter->second->GetOutDataAnchor(batch_num), dst_anchor);
  420. if (ret != GRAPH_SUCCESS) {
  421. GELOGE(INTERNAL_ERROR, "Failed to add data edge between %s(%d) to %s(%d), error-code %u",
  422. switchn_iter->second->GetName().c_str(), batch_num, copyed_node->GetName().c_str(), in_anchor->GetIdx(),
  423. ret);
  424. return INTERNAL_ERROR;
  425. }
  426. GELOGD("Add data edge from %s(%d) to %s(%d)", switchn_iter->second->GetName().c_str(), batch_num,
  427. copyed_node->GetName().c_str(), in_anchor->GetIdx());
  428. continue;
  429. }
  430. auto batch_branch_iter = nodes_to_batch_nodes_.find(origin_src_node.get());
  431. if (batch_branch_iter != nodes_to_batch_nodes_.end()) {
  432. auto src_batch_node = batch_branch_iter->second.at(batch_num);
  433. auto ret = GraphUtils::AddEdge(src_batch_node->GetOutDataAnchor(origin_src_anchor->GetIdx()), dst_anchor);
  434. if (ret != GRAPH_SUCCESS) {
  435. GELOGE(INTERNAL_ERROR, "Failed to add data edge between %s(%d) to %s(%d), error-code %u",
  436. src_batch_node->GetName().c_str(), batch_num, copyed_node->GetName().c_str(), in_anchor->GetIdx(), ret);
  437. return INTERNAL_ERROR;
  438. }
  439. GELOGD("Add data edge from %s(%d) to %s(%d)", src_batch_node->GetName().c_str(), batch_num,
  440. copyed_node->GetName().c_str(), in_anchor->GetIdx());
  441. continue;
  442. }
  443. auto ret = GraphUtils::AddEdge(origin_src_anchor, dst_anchor);
  444. if (ret != GRAPH_SUCCESS) {
  445. GELOGE(INTERNAL_ERROR, "Failed to add data edge between origin node %s(%d) to copyed %s(%d)",
  446. origin_src_node->GetName().c_str(), origin_src_anchor->GetIdx(), copyed_node->GetName().c_str(),
  447. dst_anchor->GetIdx());
  448. return INTERNAL_ERROR;
  449. }
  450. GELOGD("Add data edge between branch-out %s(%d) to branch-in %s(%d)", origin_src_node->GetName().c_str(),
  451. origin_src_anchor->GetIdx(), copyed_node->GetName().c_str(), dst_anchor->GetIdx());
  452. }
  453. return SUCCESS;
  454. }
  455. Status MultiBatchGraphCopyer::CopyInControlEdges(const NodePtr &node, int batch_num, const NodePtr &copyed_node) {
  456. for (auto &origin_src_node : node->GetInControlNodes()) {
  457. auto switchn_iter = data_nodes_to_switchn_.find(origin_src_node.get());
  458. if (switchn_iter != data_nodes_to_switchn_.end()) {
  459. // reconnect data node
  460. auto ret = GraphUtils::AddEdge(switchn_iter->second->GetOutControlAnchor(), copyed_node->GetInControlAnchor());
  461. if (ret != GRAPH_SUCCESS) {
  462. GELOGE(INTERNAL_ERROR, "Failed to add control edge between %s to %s, error-code %u",
  463. switchn_iter->second->GetName().c_str(), copyed_node->GetName().c_str(), ret);
  464. return INTERNAL_ERROR;
  465. }
  466. GELOGD("Add control edge from %s to %s", switchn_iter->second->GetName().c_str(), copyed_node->GetName().c_str());
  467. continue;
  468. }
  469. auto batch_branch_iter = nodes_to_batch_nodes_.find(origin_src_node.get());
  470. if (batch_branch_iter != nodes_to_batch_nodes_.end()) {
  471. // reconnect node in batch branch
  472. auto src_batch_node = batch_branch_iter->second.at(batch_num);
  473. auto ret = GraphUtils::AddEdge(src_batch_node->GetOutControlAnchor(), copyed_node->GetInControlAnchor());
  474. if (ret != GRAPH_SUCCESS) {
  475. GELOGE(INTERNAL_ERROR, "Failed to add data edge between %s to %s, error-code %u",
  476. src_batch_node->GetName().c_str(), copyed_node->GetName().c_str(), ret);
  477. return INTERNAL_ERROR;
  478. }
  479. GELOGD("Add control edge from %s to %s", src_batch_node->GetName().c_str(), copyed_node->GetName().c_str());
  480. continue;
  481. }
  482. auto ret = GraphUtils::AddEdge(origin_src_node->GetOutControlAnchor(), copyed_node->GetInControlAnchor());
  483. if (ret != GRAPH_SUCCESS) {
  484. GELOGE(INTERNAL_ERROR, "Failed to add control edge from origin %s to copyed %s",
  485. origin_src_node->GetName().c_str(), copyed_node->GetName().c_str());
  486. return INTERNAL_ERROR;
  487. }
  488. GELOGD("Add control edge between branch-out %s to branch-in %s", origin_src_node->GetName().c_str(),
  489. copyed_node->GetName().c_str());
  490. }
  491. return SUCCESS;
  492. }
  493. NodePtr MultiBatchGraphCopyer::InsertShapeDataNode() {
  494. auto desc = MakeShared<OpDesc>();
  495. if (desc == nullptr) {
  496. GELOGE(OUT_OF_MEMORY, "Failed to create shape data node, out of memory");
  497. return nullptr;
  498. }
  499. string node_name = "ascend_mbatch_shape_data";
  500. // Only flush subgraph name
  501. if (graph_->GetParentGraph() != nullptr) {
  502. node_name = graph_->GetName() + "_" + node_name;
  503. }
  504. desc->SetName(node_name);
  505. desc->SetType(DATA);
  506. GeTensorDesc tensor_desc;
  507. tensor_desc.SetFormat(FORMAT_ND);
  508. tensor_desc.SetShape(GeShape({static_cast<int64_t>(shapes_.at(0).size())}));
  509. tensor_desc.SetDataType(DT_INT64);
  510. auto ret = desc->AddInputDesc(tensor_desc);
  511. if (ret != GRAPH_SUCCESS) {
  512. GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data");
  513. return nullptr;
  514. }
  515. ret = desc->AddOutputDesc(tensor_desc);
  516. if (ret != GRAPH_SUCCESS) {
  517. GELOGE(INTERNAL_ERROR, "Failed to add output desc for created data");
  518. return nullptr;
  519. }
  520. if (!AttrUtils::SetBool(desc, ATTR_INSERT_BY_MBATCH, true)) {
  521. GELOGE(INTERNAL_ERROR, "Failed to add attr for created data");
  522. return nullptr;
  523. }
  524. auto data_node = graph_->AddNode(desc);
  525. if (data_node == nullptr) {
  526. GELOGE(INTERNAL_ERROR, "Failed to add shape data node to graph");
  527. return nullptr;
  528. }
  529. ret = GraphUtils::AppendInputNode(graph_, data_node);
  530. if (ret != GRAPH_SUCCESS) {
  531. GELOGE(INTERNAL_ERROR, "Failed to append data node %s as input to graph", data_node->GetName().c_str());
  532. return nullptr;
  533. }
  534. return data_node;
  535. }
  536. Status MultiBatchGraphCopyer::CheckArguments() {
  537. if (graph_ == nullptr) {
  538. GELOGE(PARAM_INVALID, "Failed to copy graph, the graph is null");
  539. return PARAM_INVALID;
  540. }
  541. return CheckDynamicParams(shapes_);
  542. }
  543. Status MultiBatchGraphCopyer::CheckCopyResult(const std::vector<NodePtr> &start_nodes) {
  544. for (auto &node : start_nodes) {
  545. if (IsOnlyOutputToAipp(node)) {
  546. continue;
  547. }
  548. auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims();
  549. if (!IsAllDimsPositive(dims)) {
  550. GELOGE(INTERNAL_ERROR, "Failed to copy multi batch graph, the node %s still has unknown shape %s",
  551. node->GetName().c_str(), formats::ShapeToString(dims).c_str());
  552. return INTERNAL_ERROR;
  553. }
  554. }
  555. return SUCCESS;
  556. }
  557. bool MultiBatchGraphCopyer::IsInBatchBranch(const NodePtr &node) {
  558. return (nodes_to_batch_nodes_.count(node.get()) > 0) || (data_nodes_to_switchn_.count(node.get()) > 0);
  559. }
  560. Status MultiBatchGraphCopyer::LinkDataToMerge(const NodePtr &data, const NodePtr &merge) {
  561. // The caller should make sure that the there is a SwitchN node in the map
  562. auto &switchn = data_nodes_to_switchn_[data.get()];
  563. GELOGI("Link edge between data %s to merge %s throw switchn %s", data->GetName().c_str(), merge->GetName().c_str(),
  564. switchn->GetName().c_str());
  565. for (size_t i = 0; i < shapes_.size(); ++i) {
  566. auto ret = GraphUtils::AddEdge(switchn->GetOutDataAnchor(i), merge->GetInDataAnchor(i));
  567. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS,
  568. GELOGE(INTERNAL_ERROR, "Failed to add edge between switchn %s(%zu) to merge %s(%zu), error-code %u",
  569. switchn->GetName().c_str(), i, merge->GetName().c_str(), i, ret);
  570. return INTERNAL_ERROR);
  571. }
  572. return SUCCESS;
  573. }
  574. Status MultiBatchGraphCopyer::LinkNodeToMerge(const NodePtr &node, int out_index, const NodePtr &merge) {
  575. auto &copyed_nodes = nodes_to_batch_nodes_[node.get()];
  576. if (copyed_nodes.size() != shapes_.size()) {
  577. GELOGE(INTERNAL_ERROR,
  578. "Failed to create merge node for node %s, the copyed nodes for it count %zu different with shape %zu",
  579. node->GetName().c_str(), copyed_nodes.size(), shapes_.size());
  580. return INTERNAL_ERROR;
  581. }
  582. for (size_t i = 0; i < copyed_nodes.size(); ++i) {
  583. auto src_node = copyed_nodes[i];
  584. if (src_node->GetAllOutDataAnchorsSize() == 0) {
  585. // if the node does not has any data output, we should create an const for it, like this:
  586. // c d
  587. // node ---> const ---> merge
  588. auto const_name = src_node->GetName() + "_merge_const";
  589. GELOGI("The node %s on the batch branch edge does not have any data output, create a const %s for it",
  590. src_node->GetName().c_str(), const_name.c_str());
  591. auto const_node = InsertConst(const_name, graph_);
  592. GE_IF_BOOL_EXEC(const_node == nullptr,
  593. GELOGE(OUT_OF_MEMORY, "Failed to create const for node %s to connect to a merge node",
  594. src_node->GetName().c_str());
  595. return OUT_OF_MEMORY);
  596. auto ret = GraphUtils::AddEdge(src_node->GetOutControlAnchor(), const_node->GetInControlAnchor());
  597. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add control edge from %s to %s",
  598. src_node->GetName().c_str(), const_node->GetName().c_str());
  599. return INTERNAL_ERROR);
  600. src_node = const_node;
  601. }
  602. auto ret = GraphUtils::AddEdge(src_node->GetOutDataAnchor(out_index), merge->GetInDataAnchor(i));
  603. if (ret != GRAPH_SUCCESS) {
  604. GELOGE(INTERNAL_ERROR,
  605. "Failed to add edge between copyed node %s(%d) to inserted merge node %s(%zu), error-code %u",
  606. copyed_nodes[i]->GetName().c_str(), out_index, merge->GetName().c_str(), i, ret);
  607. return INTERNAL_ERROR;
  608. }
  609. }
  610. return SUCCESS;
  611. }
  612. Status MultiBatchGraphCopyer::UpdateMaxShapeToData(const NodePtr &data) {
  613. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  614. auto data_name = data->GetName();
  615. if (IsAllDimsPositive(data_shape.GetDims())) {
  616. return SUCCESS;
  617. }
  618. size_t max_shape_index = 0;
  619. int64_t max_size = 0;
  620. for (size_t i = 0; i < shapes_.size(); ++i) {
  621. int64_t size = 1;
  622. for (auto dim : data_to_dynamic_info_.at(data_name).at(i)) {
  623. if (INT64_MAX / dim < size) {
  624. GELOGE(PARAM_INVALID, "The shape %s size overflow",
  625. formats::ShapeToString(data_to_dynamic_info_[data_name].at(i)).c_str());
  626. return PARAM_INVALID;
  627. }
  628. size *= dim;
  629. }
  630. if (size > max_size) {
  631. max_size = size;
  632. max_shape_index = i;
  633. }
  634. }
  635. // must not be error, the calc result has been checked in function InsertSwitchNForData
  636. (void)CalcShape(data_to_dynamic_info_.at(data_name).at(max_shape_index), data_shape);
  637. auto ret = NodeUtils::UpdateOutputShape(*data, kDataOutIndex, data_shape);
  638. if (ret != GRAPH_SUCCESS) {
  639. GELOGE(INTERNAL_ERROR, "Failed to update output shape for data %s", data->GetName().c_str());
  640. return INTERNAL_ERROR;
  641. }
  642. ret = NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape);
  643. if (ret != GRAPH_SUCCESS) {
  644. GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str());
  645. return INTERNAL_ERROR;
  646. }
  647. GELOGI("Update the data %s input/output shape to the max %s", data->GetName().c_str(),
  648. formats::ShapeToString(data_shape).c_str());
  649. return SUCCESS;
  650. }
  651. Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &data) {
  652. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  653. auto data_name = data->GetName();
  654. (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims());
  655. if (IsAllDimsPositive(data_shape.GetDims())) {
  656. GELOGI("The shape of data %s are positive(%s), skip the multi batch process", data->GetName().c_str(),
  657. data_shape.ToString().c_str());
  658. return SUCCESS;
  659. }
  660. auto switchn_desc = MakeShared<OpDesc>();
  661. if (switchn_desc == nullptr) {
  662. GELOGE(OUT_OF_MEMORY, "Failed to create switchn for data %s", data->GetName().c_str());
  663. return OUT_OF_MEMORY;
  664. }
  665. switchn_desc->SetName(data->GetName() + "_ascend_mbatch_switchn");
  666. switchn_desc->SetType(SWITCHN);
  667. GeTensorDesc tensor(NodeUtils::GetOutputDesc(*data, kDataOutIndex));
  668. if (switchn_desc->AddInputDesc("data", tensor) != GRAPH_SUCCESS) { // data
  669. return OUT_OF_MEMORY;
  670. }
  671. GeTensorDesc pred_tensor;
  672. if (switchn_desc->AddInputDesc("pred_value", pred_tensor) != GRAPH_SUCCESS) { // pred
  673. return OUT_OF_MEMORY;
  674. }
  675. std::vector<std::string> input_dims_str;
  676. for (size_t i = 0; i < shapes_.size(); ++i) {
  677. auto shape = data_shape;
  678. auto ret = CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape);
  679. if (ret != SUCCESS) {
  680. GELOGE(ret, "Failed to calculate the batched shape for data node %s, the shapes may not match",
  681. data->GetName().c_str());
  682. return ret;
  683. }
  684. tensor.SetShape(shape);
  685. string input_str;
  686. int64_t tensor_size = 0;
  687. (void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size);
  688. input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" +
  689. TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" +
  690. std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" +
  691. formats::JoinToString(tensor.GetShape().GetDims());
  692. input_dims_str.emplace_back(input_str);
  693. if (!AttrUtils::SetListInt(tensor, ATTR_NAME_SWITCHN_PRED_VALUE, shapes_.at(i))) {
  694. GELOGE(INTERNAL_ERROR, "Failed to add attr value on output %zu tensor", i);
  695. return INTERNAL_ERROR;
  696. }
  697. (void) AttrUtils::SetListInt(tensor, ATTR_NAME_COMBINED_DYNAMIC_DIMS, shape.GetDims());
  698. if (switchn_desc->AddOutputDesc("output" + std::to_string(i), tensor) != GRAPH_SUCCESS) {
  699. GELOGE(GRAPH_FAILED, "Opdesc AddOutputDesc failed");
  700. return GRAPH_FAILED;
  701. }
  702. GELOGD("The SwitchN %s output index %zu, shape %s", switchn_desc->GetName().c_str(), i, shape.ToString().c_str());
  703. }
  704. (void)AttrUtils::SetListStr(data->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str);
  705. if (!AttrUtils::SetListStr(switchn_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, data_name_order_)) {
  706. GELOGE(INTERNAL_ERROR, "Failed to add user designate shape order attr on switchn node %s",
  707. switchn_desc->GetName().c_str());
  708. return INTERNAL_ERROR;
  709. }
  710. if (!AttrUtils::SetBool(switchn_desc, ATTR_INSERT_BY_MBATCH, true)) {
  711. GELOGE(INTERNAL_ERROR, "Failed to add insert attr on switchn node %s", switchn_desc->GetName().c_str());
  712. return INTERNAL_ERROR;
  713. }
  714. if (!AttrUtils::SetStr(data->GetOpDesc(), kMbatchSwitchnName, switchn_desc->GetName())) {
  715. GELOGE(INTERNAL_ERROR, "Failed to add switchn attr on data node %s", data->GetName().c_str());
  716. return INTERNAL_ERROR;
  717. }
  718. if (StampDynamicType(switchn_desc) != SUCCESS) {
  719. GELOGE(INTERNAL_ERROR, "Failed to add dynamic type attr on switchn node %s", switchn_desc->GetName().c_str());
  720. return INTERNAL_ERROR;
  721. }
  722. auto switchn = graph_->AddNode(switchn_desc);
  723. if (switchn == nullptr) {
  724. GELOGE(OUT_OF_MEMORY, "Failed to create switchn %s from desc", switchn_desc->GetName().c_str());
  725. return OUT_OF_MEMORY;
  726. }
  727. data_nodes_to_switchn_[data.get()] = switchn;
  728. return SUCCESS;
  729. }
  730. Status MultiBatchGraphCopyer::InsertMergeForEdgeNode(const NodePtr &node) {
  731. for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
  732. auto src_out_anchor = in_data_anchor->GetPeerOutAnchor();
  733. if (src_out_anchor == nullptr) {
  734. GELOGD("The node %s does not has input at index %d", node->GetName().c_str(), in_data_anchor->GetIdx());
  735. continue;
  736. }
  737. auto in_node = src_out_anchor->GetOwnerNode();
  738. if (!IsInBatchBranch(in_node)) {
  739. continue;
  740. }
  741. auto merge_node = InsertMergeNode(in_node, src_out_anchor->GetIdx());
  742. if (merge_node == nullptr) {
  743. return INTERNAL_ERROR;
  744. }
  745. }
  746. for (auto &in_node : node->GetInControlNodes()) {
  747. if (!IsInBatchBranch(in_node)) {
  748. continue;
  749. }
  750. auto merge_node = InsertMergeNode(in_node, -1);
  751. if (merge_node == nullptr) {
  752. return INTERNAL_ERROR;
  753. }
  754. }
  755. return SUCCESS;
  756. }
  757. Status MultiBatchGraphCopyer::CopyNodeInBatchBranch(const NodePtr &node) {
  758. auto &copyed_nodes = nodes_to_batch_nodes_[node.get()];
  759. for (size_t i = 0; i < shapes_.size(); ++i) {
  760. auto copyed_node = InsertCopyNode(node, i);
  761. if (copyed_node == nullptr) {
  762. GELOGE(INTERNAL_ERROR, "Failed to add node to graph when copy node %s", node->GetName().c_str());
  763. return INTERNAL_ERROR;
  764. }
  765. copyed_nodes.emplace_back(copyed_node);
  766. GELOGI("Copy node %s type %s for shape %s, new node name %s", node->GetName().c_str(), node->GetType().c_str(),
  767. formats::JoinToString(shapes_.at(i)).c_str(), copyed_node->GetName().c_str());
  768. }
  769. return SUCCESS;
  770. }
  771. Status MultiBatchGraphCopyer::LinkEdges() {
  772. Status ret;
  773. for (const auto &node : origin_all_nodes_) {
  774. if (data_nodes_to_switchn_.count(node.get()) > 0) {
  775. ret = LinkDataToSwitchN(node);
  776. if (ret != SUCCESS) {
  777. return ret;
  778. }
  779. }
  780. if (nodes_to_merge_nodes_.count(node.get()) > 0) {
  781. ret = LinkToMerge(node);
  782. if (ret != SUCCESS) {
  783. return ret;
  784. }
  785. }
  786. if (nodes_to_batch_nodes_.count(node.get()) > 0) {
  787. ret = LinkToNodeInBranch(node);
  788. } else {
  789. ret = LinkToNodeOutBranch(node);
  790. }
  791. if (ret != SUCCESS) {
  792. return ret;
  793. }
  794. }
  795. return SUCCESS;
  796. }
  797. Status MultiBatchGraphCopyer::LinkDataToSwitchN(const NodePtr &data) {
  798. auto switchn = data_nodes_to_switchn_[data.get()];
  799. auto ret =
  800. GraphUtils::AddEdge(shape_data_->GetOutDataAnchor(kDataOutIndex), switchn->GetInDataAnchor(kSwitchNPredIndex));
  801. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link shape data %s to switchn %s",
  802. shape_data_->GetName().c_str(), switchn->GetName().c_str());
  803. return INTERNAL_ERROR);
  804. ret = GraphUtils::AddEdge(data->GetOutDataAnchor(kDataOutIndex), switchn->GetInDataAnchor(kSwitchNDataIndex));
  805. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link data %s to switchn %s",
  806. data->GetName().c_str(), switchn->GetName().c_str());
  807. return INTERNAL_ERROR);
  808. return SUCCESS;
  809. }
  810. Status MultiBatchGraphCopyer::LinkToMerge(const NodePtr &node) {
  811. auto &merge_nodes = nodes_to_merge_nodes_[node.get()];
  812. for (size_t i = 0; i < merge_nodes.size(); ++i) {
  813. auto merge_node = merge_nodes[i];
  814. if (merge_node == nullptr) {
  815. continue;
  816. }
  817. if (nodes_to_batch_nodes_.count(node.get()) > 0) {
  818. auto ret = LinkNodeToMerge(node, i, merge_node);
  819. if (ret != SUCCESS) {
  820. return ret;
  821. }
  822. continue;
  823. }
  824. if (data_nodes_to_switchn_.count(node.get()) > 0) {
  825. auto ret = LinkDataToMerge(node, merge_node);
  826. if (ret != SUCCESS) {
  827. return ret;
  828. }
  829. continue;
  830. }
  831. GELOGE(INTERNAL_ERROR, "The merge node %s is created, index %zu, but can not find the src node",
  832. merge_node->GetName().c_str(), i);
  833. return INTERNAL_ERROR;
  834. }
  835. return SUCCESS;
  836. }
  837. Status MultiBatchGraphCopyer::LinkToNodeInBranch(const NodePtr &node) {
  838. auto &branch_nodes = nodes_to_batch_nodes_[node.get()];
  839. for (size_t i = 0; i < branch_nodes.size(); ++i) {
  840. auto ret = CopyInDataEdges(node, i, branch_nodes[i]);
  841. if (ret != SUCCESS) {
  842. return ret;
  843. }
  844. ret = CopyInControlEdges(node, i, branch_nodes[i]);
  845. if (ret != SUCCESS) {
  846. return ret;
  847. }
  848. }
  849. return SUCCESS;
  850. }
  851. Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) {
  852. for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
  853. auto src_out_anchor = in_data_anchor->GetPeerOutAnchor();
  854. if (src_out_anchor == nullptr) {
  855. GELOGD("The node %s does not has input at index %d", node->GetName().c_str(), in_data_anchor->GetIdx());
  856. continue;
  857. }
  858. auto in_node = src_out_anchor->GetOwnerNode();
  859. if (!IsInBatchBranch(in_node)) {
  860. continue;
  861. }
  862. auto iter = nodes_to_merge_nodes_.find(in_node.get());
  863. if (iter == nodes_to_merge_nodes_.end()) {
  864. GELOGE(INTERNAL_ERROR, "Failed to link IO data edge from %s(%d) to %s(%d), no merge node found",
  865. in_node->GetName().c_str(), src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  866. return INTERNAL_ERROR;
  867. }
  868. auto merge_node = iter->second[src_out_anchor->GetIdx()];
  869. if (merge_node == nullptr) {
  870. GELOGE(INTERNAL_ERROR, "Failed to link IO data edge from %s(%d) to %s(%d), no merge node found",
  871. in_node->GetName().c_str(), src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  872. return INTERNAL_ERROR;
  873. }
  874. auto ret = src_out_anchor->Unlink(in_data_anchor);
  875. if (ret != GRAPH_SUCCESS) {
  876. GELOGE(INTERNAL_ERROR, "Failed to unlink the control edge from %s(%d) to %s(%d)", in_node->GetName().c_str(),
  877. src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  878. return INTERNAL_ERROR;
  879. }
  880. ret = GraphUtils::AddEdge(merge_node->GetOutDataAnchor(kMergeDataOutIndex), in_data_anchor);
  881. if (ret != GRAPH_SUCCESS) {
  882. GELOGE(INTERNAL_ERROR, "Failed to add data edge from %s(%d) to %s(%d)", merge_node->GetName().c_str(),
  883. src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  884. return INTERNAL_ERROR;
  885. }
  886. GELOGI("Link data edge from merge %s(from %s(%d)) to %s(%d)", merge_node->GetName().c_str(),
  887. in_node->GetName().c_str(), src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  888. }
  889. for (auto &in_node : node->GetInControlNodes()) {
  890. if (!IsInBatchBranch(in_node)) {
  891. continue;
  892. }
  893. auto iter = nodes_to_merge_nodes_.find(in_node.get());
  894. if (iter == nodes_to_merge_nodes_.end()) {
  895. GELOGE(INTERNAL_ERROR, "Failed to link IO control edge from %s to %s, no merge node found",
  896. in_node->GetName().c_str(), node->GetName().c_str());
  897. return INTERNAL_ERROR;
  898. }
  899. auto merge_node = iter->second[0];
  900. if (merge_node == nullptr) {
  901. GELOGE(INTERNAL_ERROR, "Failed to link IO control edge from %s to %s, no merge node found",
  902. in_node->GetName().c_str(), node->GetName().c_str());
  903. return INTERNAL_ERROR;
  904. }
  905. GE_IF_BOOL_EXEC(in_node->GetOutControlAnchor() == nullptr,
  906. GELOGE(INTERNAL_ERROR, "Innode outputControlAnchor is null");
  907. return INTERNAL_ERROR);
  908. auto ret = in_node->GetOutControlAnchor()->Unlink(node->GetInControlAnchor());
  909. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to unlink the control edge from %s to %s",
  910. in_node->GetName().c_str(), node->GetName().c_str());
  911. return INTERNAL_ERROR);
  912. ret = GraphUtils::AddEdge(merge_node->GetOutControlAnchor(), node->GetInControlAnchor());
  913. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add control edge from %s to %s",
  914. merge_node->GetName().c_str(), node->GetName().c_str());
  915. return INTERNAL_ERROR);
  916. GELOGI("Link control edge from merge %s(from %s) to %s", merge_node->GetName().c_str(), in_node->GetName().c_str(),
  917. node->GetName().c_str());
  918. }
  919. return SUCCESS;
  920. }
  921. Status MultiBatchGraphCopyer::InsertIdentityAfterSwitchN() {
  922. for (auto &node : graph_->GetAllNodes()) {
  923. if (node->GetType() != SWITCHN) {
  924. continue;
  925. }
  926. auto switchn_desc = node->GetOpDesc();
  927. GE_CHECK_NOTNULL(switchn_desc);
  928. size_t i = 0;
  929. for (auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  930. for (auto &in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  931. auto out_node = in_data_anchor->GetOwnerNode();
  932. auto op_desc = out_node->GetOpDesc();
  933. GE_CHECK_NOTNULL(op_desc);
  934. if ((out_node->GetType() == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) {
  935. GELOGD("No need to insert identity between %s and %s.", node->GetName().c_str(), out_node->GetName().c_str());
  936. continue;
  937. }
  938. auto identity_desc = MakeShared<OpDesc>(node->GetName() + "_identity_" + std::to_string(i), IDENTITY);
  939. GE_CHECK_NOTNULL(identity_desc);
  940. string batch_label;
  941. if (AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
  942. if (!AttrUtils::SetStr(identity_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
  943. GELOGE(FAILED, "Set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", identity_desc->GetName().c_str());
  944. return FAILED;
  945. }
  946. }
  947. auto data_desc = switchn_desc->GetOutputDesc(i);
  948. i++;
  949. GE_CHK_STATUS_RET(identity_desc->AddInputDesc("x", data_desc));
  950. GE_CHK_STATUS_RET(identity_desc->AddOutputDesc("y", data_desc));
  951. auto identity_node = graph_->AddNode(identity_desc);
  952. GE_CHECK_NOTNULL(identity_node);
  953. GE_CHK_STATUS_RET(out_data_anchor->LinkTo(identity_node->GetInDataAnchor(0)));
  954. GE_CHECK_NOTNULL(identity_node->GetOutControlAnchor());
  955. GE_CHK_STATUS_RET(identity_node->GetOutControlAnchor()->LinkTo(out_node->GetInControlAnchor()));
  956. }
  957. }
  958. }
  959. return SUCCESS;
  960. }
  961. Status ProcessMultiBatch(ComputeGraphPtr &graph) {
  962. const char *multi_batch_with_case = std::getenv("MULTI_BATCH_WITH_CASE");
  963. if (multi_batch_with_case != nullptr) {
  964. PassManager pass_manager;
  965. GE_CHK_STATUS_RET(pass_manager.AddPass("MultiBatchClonePass", new (std::nothrow) MultiBatchClonePass));
  966. return pass_manager.Run(graph);
  967. }
  968. std::vector<std::vector<int64_t>> shapes;
  969. if (!InitDynamicParams(shapes)) {
  970. GELOGD("There is no multi-batch options, no need to process multi-batch copy");
  971. return SUCCESS;
  972. }
  973. DynamicType dynamic_type = DynamicType::kDynamicUnknown;
  974. if (!GetLocalOmgContext().dynamic_batch_size.empty()) {
  975. dynamic_type = DynamicType::kDynamicBatch;
  976. } else if (!GetLocalOmgContext().dynamic_image_size.empty()) {
  977. dynamic_type = DynamicType::kDynamicImageSize;;
  978. } else if (!GetLocalOmgContext().dynamic_dims.empty()) {
  979. dynamic_type = DynamicType::kDynamicDims;
  980. }
  981. std::vector<std::pair<std::string, std::vector<int64_t>>> user_designate_shape;
  982. user_designate_shape = GetLocalOmgContext().user_input_dims;
  983. GELOGI("Begin to copy graph for multi-batch");
  984. multibatch::MultiBatchGraphCopyer copyer(graph);
  985. for (auto &shape : shapes) {
  986. copyer.AddShape(shape);
  987. }
  988. copyer.SetDynamicType(dynamic_type);
  989. copyer.SetUserDesignateShape(user_designate_shape);
  990. return copyer.CopyGraph();
  991. }
  992. // +-----------+
  993. // | Data | +-----------+ +-----------+ +-----------+
  994. // +-----------+ | Data | ----> | SoftmaxV2 | ----> | NetOutput |
  995. // \ /. +-----------+ +-----------+ +-----------+
  996. // \ /.
  997. // +-----------+ +-----------+ /. +-----------+ +-----------+ +-----------+
  998. // | Data | ----> | Case | S--- | Data | ----> | SoftmaxV2 | ----> | NetOutput |
  999. // +-----------+ +-----------+ \. +-----------+ +-----------+ +-----------+
  1000. // \ \.
  1001. // \ \. +-----------+ +-----------+ +-----------+
  1002. // +-----------+ | Data | ----> | SoftmaxV2 | ----> | NetOutput |
  1003. // | NetOutput | +-----------+ +-----------+ +-----------+
  1004. // +-----------+
  1005. // +-----------+ /
  1006. // | Data | --------------->/
  1007. // +-----------+
  1008. void GetDynamicShapeByGraph(const ComputeGraphPtr &graph, const NodePtr &node,
  1009. set<size_t> &dynamic_output_index, vector<string> &dynamic_output_dims) {
  1010. GELOGD("Try get dynamic shape info, Graph: %s, Node: %s", graph->GetName().c_str(), node->GetName().c_str());
  1011. const auto &func_desc = node->GetOpDesc();
  1012. if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) {
  1013. GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), node->GetName().c_str());
  1014. return;
  1015. }
  1016. const auto &dynamic_branch_names = func_desc->GetSubgraphInstanceNames();
  1017. for (size_t i = 0; i < func_desc->GetOutputsSize(); ++i) {
  1018. for (size_t j = 0; j < dynamic_branch_names.size(); ++j) {
  1019. const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[j]);
  1020. if (subgraph == nullptr) {
  1021. GELOGE(GE_GRAPH_EMPTY_SUBGRAPH, "Subgraph not found, name: %s", dynamic_branch_names[j].c_str());
  1022. dynamic_output_dims.clear();
  1023. return;
  1024. }
  1025. const auto &out_node = subgraph->FindFirstNodeMatchType(NETOUTPUT);
  1026. if (out_node == nullptr) {
  1027. GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "NetOutput not found, name: %s", dynamic_branch_names[j].c_str());
  1028. dynamic_output_dims.clear();
  1029. return;
  1030. }
  1031. GELOGI("Find the subgraph Output node %s and the index is %zu", out_node->GetName().c_str(), i);
  1032. const auto &out_desc = out_node->GetOpDesc();
  1033. if (out_desc == nullptr || out_desc->GetInputsSize() <= i) {
  1034. GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "Get Input desc failed, name: %s, index: %zu", out_node->GetName().c_str(), i);
  1035. dynamic_output_dims.clear();
  1036. return;
  1037. }
  1038. const auto &input_tensor = out_desc->GetInputDesc(i);
  1039. const auto &shape_msg = input_tensor.GetShape().ToString();
  1040. string output_shape = std::to_string(j) + "," + std::to_string(i) + "," + shape_msg;
  1041. GELOGI("The shape msg in dynamic batch is %s", output_shape.c_str());
  1042. dynamic_output_dims.emplace_back(output_shape);
  1043. uint32_t parent_index = 0;
  1044. (void)AttrUtils::GetInt(input_tensor, ATTR_NAME_PARENT_NODE_INDEX, parent_index);
  1045. dynamic_output_index.insert(parent_index);
  1046. }
  1047. }
  1048. }
  1049. // +-----------+ +-----------+ i = 0
  1050. // +----> | SoftmaxV2 | ----> |MemcpyAsync| ----> \.
  1051. // / +-----------+ +-----------+ \.
  1052. // / \.
  1053. // +-----------+ +-----------+ +-----------+ +-----------+ i = 1 +-----------+
  1054. // | Data | ----> | SwitchN | ----> | SoftmaxV2 | ----> |MemcpyAsync| ----> | Merge |
  1055. // +-----------+ +-----------+ +-----------+ +-----------+ +-----------+
  1056. // \ / \. j = 0
  1057. // \ +-----------+ +-----------+ i = 2 / \.
  1058. // +----> | SoftmaxV2 | ----> |MemcpyAsync| ----> / +-----------+
  1059. // +-----------+ +-----------+ | NetOutput |
  1060. // +-----------+
  1061. // +-----------+ /.
  1062. // | Data | --------------------------------------------------------------------------->/. j = 1
  1063. // +-----------+
  1064. void GetDynamicShapeByMerge(const ComputeGraphPtr &graph, const NodePtr &node,
  1065. set<size_t> &dynamic_output_index, vector<string> &dynamic_output_dims) {
  1066. GELOGD("Try get dynamic shape info, Graph: %s, Node: %s", graph->GetName().c_str(), node->GetName().c_str());
  1067. const auto &netoutput_desc = node->GetOpDesc();
  1068. const auto &inputnode_to_netoutput = node->GetInAllNodes();
  1069. for (size_t i = 0; i < inputnode_to_netoutput.size(); ++i) {
  1070. bool insert_by_mbatch = false;
  1071. (void)AttrUtils::GetBool(inputnode_to_netoutput.at(i)->GetOpDesc(), ATTR_INSERT_BY_MBATCH, insert_by_mbatch);
  1072. if (inputnode_to_netoutput.at(i)->GetType() == MERGE && insert_by_mbatch) {
  1073. GELOGI("Find the merge node %s with mbatch attr and the index is %zu",
  1074. inputnode_to_netoutput.at(i)->GetName().c_str(), i);
  1075. dynamic_output_index.insert(i);
  1076. for (size_t j = 0; j < inputnode_to_netoutput.at(i)->GetInNodes().size(); ++j) {
  1077. auto input_desc = inputnode_to_netoutput.at(i)->GetOpDesc();
  1078. auto input_tensor_desc = input_desc->GetInputDesc(j);
  1079. auto shape_msg = input_tensor_desc.GetShape().ToString();
  1080. string output_shape = std::to_string(j) + "," + std::to_string(i) + "," + shape_msg;
  1081. GELOGI("The shape msg in dynamic batch is %s", output_shape.c_str());
  1082. dynamic_output_dims.emplace_back(output_shape);
  1083. }
  1084. }
  1085. }
  1086. }
  1087. // Connect NetOutput directly
  1088. void GetDirectOutputShape(const ComputeGraphPtr &graph, const NodePtr &node,
  1089. const set<size_t> &dynamic_output_index, vector<string> &dynamic_output_dims) {
  1090. GELOGD("Try get directly shape info, Graph: %s, Node: %s", graph->GetName().c_str(), node->GetName().c_str());
  1091. const auto &netoutput_desc = node->GetOpDesc();
  1092. const auto &inputnode_to_netoutput = node->GetInAllNodes();
  1093. for (size_t i = 0; i < inputnode_to_netoutput.size(); ++i) {
  1094. if (dynamic_output_index.count(i) > 0) {
  1095. continue;
  1096. }
  1097. auto tensor_desc = netoutput_desc->GetInputDesc(i);
  1098. auto shape = tensor_desc.GetShape().ToString();
  1099. string static_output_shape = std::to_string(kStaticOutput) + "," + std::to_string(i) + "," + shape;
  1100. GELOGI("The static output shape msg is %s", static_output_shape.c_str());
  1101. dynamic_output_dims.emplace_back(static_output_shape);
  1102. }
  1103. }
  1104. Status GetDynamicOutputShape(ComputeGraphPtr &graph) {
  1105. GE_CHECK_NOTNULL(graph);
  1106. GELOGI("Start to get output dynamic batch shape message");
  1107. NodePtr net_output;
  1108. set<size_t> dynamic_output_index;
  1109. vector<string> dynamic_output_dims;
  1110. for (auto &node : graph->GetDirectNode()) {
  1111. if (node->GetType() == NETOUTPUT) {
  1112. net_output = node;
  1113. GetDynamicShapeByMerge(graph, node, dynamic_output_index, dynamic_output_dims);
  1114. } else if (node->GetType() == CASE) {
  1115. GetDynamicShapeByGraph(graph, node, dynamic_output_index, dynamic_output_dims);
  1116. }
  1117. }
  1118. if ((net_output != nullptr) && !dynamic_output_dims.empty()) {
  1119. GetDirectOutputShape(graph, net_output, dynamic_output_index, dynamic_output_dims);
  1120. if (!AttrUtils::SetListStr(net_output->GetOpDesc(), ATTR_NAME_DYNAMIC_OUTPUT_DIMS, dynamic_output_dims)) {
  1121. GELOGE(FAILED, "Set dynamic output dims attr failed");
  1122. return FAILED;
  1123. }
  1124. }
  1125. return SUCCESS;
  1126. }
  1127. } // namespace multibatch
  1128. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示