You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

multi_batch_copy_graph.cc 51 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
4 years ago
4 years ago
5 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/preprocess/multi_batch_copy_graph.h"
  17. #include <queue>
  18. #include <set>
  19. #include <string>
  20. #include "common/formats/utils/formats_trans_utils.h"
  21. #include "common/ge/ge_util.h"
  22. #include "common/util/error_manager/error_manager.h"
  23. #include "framework/common/debug/ge_log.h"
  24. #include "framework/common/ge_inner_error_codes.h"
  25. #include "framework/common/string_util.h"
  26. #include "framework/common/types.h"
  27. #include "framework/omg/omg_inner_types.h"
  28. #include "graph/debug/ge_attr_define.h"
  29. #include "graph/ge_context.h"
  30. #include "graph/passes/multi_batch_clone_pass.h"
  31. #include "graph/passes/prune_pass.h"
  32. #include "graph/preprocess/multi_batch_options.h"
  33. #include "graph/utils/attr_utils.h"
  34. #include "graph/utils/graph_utils.h"
  35. #include "graph/utils/node_utils.h"
  36. #include "graph/utils/tensor_utils.h"
  37. #include "graph/utils/type_utils.h"
  38. #include "inc/pass_manager.h"
  39. #include "graph/common/local_context.h"
  40. using std::set;
  41. using std::string;
  42. using std::vector;
  43. namespace ge {
  44. namespace multibatch {
  45. namespace {
  46. const char *const kMbatchSwitchnName = "mbatch-switch-name";
  47. const int kSwitchNDataIndex = 0;
  48. const int kSwitchNPredIndex = 1;
  49. const int kDataOutIndex = 0;
  50. const int kDataInIndex = 0;
  51. const int kMergeDataOutIndex = 0;
  52. const int kStaticOutput = -1;
  53. inline bool IsDataLikeType(const std::string &node_type) { return (node_type == DATA) || (node_type == AIPP); }
  54. NodePtr InsertMergeNodeToGraph(const std::string &name, size_t input_num, const ComputeGraphPtr &graph) {
  55. OpDescPtr desc = MakeShared<OpDesc>();
  56. if (desc == nullptr) {
  57. GELOGE(OUT_OF_MEMORY, "Failed to insert merge node, name %s", name.c_str());
  58. return nullptr;
  59. }
  60. desc->SetName(name);
  61. desc->SetType(MERGE);
  62. GeTensorDesc tensor_desc;
  63. for (size_t i = 0; i < input_num; ++i) {
  64. auto ret = desc->AddInputDesc("x" + std::to_string(i), tensor_desc);
  65. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS,
  66. GELOGE(INTERNAL_ERROR, "Failed to create merge node %s, failed to add input %zu, error-code %u",
  67. name.c_str(), i, ret);
  68. return nullptr);
  69. }
  70. auto ret = desc->AddOutputDesc("y", tensor_desc);
  71. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS,
  72. GELOGE(INTERNAL_ERROR, "Failed to create merge node %s, failed to add output 'y', error-code %u",
  73. name.c_str(), ret);
  74. return nullptr);
  75. tensor_desc.SetDataType(DT_INT32);
  76. ret = desc->AddOutputDesc("value_index", tensor_desc);
  77. if (ret != GRAPH_SUCCESS) {
  78. GELOGE(INTERNAL_ERROR, "Failed to create merge node %s, failed to add output 'value_index', error-code %u",
  79. name.c_str(), ret);
  80. return nullptr;
  81. }
  82. if (!AttrUtils::SetBool(desc, ATTR_INSERT_BY_MBATCH, true)) {
  83. GELOGE(INTERNAL_ERROR, "Failed to create merge node %s, failed to add attr", name.c_str());
  84. return nullptr;
  85. }
  86. return graph->AddNode(desc);
  87. }
  88. NodePtr InsertCopyNode(const NodePtr &node, size_t n) {
  89. const std::string &name = node->GetName() + "_ascend_mbatch_batch_" + std::to_string(n);
  90. auto src_op_desc = node->GetOpDesc();
  91. GE_IF_BOOL_EXEC(src_op_desc == nullptr, GELOGE(INTERNAL_ERROR, "Failed to copy node %s to %s, the OpDesc is null",
  92. node->GetName().c_str(), name.c_str());
  93. return nullptr);
  94. auto desc = AttrUtils::CopyOpDesc(src_op_desc);
  95. GE_IF_BOOL_EXEC(desc == nullptr, GELOGE(OUT_OF_MEMORY, "Failed to create op desc for copy node for node %s name %s",
  96. node->GetName().c_str(), name.c_str());
  97. return nullptr);
  98. desc->SetName(name);
  99. desc->CopyAttrsFrom(*src_op_desc);
  100. for (uint32_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) {
  101. auto input_desc = desc->MutableInputDesc(i);
  102. GE_IF_BOOL_EXEC(input_desc == nullptr,
  103. GELOGW("Get null input desc by index %u from node %s when copy from %s", i,
  104. desc->GetName().c_str(), node->GetName().c_str());
  105. continue);
  106. input_desc->CopyAttrsFrom(src_op_desc->GetInputDesc(i));
  107. }
  108. for (uint32_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) {
  109. auto output_desc = desc->MutableOutputDesc(i);
  110. GE_IF_BOOL_EXEC(output_desc == nullptr,
  111. GELOGE(INTERNAL_ERROR, "Failed to get output desc by index %u from node %s when copy from %s", i,
  112. desc->GetName().c_str(), node->GetName().c_str());
  113. return nullptr);
  114. output_desc->CopyAttrsFrom(src_op_desc->GetOutputDesc(i));
  115. }
  116. const std::string &batch_label = "Batch_" + std::to_string(n);
  117. if (!AttrUtils::SetStr(desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
  118. GELOGE(FAILED, "set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", name.c_str());
  119. return nullptr;
  120. }
  121. (void)AttrUtils::SetListStr(desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, {node->GetName()});
  122. auto graph = node->GetOwnerComputeGraph();
  123. return graph->AddNode(desc);
  124. }
  125. bool IsAllDimsPositive(const std::vector<int64_t> &dims) {
  126. for (auto dim : dims) {
  127. if (dim < 0) {
  128. return false;
  129. }
  130. }
  131. return true;
  132. }
  133. NodePtr InsertConst(const std::string &name, const ComputeGraphPtr &graph) {
  134. auto desc = MakeShared<OpDesc>();
  135. if (desc == nullptr) {
  136. GELOGE(OUT_OF_MEMORY, "Failed to create const op %s, out of memory", name.c_str());
  137. return nullptr;
  138. }
  139. desc->SetName(name);
  140. desc->SetType(CONSTANT);
  141. GeTensor tensor;
  142. tensor.SetData(std::vector<uint8_t>({0}));
  143. if (!AttrUtils::SetTensor(desc, ATTR_NAME_WEIGHTS, tensor)) {
  144. GELOGE(OUT_OF_MEMORY, "Failed to init tensor value for const %s", name.c_str());
  145. return nullptr;
  146. }
  147. if (!AttrUtils::SetBool(desc, ATTR_INSERT_BY_MBATCH, true)) {
  148. GELOGE(OUT_OF_MEMORY, "Failed to set insert flag for const node %s", name.c_str());
  149. return nullptr;
  150. }
  151. if (desc->AddOutputDesc(GeTensorDesc()) != GRAPH_SUCCESS) {
  152. GELOGE(OUT_OF_MEMORY, "Failed to add output desc for const node %s", name.c_str());
  153. return nullptr;
  154. }
  155. return graph->AddNode(desc);
  156. }
  157. bool IsOnlyOutputToAipp(const NodePtr &node) {
  158. for (const auto &out_node : node->GetOutDataNodes()) {
  159. if (out_node->GetType() != AIPP) {
  160. return false;
  161. }
  162. }
  163. return true;
  164. }
  165. Status CheckDataShape(const std::vector<NodePtr> &nodes) {
  166. size_t unknown_shape_count = 0;
  167. for (const auto &node : nodes) {
  168. if (node->GetType() != DATA) {
  169. continue;
  170. }
  171. for (auto dim : NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims()) {
  172. if (dim < 0) {
  173. unknown_shape_count++;
  174. break;
  175. }
  176. }
  177. }
  178. if (unknown_shape_count == 0) {
  179. ErrorManager::GetInstance().ATCReportErrMessage("E10040");
  180. GELOGE(PARAM_INVALID,
  181. "Need unknow shape data when user set --dynamic_batch_size, --dynamic_image_size or --dynamic_dims");
  182. return PARAM_INVALID;
  183. }
  184. return SUCCESS;
  185. }
  186. } // namespace
  187. Status MultiBatchGraphCopyer::CopyGraph() {
  188. auto ret = Init();
  189. if (ret != SUCCESS) {
  190. return ret;
  191. }
  192. if (LabelStatus() != SUCCESS) {
  193. GELOGE(INTERNAL_ERROR, "Failed to label status for all nodes.");
  194. return INTERNAL_ERROR;
  195. }
  196. ret = CheckAndParseDynamicData();
  197. if (ret != SUCCESS) {
  198. return ret;
  199. }
  200. ret = CreateNewNodes();
  201. if (ret != SUCCESS) {
  202. return ret;
  203. }
  204. ret = LinkEdges();
  205. if (ret != SUCCESS) {
  206. return ret;
  207. }
  208. ret = InsertIdentityAfterSwitchN();
  209. if (ret != SUCCESS) {
  210. GELOGE(INTERNAL_ERROR, "Failed to insert identity nodes after switchn node.");
  211. return INTERNAL_ERROR;
  212. }
  213. GELOGI("Begin to remove useless nodes by prune pass after copy process");
  214. PrunePass prune_pass;
  215. ret = prune_pass.Run(graph_);
  216. if (ret != SUCCESS) {
  217. GELOGE(ret, "Failed to prune");
  218. return ret;
  219. }
  220. return CheckCopyResult(origin_data_nodes_);
  221. }
  222. Status MultiBatchGraphCopyer::Init() {
  223. auto ret = CheckArguments();
  224. if (ret != SUCCESS) {
  225. return ret;
  226. }
  227. for (auto &node : graph_->GetAllNodes()) {
  228. origin_all_nodes_.emplace_back(node);
  229. if (IsDataLikeType(node->GetType())) {
  230. origin_data_nodes_.emplace_back(node);
  231. }
  232. }
  233. return SUCCESS;
  234. }
  235. Status MultiBatchGraphCopyer::LabelStatus() {
  236. for (const auto &data : origin_data_nodes_) {
  237. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  238. if (!IsAllDimsPositive(data_shape.GetDims())) {
  239. origin_nodes_status_[data.get()] = kNodeInBatchBranch;
  240. }
  241. }
  242. bool changed = true;
  243. // If anyone of in node is kNodeInBatchBranch, it is also kNodeInBatchBranch
  244. while (changed) {
  245. changed = false;
  246. for (const auto &node : origin_all_nodes_) {
  247. auto iter = origin_nodes_status_.find(node.get());
  248. if (iter != origin_nodes_status_.end()) {
  249. continue;
  250. }
  251. for (auto &in_node : node->GetInAllNodes()) {
  252. bool is_in_batch = origin_nodes_status_.find(in_node.get()) != origin_nodes_status_.end() &&
  253. origin_nodes_status_[in_node.get()] == kNodeInBatchBranch;
  254. if (is_in_batch) {
  255. origin_nodes_status_[node.get()] = kNodeInBatchBranch;
  256. changed = true;
  257. break;
  258. }
  259. }
  260. }
  261. }
  262. for (const auto &node : origin_all_nodes_) {
  263. if (!(node->GetOpDesc()->GetSubgraphInstanceNames().empty())) {
  264. origin_nodes_status_[node.get()] = kNodeNotSupportNode;
  265. continue;
  266. }
  267. if (node->GetType() == NETOUTPUT) {
  268. origin_nodes_status_[node.get()] = kNodeOutBatchBranch;
  269. continue;
  270. }
  271. if (IsDataLikeType(node->GetType())) {
  272. if (IsOnlyOutputToAipp(node)) {
  273. origin_nodes_status_[node.get()] = kNodeOutBatchBranch;
  274. } else {
  275. origin_nodes_status_[node.get()] = kNodeStartNode;
  276. }
  277. continue;
  278. }
  279. if (origin_nodes_status_.find(node.get()) == origin_nodes_status_.end()) {
  280. origin_nodes_status_[node.get()] = kNodeOutBatchBranch;
  281. }
  282. }
  283. return SUCCESS;
  284. }
  285. Status MultiBatchGraphCopyer::CheckAndParseDynamicData(){
  286. size_t unknown_shape_count = 0;
  287. auto data_name_and_shape = GetLocalOmgContext().user_input_dims;
  288. GELOGD("raw data_name_and_shape size: %zu", data_name_and_shape.size());
  289. for (const auto &node : origin_all_nodes_) {
  290. auto data_desc = NodeUtils::GetOutputDesc(*node, kDataOutIndex);
  291. auto data_shape = data_desc.GetShape();
  292. auto data_format = data_desc.GetFormat() == Format::FORMAT_NCHW ? "NCHW" :
  293. data_desc.GetFormat() == Format::FORMAT_NHWC ? "NHWC" : "Others";
  294. auto data_name = node->GetName();
  295. auto branch_status = GetNodeStatus(node);
  296. if (branch_status != kNodeStartNode) {
  297. continue;
  298. }
  299. if (IsAllDimsPositive(data_shape.GetDims())) {
  300. continue;
  301. }
  302. ++unknown_shape_count;
  303. auto iter = find(data_name_order_.begin(), data_name_order_.end(), data_name);
  304. if (iter == data_name_order_.end()) {
  305. if (dynamic_type_ == DynamicType::kDynamicBatch) {
  306. auto ret = CheckDynamicBatchShape(data_shape.GetDims(), data_name);
  307. if (!ret) {
  308. return PARAM_INVALID;
  309. }
  310. } else if (dynamic_type_ == DynamicType::kDynamicImageSize) {
  311. auto ret = CheckDynamicImageSizeShape(data_shape.GetDims(), data_name, data_format);
  312. if (!ret) {
  313. return PARAM_INVALID;
  314. }
  315. } else if (dynamic_type_ == DynamicType::kDynamicDims) {
  316. ErrorManager::GetInstance().ATCReportErrMessage("E10001",
  317. {"parameter", "reason"},
  318. {"--input_shape",
  319. "all dynamic data must be set in --input_shape"});
  320. GELOGE(INTERNAL_ERROR, "data: %s shape:%s must be set int --input_shape",
  321. node->GetName().c_str(), data_shape.ToString().c_str());
  322. return INTERNAL_ERROR;
  323. }
  324. data_name_and_shape.emplace_back(data_name, data_shape.GetDims());
  325. }
  326. }
  327. auto ret = ParserDataToDynmaicInfo(shapes_, data_name_and_shape, data_to_dynamic_info_);
  328. if (ret != SUCCESS){
  329. return ret;
  330. }
  331. if (unknown_shape_count == 0) {
  332. ErrorManager::GetInstance().ATCReportErrMessage("E10040");
  333. GELOGE(PARAM_INVALID,
  334. "Need unknow shape data when user set --dynamic_batch_size, --dynamic_image_size or --dynamic_dims");
  335. return PARAM_INVALID;
  336. }
  337. return SUCCESS;
  338. }
  339. Status MultiBatchGraphCopyer::CreateNewNodes() {
  340. shape_data_ = InsertShapeDataNode();
  341. if (shape_data_ == nullptr) {
  342. GELOGE(INTERNAL_ERROR, "Failed to create the shape data node for muti-batch");
  343. return INTERNAL_ERROR;
  344. }
  345. for (const auto &node : origin_all_nodes_) {
  346. auto node_type = node->GetType();
  347. Status ret = INTERNAL_ERROR;
  348. auto branch_status = GetNodeStatus(node);
  349. GELOGD("Process node %s, status %d", node->GetName().c_str(), static_cast<int>(branch_status));
  350. switch (branch_status) {
  351. case kNodeStartNode:
  352. GELOGD("Name: %s, type: %s, status: kNodeStartNode.", node->GetName().c_str(), node->GetType().c_str());
  353. ret = InsertSwitchNForData(node);
  354. if (ret == SUCCESS) {
  355. ret = UpdateMaxShapeToData(node);
  356. }
  357. break;
  358. case kNodeInBatchBranch:
  359. GELOGD("Name: %s, type: %s, status: kNodeInBatchBranch.", node->GetName().c_str(), node->GetType().c_str());
  360. ret = CopyNodeInBatchBranch(node);
  361. break;
  362. case kNodeOutBatchBranch:
  363. GELOGD("Name: %s, type: %s, status: kNodeOutBatchBranch.", node->GetName().c_str(), node->GetType().c_str());
  364. ret = InsertMergeForEdgeNode(node);
  365. break;
  366. case kNodeNotSupportNode:
  367. GELOGD("Name: %s, type: %s, status: kNodeNotSupportNode.", node->GetName().c_str(), node->GetType().c_str());
  368. break;
  369. default:
  370. GELOGE(INTERNAL_ERROR, "Unexpected status %d on node %s", static_cast<int>(branch_status),
  371. node->GetName().c_str());
  372. break;
  373. }
  374. if (ret != SUCCESS) {
  375. GELOGE(ret, "Failed to deal with node %s in multi-batch process", node->GetName().c_str());
  376. return ret;
  377. }
  378. }
  379. return SUCCESS;
  380. }
  381. NodePtr MultiBatchGraphCopyer::InsertMergeNode(const NodePtr &node, int index) {
  382. if (index < 0) {
  383. // the merge node must has data inputs, if origin connection is a control
  384. // edge, we use data edge instead
  385. index = 0;
  386. }
  387. auto &merge_nodes = nodes_to_merge_nodes_[node.get()];
  388. if (merge_nodes.empty()) {
  389. auto count = node->GetAllOutDataAnchorsSize();
  390. if (count == 0) {
  391. count = 1;
  392. }
  393. merge_nodes.resize(count, nullptr);
  394. }
  395. if (merge_nodes.at(index) != nullptr) {
  396. return merge_nodes[index];
  397. }
  398. auto merge_node_name = node->GetName() + "_ascend_mbatch_merge_" + std::to_string(index);
  399. auto merge_node = InsertMergeNodeToGraph(merge_node_name, shapes_.size(), node->GetOwnerComputeGraph());
  400. GE_IF_BOOL_EXEC(merge_node == nullptr, GELOGE(INTERNAL_ERROR, "Failed to create merge node for node %s, out index %d",
  401. node->GetName().c_str(), index);
  402. return nullptr);
  403. merge_nodes[index] = merge_node;
  404. GELOGI("Create merge node %s for node %s index %d", merge_node_name.c_str(), node->GetName().c_str(), index);
  405. return merge_node;
  406. }
  407. Status MultiBatchGraphCopyer::CopyInDataEdges(const NodePtr &origin_node, int batch_num, const NodePtr &copyed_node) {
  408. for (auto &in_anchor : origin_node->GetAllInDataAnchors()) {
  409. auto origin_src_anchor = in_anchor->GetPeerOutAnchor();
  410. if (origin_src_anchor == nullptr) {
  411. GELOGD("The node %s does not have input on index %d", origin_node->GetName().c_str(), in_anchor->GetIdx());
  412. continue;
  413. }
  414. auto origin_src_node = origin_src_anchor->GetOwnerNode();
  415. auto dst_anchor = copyed_node->GetInDataAnchor(in_anchor->GetIdx());
  416. GE_CHECK_NOTNULL(dst_anchor);
  417. auto switchn_iter = data_nodes_to_switchn_.find(origin_src_node.get());
  418. if (switchn_iter != data_nodes_to_switchn_.end()) {
  419. auto ret = GraphUtils::AddEdge(switchn_iter->second->GetOutDataAnchor(batch_num), dst_anchor);
  420. if (ret != GRAPH_SUCCESS) {
  421. GELOGE(INTERNAL_ERROR, "Failed to add data edge between %s(%d) to %s(%d), error-code %u",
  422. switchn_iter->second->GetName().c_str(), batch_num, copyed_node->GetName().c_str(), in_anchor->GetIdx(),
  423. ret);
  424. return INTERNAL_ERROR;
  425. }
  426. GELOGD("Add data edge from %s(%d) to %s(%d)", switchn_iter->second->GetName().c_str(), batch_num,
  427. copyed_node->GetName().c_str(), in_anchor->GetIdx());
  428. continue;
  429. }
  430. auto batch_branch_iter = nodes_to_batch_nodes_.find(origin_src_node.get());
  431. if (batch_branch_iter != nodes_to_batch_nodes_.end()) {
  432. auto src_batch_node = batch_branch_iter->second.at(batch_num);
  433. auto ret = GraphUtils::AddEdge(src_batch_node->GetOutDataAnchor(origin_src_anchor->GetIdx()), dst_anchor);
  434. if (ret != GRAPH_SUCCESS) {
  435. GELOGE(INTERNAL_ERROR, "Failed to add data edge between %s(%d) to %s(%d), error-code %u",
  436. src_batch_node->GetName().c_str(), batch_num, copyed_node->GetName().c_str(), in_anchor->GetIdx(), ret);
  437. return INTERNAL_ERROR;
  438. }
  439. GELOGD("Add data edge from %s(%d) to %s(%d)", src_batch_node->GetName().c_str(), batch_num,
  440. copyed_node->GetName().c_str(), in_anchor->GetIdx());
  441. continue;
  442. }
  443. auto ret = GraphUtils::AddEdge(origin_src_anchor, dst_anchor);
  444. if (ret != GRAPH_SUCCESS) {
  445. GELOGE(INTERNAL_ERROR, "Failed to add data edge between origin node %s(%d) to copyed %s(%d)",
  446. origin_src_node->GetName().c_str(), origin_src_anchor->GetIdx(), copyed_node->GetName().c_str(),
  447. dst_anchor->GetIdx());
  448. return INTERNAL_ERROR;
  449. }
  450. GELOGD("Add data edge between branch-out %s(%d) to branch-in %s(%d)", origin_src_node->GetName().c_str(),
  451. origin_src_anchor->GetIdx(), copyed_node->GetName().c_str(), dst_anchor->GetIdx());
  452. }
  453. return SUCCESS;
  454. }
  455. Status MultiBatchGraphCopyer::CopyInControlEdges(const NodePtr &node, int batch_num, const NodePtr &copyed_node) {
  456. for (auto &origin_src_node : node->GetInControlNodes()) {
  457. auto switchn_iter = data_nodes_to_switchn_.find(origin_src_node.get());
  458. if (switchn_iter != data_nodes_to_switchn_.end()) {
  459. // reconnect data node
  460. auto ret = GraphUtils::AddEdge(switchn_iter->second->GetOutControlAnchor(), copyed_node->GetInControlAnchor());
  461. if (ret != GRAPH_SUCCESS) {
  462. GELOGE(INTERNAL_ERROR, "Failed to add control edge between %s to %s, error-code %u",
  463. switchn_iter->second->GetName().c_str(), copyed_node->GetName().c_str(), ret);
  464. return INTERNAL_ERROR;
  465. }
  466. GELOGD("Add control edge from %s to %s", switchn_iter->second->GetName().c_str(), copyed_node->GetName().c_str());
  467. continue;
  468. }
  469. auto batch_branch_iter = nodes_to_batch_nodes_.find(origin_src_node.get());
  470. if (batch_branch_iter != nodes_to_batch_nodes_.end()) {
  471. // reconnect node in batch branch
  472. auto src_batch_node = batch_branch_iter->second.at(batch_num);
  473. auto ret = GraphUtils::AddEdge(src_batch_node->GetOutControlAnchor(), copyed_node->GetInControlAnchor());
  474. if (ret != GRAPH_SUCCESS) {
  475. GELOGE(INTERNAL_ERROR, "Failed to add data edge between %s to %s, error-code %u",
  476. src_batch_node->GetName().c_str(), copyed_node->GetName().c_str(), ret);
  477. return INTERNAL_ERROR;
  478. }
  479. GELOGD("Add control edge from %s to %s", src_batch_node->GetName().c_str(), copyed_node->GetName().c_str());
  480. continue;
  481. }
  482. auto ret = GraphUtils::AddEdge(origin_src_node->GetOutControlAnchor(), copyed_node->GetInControlAnchor());
  483. if (ret != GRAPH_SUCCESS) {
  484. GELOGE(INTERNAL_ERROR, "Failed to add control edge from origin %s to copyed %s",
  485. origin_src_node->GetName().c_str(), copyed_node->GetName().c_str());
  486. return INTERNAL_ERROR;
  487. }
  488. GELOGD("Add control edge between branch-out %s to branch-in %s", origin_src_node->GetName().c_str(),
  489. copyed_node->GetName().c_str());
  490. }
  491. return SUCCESS;
  492. }
  493. NodePtr MultiBatchGraphCopyer::InsertShapeDataNode() {
  494. auto desc = MakeShared<OpDesc>();
  495. if (desc == nullptr) {
  496. GELOGE(OUT_OF_MEMORY, "Failed to create shape data node, out of memory");
  497. return nullptr;
  498. }
  499. string node_name = "ascend_mbatch_shape_data";
  500. // Only flush subgraph name
  501. if (graph_->GetParentGraph() != nullptr) {
  502. node_name = graph_->GetName() + "_" + node_name;
  503. }
  504. desc->SetName(node_name);
  505. desc->SetType(DATA);
  506. GeTensorDesc tensor_desc;
  507. tensor_desc.SetFormat(FORMAT_ND);
  508. tensor_desc.SetShape(GeShape({static_cast<int64_t>(shapes_.at(0).size())}));
  509. tensor_desc.SetDataType(DT_INT64);
  510. auto ret = desc->AddInputDesc(tensor_desc);
  511. if (ret != GRAPH_SUCCESS) {
  512. GELOGE(INTERNAL_ERROR, "Failed to add input desc for created data");
  513. return nullptr;
  514. }
  515. ret = desc->AddOutputDesc(tensor_desc);
  516. if (ret != GRAPH_SUCCESS) {
  517. GELOGE(INTERNAL_ERROR, "Failed to add output desc for created data");
  518. return nullptr;
  519. }
  520. if (!AttrUtils::SetBool(desc, ATTR_INSERT_BY_MBATCH, true)) {
  521. GELOGE(INTERNAL_ERROR, "Failed to add attr for created data");
  522. return nullptr;
  523. }
  524. auto data_node = graph_->AddNode(desc);
  525. if (data_node == nullptr) {
  526. GELOGE(INTERNAL_ERROR, "Failed to add shape data node to graph");
  527. return nullptr;
  528. }
  529. ret = GraphUtils::AppendInputNode(graph_, data_node);
  530. if (ret != GRAPH_SUCCESS) {
  531. GELOGE(INTERNAL_ERROR, "Failed to append data node %s as input to graph", data_node->GetName().c_str());
  532. return nullptr;
  533. }
  534. return data_node;
  535. }
  536. Status MultiBatchGraphCopyer::CheckArguments() {
  537. if (graph_ == nullptr) {
  538. GELOGE(PARAM_INVALID, "Failed to copy graph, the graph is null");
  539. return PARAM_INVALID;
  540. }
  541. return CheckDynamicParams(shapes_);
  542. }
  543. Status MultiBatchGraphCopyer::CheckCopyResult(const std::vector<NodePtr> &start_nodes) {
  544. for (auto &node : start_nodes) {
  545. if (IsOnlyOutputToAipp(node)) {
  546. continue;
  547. }
  548. auto dims = NodeUtils::GetOutputDesc(*node, kDataOutIndex).GetShape().GetDims();
  549. if (!IsAllDimsPositive(dims)) {
  550. GELOGE(INTERNAL_ERROR, "Failed to copy multi batch graph, the node %s still has unknown shape %s",
  551. node->GetName().c_str(), formats::ShapeToString(dims).c_str());
  552. return INTERNAL_ERROR;
  553. }
  554. }
  555. return SUCCESS;
  556. }
  557. bool MultiBatchGraphCopyer::IsInBatchBranch(const NodePtr &node) {
  558. return (nodes_to_batch_nodes_.count(node.get()) > 0) || (data_nodes_to_switchn_.count(node.get()) > 0);
  559. }
  560. Status MultiBatchGraphCopyer::LinkDataToMerge(const NodePtr &data, const NodePtr &merge) {
  561. // The caller should make sure that the there is a SwitchN node in the map
  562. auto &switchn = data_nodes_to_switchn_[data.get()];
  563. GELOGI("Link edge between data %s to merge %s throw switchn %s", data->GetName().c_str(), merge->GetName().c_str(),
  564. switchn->GetName().c_str());
  565. for (size_t i = 0; i < shapes_.size(); ++i) {
  566. auto ret = GraphUtils::AddEdge(switchn->GetOutDataAnchor(i), merge->GetInDataAnchor(i));
  567. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS,
  568. GELOGE(INTERNAL_ERROR, "Failed to add edge between switchn %s(%zu) to merge %s(%zu), error-code %u",
  569. switchn->GetName().c_str(), i, merge->GetName().c_str(), i, ret);
  570. return INTERNAL_ERROR);
  571. }
  572. return SUCCESS;
  573. }
  574. Status MultiBatchGraphCopyer::LinkNodeToMerge(const NodePtr &node, int out_index, const NodePtr &merge) {
  575. auto &copyed_nodes = nodes_to_batch_nodes_[node.get()];
  576. if (copyed_nodes.size() != shapes_.size()) {
  577. GELOGE(INTERNAL_ERROR,
  578. "Failed to create merge node for node %s, the copyed nodes for it count %zu different with shape %zu",
  579. node->GetName().c_str(), copyed_nodes.size(), shapes_.size());
  580. return INTERNAL_ERROR;
  581. }
  582. for (size_t i = 0; i < copyed_nodes.size(); ++i) {
  583. auto src_node = copyed_nodes[i];
  584. if (src_node->GetAllOutDataAnchorsSize() == 0) {
  585. // if the node does not has any data output, we should create an const for it, like this:
  586. // c d
  587. // node ---> const ---> merge
  588. auto const_name = src_node->GetName() + "_merge_const";
  589. GELOGI("The node %s on the batch branch edge does not have any data output, create a const %s for it",
  590. src_node->GetName().c_str(), const_name.c_str());
  591. auto const_node = InsertConst(const_name, graph_);
  592. GE_IF_BOOL_EXEC(const_node == nullptr,
  593. GELOGE(OUT_OF_MEMORY, "Failed to create const for node %s to connect to a merge node",
  594. src_node->GetName().c_str());
  595. return OUT_OF_MEMORY);
  596. auto ret = GraphUtils::AddEdge(src_node->GetOutControlAnchor(), const_node->GetInControlAnchor());
  597. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add control edge from %s to %s",
  598. src_node->GetName().c_str(), const_node->GetName().c_str());
  599. return INTERNAL_ERROR);
  600. src_node = const_node;
  601. }
  602. auto ret = GraphUtils::AddEdge(src_node->GetOutDataAnchor(out_index), merge->GetInDataAnchor(i));
  603. if (ret != GRAPH_SUCCESS) {
  604. GELOGE(INTERNAL_ERROR,
  605. "Failed to add edge between copyed node %s(%d) to inserted merge node %s(%zu), error-code %u",
  606. copyed_nodes[i]->GetName().c_str(), out_index, merge->GetName().c_str(), i, ret);
  607. return INTERNAL_ERROR;
  608. }
  609. }
  610. return SUCCESS;
  611. }
  612. Status MultiBatchGraphCopyer::UpdateMaxShapeToData(const NodePtr &data) {
  613. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  614. auto data_name = data->GetName();
  615. if (IsAllDimsPositive(data_shape.GetDims())) {
  616. return SUCCESS;
  617. }
  618. size_t max_shape_index = 0;
  619. int64_t max_size = 0;
  620. for (size_t i = 0; i < shapes_.size(); ++i) {
  621. int64_t size = 1;
  622. for (auto dim : data_to_dynamic_info_.at(data_name).at(i)) {
  623. if (INT64_MAX / dim < size) {
  624. GELOGE(PARAM_INVALID, "The shape %s size overflow",
  625. formats::ShapeToString(data_to_dynamic_info_[data_name].at(i)).c_str());
  626. return PARAM_INVALID;
  627. }
  628. size *= dim;
  629. }
  630. if (size > max_size) {
  631. max_size = size;
  632. max_shape_index = i;
  633. }
  634. }
  635. // must not be error, the calc result has been checked in function InsertSwitchNForData
  636. (void)CalcShape(data_to_dynamic_info_.at(data_name).at(max_shape_index), data_shape);
  637. auto ret = NodeUtils::UpdateOutputShape(*data, kDataOutIndex, data_shape);
  638. if (ret != GRAPH_SUCCESS) {
  639. GELOGE(INTERNAL_ERROR, "Failed to update output shape for data %s", data->GetName().c_str());
  640. return INTERNAL_ERROR;
  641. }
  642. ret = NodeUtils::UpdateInputShape(*data, kDataInIndex, data_shape);
  643. if (ret != GRAPH_SUCCESS) {
  644. GELOGE(INTERNAL_ERROR, "Failed to update input shape for data %s", data->GetName().c_str());
  645. return INTERNAL_ERROR;
  646. }
  647. GELOGI("Update the data %s input/output shape to the max %s", data->GetName().c_str(),
  648. formats::ShapeToString(data_shape).c_str());
  649. return SUCCESS;
  650. }
  651. Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &data) {
  652. auto data_shape = NodeUtils::GetOutputDesc(*data, kDataOutIndex).GetShape();
  653. auto data_name = data->GetName();
  654. (void)AttrUtils::SetListInt(data->GetOpDesc(), ATTR_MBATCH_ORIGIN_INPUT_DIMS, data_shape.GetDims());
  655. if (IsAllDimsPositive(data_shape.GetDims())) {
  656. GELOGI("The shape of data %s are positive(%s), skip the multi batch process", data->GetName().c_str(),
  657. data_shape.ToString().c_str());
  658. return SUCCESS;
  659. }
  660. auto switchn_desc = MakeShared<OpDesc>();
  661. if (switchn_desc == nullptr) {
  662. GELOGE(OUT_OF_MEMORY, "Failed to create switchn for data %s", data->GetName().c_str());
  663. return OUT_OF_MEMORY;
  664. }
  665. switchn_desc->SetName(data->GetName() + "_ascend_mbatch_switchn");
  666. switchn_desc->SetType(SWITCHN);
  667. GeTensorDesc tensor(NodeUtils::GetOutputDesc(*data, kDataOutIndex));
  668. if (switchn_desc->AddInputDesc("data", tensor) != GRAPH_SUCCESS) { // data
  669. return OUT_OF_MEMORY;
  670. }
  671. GeTensorDesc pred_tensor;
  672. if (switchn_desc->AddInputDesc("pred_value", pred_tensor) != GRAPH_SUCCESS) { // pred
  673. return OUT_OF_MEMORY;
  674. }
  675. std::vector<std::string> input_dims_str;
  676. for (size_t i = 0; i < shapes_.size(); ++i) {
  677. auto shape = data_shape;
  678. auto ret = CalcShape(data_to_dynamic_info_.at(data_name).at(i), shape);
  679. if (ret != SUCCESS) {
  680. GELOGE(ret, "Failed to calculate the batched shape for data node %s, the shapes may not match",
  681. data->GetName().c_str());
  682. return ret;
  683. }
  684. tensor.SetShape(shape);
  685. string input_str;
  686. int64_t tensor_size = 0;
  687. (void)TensorUtils::GetTensorSizeInBytes(tensor, tensor_size);
  688. input_str = TypeUtils::FormatToSerialString(tensor.GetFormat()) + ":" +
  689. TypeUtils::DataTypeToSerialString(tensor.GetDataType()) + ":" + data->GetName() + ":" +
  690. std::to_string(tensor_size) + ":" + std::to_string(tensor.GetShape().GetDimNum()) + ":" +
  691. formats::JoinToString(tensor.GetShape().GetDims());
  692. input_dims_str.emplace_back(input_str);
  693. if (!AttrUtils::SetListInt(tensor, ATTR_NAME_SWITCHN_PRED_VALUE, shapes_.at(i))) {
  694. GELOGE(INTERNAL_ERROR, "Failed to add attr value on output %zu tensor", i);
  695. return INTERNAL_ERROR;
  696. }
  697. (void) AttrUtils::SetListInt(tensor, ATTR_NAME_COMBINED_DYNAMIC_DIMS, shape.GetDims());
  698. if (switchn_desc->AddOutputDesc("output" + std::to_string(i), tensor) != GRAPH_SUCCESS) {
  699. GELOGE(GRAPH_FAILED, "Opdesc AddOutputDesc failed");
  700. return GRAPH_FAILED;
  701. }
  702. GELOGD("The SwitchN %s output index %zu, shape %s", switchn_desc->GetName().c_str(), i, shape.ToString().c_str());
  703. }
  704. (void)AttrUtils::SetListStr(data->GetOpDesc(), "_all_origin_gears_inputs", input_dims_str);
  705. if (!AttrUtils::SetListStr(switchn_desc, ATTR_USER_DESIGNEATE_SHAPE_ORDER, data_name_order_)) {
  706. GELOGE(INTERNAL_ERROR, "Failed to add user designate shape order attr on switchn node %s",
  707. switchn_desc->GetName().c_str());
  708. return INTERNAL_ERROR;
  709. }
  710. if (!AttrUtils::SetBool(switchn_desc, ATTR_INSERT_BY_MBATCH, true)) {
  711. GELOGE(INTERNAL_ERROR, "Failed to add insert attr on switchn node %s", switchn_desc->GetName().c_str());
  712. return INTERNAL_ERROR;
  713. }
  714. if (!AttrUtils::SetStr(data->GetOpDesc(), kMbatchSwitchnName, switchn_desc->GetName())) {
  715. GELOGE(INTERNAL_ERROR, "Failed to add switchn attr on data node %s", data->GetName().c_str());
  716. return INTERNAL_ERROR;
  717. }
  718. if (StampDynamicType(switchn_desc) != SUCCESS) {
  719. GELOGE(INTERNAL_ERROR, "Failed to add dynamic type attr on switchn node %s", switchn_desc->GetName().c_str());
  720. return INTERNAL_ERROR;
  721. }
  722. auto switchn = graph_->AddNode(switchn_desc);
  723. if (switchn == nullptr) {
  724. GELOGE(OUT_OF_MEMORY, "Failed to create switchn %s from desc", switchn_desc->GetName().c_str());
  725. return OUT_OF_MEMORY;
  726. }
  727. data_nodes_to_switchn_[data.get()] = switchn;
  728. return SUCCESS;
  729. }
  730. Status MultiBatchGraphCopyer::InsertMergeForEdgeNode(const NodePtr &node) {
  731. for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
  732. auto src_out_anchor = in_data_anchor->GetPeerOutAnchor();
  733. if (src_out_anchor == nullptr) {
  734. GELOGD("The node %s does not has input at index %d", node->GetName().c_str(), in_data_anchor->GetIdx());
  735. continue;
  736. }
  737. auto in_node = src_out_anchor->GetOwnerNode();
  738. if (!IsInBatchBranch(in_node)) {
  739. continue;
  740. }
  741. auto merge_node = InsertMergeNode(in_node, src_out_anchor->GetIdx());
  742. if (merge_node == nullptr) {
  743. return INTERNAL_ERROR;
  744. }
  745. }
  746. for (auto &in_node : node->GetInControlNodes()) {
  747. if (!IsInBatchBranch(in_node)) {
  748. continue;
  749. }
  750. auto merge_node = InsertMergeNode(in_node, -1);
  751. if (merge_node == nullptr) {
  752. return INTERNAL_ERROR;
  753. }
  754. }
  755. return SUCCESS;
  756. }
  757. Status MultiBatchGraphCopyer::CopyNodeInBatchBranch(const NodePtr &node) {
  758. auto &copyed_nodes = nodes_to_batch_nodes_[node.get()];
  759. for (size_t i = 0; i < shapes_.size(); ++i) {
  760. auto copyed_node = InsertCopyNode(node, i);
  761. if (copyed_node == nullptr) {
  762. GELOGE(INTERNAL_ERROR, "Failed to add node to graph when copy node %s", node->GetName().c_str());
  763. return INTERNAL_ERROR;
  764. }
  765. copyed_nodes.emplace_back(copyed_node);
  766. GELOGI("Copy node %s type %s for shape %s, new node name %s", node->GetName().c_str(), node->GetType().c_str(),
  767. formats::JoinToString(shapes_.at(i)).c_str(), copyed_node->GetName().c_str());
  768. }
  769. return SUCCESS;
  770. }
  771. Status MultiBatchGraphCopyer::LinkEdges() {
  772. Status ret;
  773. for (const auto &node : origin_all_nodes_) {
  774. if (data_nodes_to_switchn_.count(node.get()) > 0) {
  775. ret = LinkDataToSwitchN(node);
  776. if (ret != SUCCESS) {
  777. return ret;
  778. }
  779. }
  780. if (nodes_to_merge_nodes_.count(node.get()) > 0) {
  781. ret = LinkToMerge(node);
  782. if (ret != SUCCESS) {
  783. return ret;
  784. }
  785. }
  786. if (nodes_to_batch_nodes_.count(node.get()) > 0) {
  787. ret = LinkToNodeInBranch(node);
  788. } else {
  789. ret = LinkToNodeOutBranch(node);
  790. }
  791. if (ret != SUCCESS) {
  792. return ret;
  793. }
  794. }
  795. return SUCCESS;
  796. }
  797. Status MultiBatchGraphCopyer::LinkDataToSwitchN(const NodePtr &data) {
  798. auto switchn = data_nodes_to_switchn_[data.get()];
  799. auto ret =
  800. GraphUtils::AddEdge(shape_data_->GetOutDataAnchor(kDataOutIndex), switchn->GetInDataAnchor(kSwitchNPredIndex));
  801. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link shape data %s to switchn %s",
  802. shape_data_->GetName().c_str(), switchn->GetName().c_str());
  803. return INTERNAL_ERROR);
  804. ret = GraphUtils::AddEdge(data->GetOutDataAnchor(kDataOutIndex), switchn->GetInDataAnchor(kSwitchNDataIndex));
  805. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to link data %s to switchn %s",
  806. data->GetName().c_str(), switchn->GetName().c_str());
  807. return INTERNAL_ERROR);
  808. return SUCCESS;
  809. }
  810. Status MultiBatchGraphCopyer::LinkToMerge(const NodePtr &node) {
  811. auto &merge_nodes = nodes_to_merge_nodes_[node.get()];
  812. for (size_t i = 0; i < merge_nodes.size(); ++i) {
  813. auto merge_node = merge_nodes[i];
  814. if (merge_node == nullptr) {
  815. continue;
  816. }
  817. if (nodes_to_batch_nodes_.count(node.get()) > 0) {
  818. auto ret = LinkNodeToMerge(node, i, merge_node);
  819. if (ret != SUCCESS) {
  820. return ret;
  821. }
  822. continue;
  823. }
  824. if (data_nodes_to_switchn_.count(node.get()) > 0) {
  825. auto ret = LinkDataToMerge(node, merge_node);
  826. if (ret != SUCCESS) {
  827. return ret;
  828. }
  829. continue;
  830. }
  831. GELOGE(INTERNAL_ERROR, "The merge node %s is created, index %zu, but can not find the src node",
  832. merge_node->GetName().c_str(), i);
  833. return INTERNAL_ERROR;
  834. }
  835. return SUCCESS;
  836. }
  837. Status MultiBatchGraphCopyer::LinkToNodeInBranch(const NodePtr &node) {
  838. auto &branch_nodes = nodes_to_batch_nodes_[node.get()];
  839. for (size_t i = 0; i < branch_nodes.size(); ++i) {
  840. auto ret = CopyInDataEdges(node, i, branch_nodes[i]);
  841. if (ret != SUCCESS) {
  842. return ret;
  843. }
  844. ret = CopyInControlEdges(node, i, branch_nodes[i]);
  845. if (ret != SUCCESS) {
  846. return ret;
  847. }
  848. }
  849. return SUCCESS;
  850. }
  851. Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) {
  852. for (auto &in_data_anchor : node->GetAllInDataAnchors()) {
  853. auto src_out_anchor = in_data_anchor->GetPeerOutAnchor();
  854. if (src_out_anchor == nullptr) {
  855. GELOGD("The node %s does not has input at index %d", node->GetName().c_str(), in_data_anchor->GetIdx());
  856. continue;
  857. }
  858. auto in_node = src_out_anchor->GetOwnerNode();
  859. if (!IsInBatchBranch(in_node)) {
  860. continue;
  861. }
  862. auto iter = nodes_to_merge_nodes_.find(in_node.get());
  863. if (iter == nodes_to_merge_nodes_.end()) {
  864. GELOGE(INTERNAL_ERROR, "Failed to link IO data edge from %s(%d) to %s(%d), no merge node found",
  865. in_node->GetName().c_str(), src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  866. return INTERNAL_ERROR;
  867. }
  868. auto merge_node = iter->second[src_out_anchor->GetIdx()];
  869. if (merge_node == nullptr) {
  870. GELOGE(INTERNAL_ERROR, "Failed to link IO data edge from %s(%d) to %s(%d), no merge node found",
  871. in_node->GetName().c_str(), src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  872. return INTERNAL_ERROR;
  873. }
  874. auto ret = src_out_anchor->Unlink(in_data_anchor);
  875. if (ret != GRAPH_SUCCESS) {
  876. GELOGE(INTERNAL_ERROR, "Failed to unlink the control edge from %s(%d) to %s(%d)", in_node->GetName().c_str(),
  877. src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  878. return INTERNAL_ERROR;
  879. }
  880. ret = GraphUtils::AddEdge(merge_node->GetOutDataAnchor(kMergeDataOutIndex), in_data_anchor);
  881. if (ret != GRAPH_SUCCESS) {
  882. GELOGE(INTERNAL_ERROR, "Failed to add data edge from %s(%d) to %s(%d)", merge_node->GetName().c_str(),
  883. src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  884. return INTERNAL_ERROR;
  885. }
  886. GELOGI("Link data edge from merge %s(from %s(%d)) to %s(%d)", merge_node->GetName().c_str(),
  887. in_node->GetName().c_str(), src_out_anchor->GetIdx(), node->GetName().c_str(), in_data_anchor->GetIdx());
  888. }
  889. for (auto &in_node : node->GetInControlNodes()) {
  890. if (!IsInBatchBranch(in_node)) {
  891. continue;
  892. }
  893. auto iter = nodes_to_merge_nodes_.find(in_node.get());
  894. if (iter == nodes_to_merge_nodes_.end()) {
  895. GELOGE(INTERNAL_ERROR, "Failed to link IO control edge from %s to %s, no merge node found",
  896. in_node->GetName().c_str(), node->GetName().c_str());
  897. return INTERNAL_ERROR;
  898. }
  899. auto merge_node = iter->second[0];
  900. if (merge_node == nullptr) {
  901. GELOGE(INTERNAL_ERROR, "Failed to link IO control edge from %s to %s, no merge node found",
  902. in_node->GetName().c_str(), node->GetName().c_str());
  903. return INTERNAL_ERROR;
  904. }
  905. GE_IF_BOOL_EXEC(in_node->GetOutControlAnchor() == nullptr,
  906. GELOGE(INTERNAL_ERROR, "Innode outputControlAnchor is null");
  907. return INTERNAL_ERROR);
  908. auto ret = in_node->GetOutControlAnchor()->Unlink(node->GetInControlAnchor());
  909. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to unlink the control edge from %s to %s",
  910. in_node->GetName().c_str(), node->GetName().c_str());
  911. return INTERNAL_ERROR);
  912. ret = GraphUtils::AddEdge(merge_node->GetOutControlAnchor(), node->GetInControlAnchor());
  913. GE_IF_BOOL_EXEC(ret != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "Failed to add control edge from %s to %s",
  914. merge_node->GetName().c_str(), node->GetName().c_str());
  915. return INTERNAL_ERROR);
  916. GELOGI("Link control edge from merge %s(from %s) to %s", merge_node->GetName().c_str(), in_node->GetName().c_str(),
  917. node->GetName().c_str());
  918. }
  919. return SUCCESS;
  920. }
  921. Status MultiBatchGraphCopyer::InsertIdentityAfterSwitchN() {
  922. for (auto &node : graph_->GetAllNodes()) {
  923. if (node->GetType() != SWITCHN) {
  924. continue;
  925. }
  926. auto switchn_desc = node->GetOpDesc();
  927. GE_CHECK_NOTNULL(switchn_desc);
  928. size_t i = 0;
  929. for (auto &out_data_anchor : node->GetAllOutDataAnchors()) {
  930. for (auto &in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) {
  931. auto out_node = in_data_anchor->GetOwnerNode();
  932. auto op_desc = out_node->GetOpDesc();
  933. GE_CHECK_NOTNULL(op_desc);
  934. if ((out_node->GetType() == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) {
  935. GELOGD("No need to insert identity between %s and %s.", node->GetName().c_str(), out_node->GetName().c_str());
  936. continue;
  937. }
  938. auto identity_desc = MakeShared<OpDesc>(node->GetName() + "_identity_" + std::to_string(i), IDENTITY);
  939. GE_CHECK_NOTNULL(identity_desc);
  940. string batch_label;
  941. if (AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
  942. if (!AttrUtils::SetStr(identity_desc, ATTR_NAME_BATCH_LABEL, batch_label)) {
  943. GELOGE(FAILED, "Set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", identity_desc->GetName().c_str());
  944. return FAILED;
  945. }
  946. }
  947. auto data_desc = switchn_desc->GetOutputDesc(i);
  948. i++;
  949. GE_CHK_STATUS_RET(identity_desc->AddInputDesc("x", data_desc));
  950. GE_CHK_STATUS_RET(identity_desc->AddOutputDesc("y", data_desc));
  951. auto identity_node = graph_->AddNode(identity_desc);
  952. GE_CHECK_NOTNULL(identity_node);
  953. GE_CHK_STATUS_RET(out_data_anchor->LinkTo(identity_node->GetInDataAnchor(0)));
  954. GE_CHECK_NOTNULL(identity_node->GetOutControlAnchor());
  955. GE_CHK_STATUS_RET(identity_node->GetOutControlAnchor()->LinkTo(out_node->GetInControlAnchor()));
  956. }
  957. }
  958. }
  959. return SUCCESS;
  960. }
  961. Status ProcessMultiBatch(ComputeGraphPtr &graph) {
  962. std::vector<std::vector<int64_t>> shapes;
  963. if (!InitDynamicParams(shapes)) {
  964. GELOGD("There is no multi-batch options, no need to process multi-batch copy");
  965. return SUCCESS;
  966. }
  967. DynamicType dynamic_type = DynamicType::kDynamicUnknown;
  968. if (!GetLocalOmgContext().dynamic_batch_size.empty()) {
  969. dynamic_type = DynamicType::kDynamicBatch;
  970. } else if (!GetLocalOmgContext().dynamic_image_size.empty()) {
  971. dynamic_type = DynamicType::kDynamicImageSize;;
  972. } else if (!GetLocalOmgContext().dynamic_dims.empty()) {
  973. dynamic_type = DynamicType::kDynamicDims;
  974. }
  975. std::vector<std::pair<std::string, std::vector<int64_t>>> user_designate_shape;
  976. user_designate_shape = GetLocalOmgContext().user_input_dims;
  977. GELOGI("Begin to copy graph for multi-batch");
  978. multibatch::MultiBatchGraphCopyer copyer(graph);
  979. for (auto &shape : shapes) {
  980. copyer.AddShape(shape);
  981. }
  982. copyer.SetDynamicType(dynamic_type);
  983. copyer.SetUserDesignateShape(user_designate_shape);
  984. return copyer.CopyGraph();
  985. }
  986. // +-----------+
  987. // | Data | +-----------+ +-----------+ +-----------+
  988. // +-----------+ | Data | ----> | SoftmaxV2 | ----> | NetOutput |
  989. // \ /. +-----------+ +-----------+ +-----------+
  990. // \ /.
  991. // +-----------+ +-----------+ /. +-----------+ +-----------+ +-----------+
  992. // | Data | ----> | Case | S--- | Data | ----> | SoftmaxV2 | ----> | NetOutput |
  993. // +-----------+ +-----------+ \. +-----------+ +-----------+ +-----------+
  994. // \ \.
  995. // \ \. +-----------+ +-----------+ +-----------+
  996. // +-----------+ | Data | ----> | SoftmaxV2 | ----> | NetOutput |
  997. // | NetOutput | +-----------+ +-----------+ +-----------+
  998. // +-----------+
  999. // +-----------+ /
  1000. // | Data | --------------->/
  1001. // +-----------+
  1002. void GetDynamicShapeByGraph(const ComputeGraphPtr &graph, const NodePtr &node,
  1003. set<size_t> &dynamic_output_index, vector<string> &dynamic_output_dims) {
  1004. GELOGD("Try get dynamic shape info, Graph: %s, Node: %s", graph->GetName().c_str(), node->GetName().c_str());
  1005. const auto &func_desc = node->GetOpDesc();
  1006. if (!func_desc->HasAttr(ATTR_NAME_BATCH_NUM)) {
  1007. GELOGD("Graph: %s Not multi-batch, Node: %s", graph->GetName().c_str(), node->GetName().c_str());
  1008. return;
  1009. }
  1010. const auto &dynamic_branch_names = func_desc->GetSubgraphInstanceNames();
  1011. for (size_t i = 0; i < func_desc->GetOutputsSize(); ++i) {
  1012. for (size_t j = 0; j < dynamic_branch_names.size(); ++j) {
  1013. const auto &subgraph = graph->GetSubgraph(dynamic_branch_names[j]);
  1014. if (subgraph == nullptr) {
  1015. GELOGE(GE_GRAPH_EMPTY_SUBGRAPH, "Subgraph not found, name: %s", dynamic_branch_names[j].c_str());
  1016. dynamic_output_dims.clear();
  1017. return;
  1018. }
  1019. const auto &out_node = subgraph->FindFirstNodeMatchType(NETOUTPUT);
  1020. if (out_node == nullptr) {
  1021. GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "NetOutput not found, name: %s", dynamic_branch_names[j].c_str());
  1022. dynamic_output_dims.clear();
  1023. return;
  1024. }
  1025. GELOGI("Find the subgraph Output node %s and the index is %zu", out_node->GetName().c_str(), i);
  1026. const auto &out_desc = out_node->GetOpDesc();
  1027. if (out_desc == nullptr || out_desc->GetInputsSize() <= i) {
  1028. GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "Get Input desc failed, name: %s, index: %zu", out_node->GetName().c_str(), i);
  1029. dynamic_output_dims.clear();
  1030. return;
  1031. }
  1032. const auto &input_tensor = out_desc->GetInputDesc(i);
  1033. const auto &shape_msg = input_tensor.GetShape().ToString();
  1034. string output_shape = std::to_string(j) + "," + std::to_string(i) + "," + shape_msg;
  1035. GELOGI("The shape msg in dynamic batch is %s", output_shape.c_str());
  1036. dynamic_output_dims.emplace_back(output_shape);
  1037. uint32_t parent_index = 0;
  1038. (void)AttrUtils::GetInt(input_tensor, ATTR_NAME_PARENT_NODE_INDEX, parent_index);
  1039. dynamic_output_index.insert(parent_index);
  1040. }
  1041. }
  1042. }
  1043. // +-----------+ +-----------+ i = 0
  1044. // +----> | SoftmaxV2 | ----> |MemcpyAsync| ----> \.
  1045. // / +-----------+ +-----------+ \.
  1046. // / \.
  1047. // +-----------+ +-----------+ +-----------+ +-----------+ i = 1 +-----------+
  1048. // | Data | ----> | SwitchN | ----> | SoftmaxV2 | ----> |MemcpyAsync| ----> | Merge |
  1049. // +-----------+ +-----------+ +-----------+ +-----------+ +-----------+
  1050. // \ / \. j = 0
  1051. // \ +-----------+ +-----------+ i = 2 / \.
  1052. // +----> | SoftmaxV2 | ----> |MemcpyAsync| ----> / +-----------+
  1053. // +-----------+ +-----------+ | NetOutput |
  1054. // +-----------+
  1055. // +-----------+ /.
  1056. // | Data | --------------------------------------------------------------------------->/. j = 1
  1057. // +-----------+
  1058. void GetDynamicShapeByMerge(const ComputeGraphPtr &graph, const NodePtr &node,
  1059. set<size_t> &dynamic_output_index, vector<string> &dynamic_output_dims) {
  1060. GELOGD("Try get dynamic shape info, Graph: %s, Node: %s", graph->GetName().c_str(), node->GetName().c_str());
  1061. const auto &netoutput_desc = node->GetOpDesc();
  1062. const auto &inputnode_to_netoutput = node->GetInAllNodes();
  1063. for (size_t i = 0; i < inputnode_to_netoutput.size(); ++i) {
  1064. bool insert_by_mbatch = false;
  1065. (void)AttrUtils::GetBool(inputnode_to_netoutput.at(i)->GetOpDesc(), ATTR_INSERT_BY_MBATCH, insert_by_mbatch);
  1066. if (inputnode_to_netoutput.at(i)->GetType() == MERGE && insert_by_mbatch) {
  1067. GELOGI("Find the merge node %s with mbatch attr and the index is %zu",
  1068. inputnode_to_netoutput.at(i)->GetName().c_str(), i);
  1069. dynamic_output_index.insert(i);
  1070. for (size_t j = 0; j < inputnode_to_netoutput.at(i)->GetInNodes().size(); ++j) {
  1071. auto input_desc = inputnode_to_netoutput.at(i)->GetOpDesc();
  1072. auto input_tensor_desc = input_desc->GetInputDesc(j);
  1073. auto shape_msg = input_tensor_desc.GetShape().ToString();
  1074. string output_shape = std::to_string(j) + "," + std::to_string(i) + "," + shape_msg;
  1075. GELOGI("The shape msg in dynamic batch is %s", output_shape.c_str());
  1076. dynamic_output_dims.emplace_back(output_shape);
  1077. }
  1078. }
  1079. }
  1080. }
  1081. // Connect NetOutput directly
  1082. void GetDirectOutputShape(const ComputeGraphPtr &graph, const NodePtr &node,
  1083. const set<size_t> &dynamic_output_index, vector<string> &dynamic_output_dims) {
  1084. GELOGD("Try get directly shape info, Graph: %s, Node: %s", graph->GetName().c_str(), node->GetName().c_str());
  1085. const auto &netoutput_desc = node->GetOpDesc();
  1086. const auto &inputnode_to_netoutput = node->GetInAllNodes();
  1087. for (size_t i = 0; i < inputnode_to_netoutput.size(); ++i) {
  1088. if (dynamic_output_index.count(i) > 0) {
  1089. continue;
  1090. }
  1091. auto tensor_desc = netoutput_desc->GetInputDesc(i);
  1092. auto shape = tensor_desc.GetShape().ToString();
  1093. string static_output_shape = std::to_string(kStaticOutput) + "," + std::to_string(i) + "," + shape;
  1094. GELOGI("The static output shape msg is %s", static_output_shape.c_str());
  1095. dynamic_output_dims.emplace_back(static_output_shape);
  1096. }
  1097. }
  1098. Status GetDynamicOutputShape(ComputeGraphPtr &graph) {
  1099. GE_CHECK_NOTNULL(graph);
  1100. GELOGI("Start to get output dynamic batch shape message");
  1101. NodePtr net_output;
  1102. set<size_t> dynamic_output_index;
  1103. vector<string> dynamic_output_dims;
  1104. for (auto &node : graph->GetDirectNode()) {
  1105. if (node->GetType() == NETOUTPUT) {
  1106. net_output = node;
  1107. GetDynamicShapeByMerge(graph, node, dynamic_output_index, dynamic_output_dims);
  1108. } else if (node->GetType() == CASE) {
  1109. GetDynamicShapeByGraph(graph, node, dynamic_output_index, dynamic_output_dims);
  1110. }
  1111. }
  1112. if ((net_output != nullptr) && !dynamic_output_dims.empty()) {
  1113. GetDirectOutputShape(graph, net_output, dynamic_output_index, dynamic_output_dims);
  1114. if (!AttrUtils::SetListStr(net_output->GetOpDesc(), ATTR_NAME_DYNAMIC_OUTPUT_DIMS, dynamic_output_dims)) {
  1115. GELOGE(FAILED, "Set dynamic output dims attr failed");
  1116. return FAILED;
  1117. }
  1118. }
  1119. return SUCCESS;
  1120. }
  1121. } // namespace multibatch
  1122. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示