You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dynamic_shape_partition.cc 44 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/partition/dynamic_shape_partition.h"
  17. #include <algorithm>
  18. #include <iostream>
  19. #include <memory>
  20. #include <queue>
  21. #include <sstream>
  22. #include <string>
  23. #include <unordered_set>
  24. #include <vector>
  25. #include "common/ge/ge_util.h"
  26. #include "framework/common/debug/ge_log.h"
  27. #include "framework/common/debug/log.h"
  28. #include "framework/common/types.h"
  29. #include "graph/debug/ge_attr_define.h"
  30. #include "graph/utils/graph_utils.h"
  31. #include "graph/utils/op_desc_utils.h"
  32. #include "common/omg_util.h"
  33. #define REQUIRE(cond, ...) \
  34. do { \
  35. if (!(cond)) { \
  36. REPORT_INNER_ERROR("E19999", __VA_ARGS__); \
  37. GELOGE(FAILED, "[Dynamic shape partition]" __VA_ARGS__); \
  38. return FAILED; \
  39. } \
  40. } while (0)
  41. #define REQUIRE_NOT_NULL(cond, ...) REQUIRE(((cond) != nullptr), __VA_ARGS__)
  42. #define REQUIRE_SUCCESS(cond, ...) REQUIRE(((cond) == SUCCESS), __VA_ARGS__)
  43. #define REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__)
  44. namespace ge {
  45. using Cluster = DynamicShapePartitioner::Cluster;
  46. using ClusterPtr = std::shared_ptr<Cluster>;
  47. static bool IsSingleOpScene(const ComputeGraphPtr &root_graph) {
  48. for (const auto &node : root_graph->GetAllNodes()) {
  49. GE_CHECK_NOTNULL(node->GetOpDesc());
  50. // not do partition in single op scene.
  51. bool is_singleop = false;
  52. (void)AttrUtils::GetBool(node->GetOpDesc(), ATTR_SINGLE_OP_SCENE, is_singleop);
  53. if (is_singleop) {
  54. return true;
  55. }
  56. }
  57. return false;
  58. }
  59. Status DynamicShapePartitioner::Partition() {
  60. REQUIRE_NOT_NULL(root_graph_, "[Check][Param] Graph is nullptr.");
  61. if (IsSingleOpScene(root_graph_)) {
  62. GELOGD("Skip dynamic shape partition as in single op scene.");
  63. REQUIRE(AttrUtils::SetBool(*root_graph_, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, false),
  64. "[Set][Attr] dynamic shape partitioned flag on root graph:%s failed.", root_graph_->GetName().c_str());
  65. return SUCCESS;
  66. }
  67. GELOGD("Start dynamic shape partition graph %s.", root_graph_->GetName().c_str());
  68. REQUIRE_SUCCESS(MarkUnknownShapeNodes(), "[Call][MarkUnknownShapeNodes] failed, root grah name:%s.",
  69. root_graph_->GetName().c_str());
  70. if (unknown_shape_nodes_.empty()) {
  71. GELOGD("Skip dynamic shape partition of graph %s as all nodes are known shape.", root_graph_->GetName().c_str());
  72. REQUIRE(AttrUtils::SetBool(*root_graph_, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, false),
  73. "[Set][Attr] dynamic shape partitioned flag on root graph %s failed.", root_graph_->GetName().c_str());
  74. return SUCCESS;
  75. }
  76. REQUIRE(AttrUtils::SetBool(*root_graph_, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, true),
  77. "[Set][Attr] dynamic shape partitioned flag on root graph %s failed.", root_graph_->GetName().c_str());
  78. REQUIRE_SUCCESS(CtrlEdgeTransfer(), "[Call][CtrlEdgeTransfer] failed, graph:%s.", root_graph_->GetName().c_str());
  79. DumpGraph("_Before_DSP");
  80. auto status = PartitionImpl();
  81. GELOGD("%s.", DebugString().c_str());
  82. if (status != SUCCESS) {
  83. GELOGE(status, "[Call][PartitionImpl] Failed dynamic shape partition graph:%s, ret:%s",
  84. root_graph_->GetName().c_str(), DebugString().c_str());
  85. }
  86. DumpGraph("_After_DSP");
  87. GELOGD("Finish dynamic shape partition graph %s.", root_graph_->GetName().c_str());
  88. ClearResource();
  89. return status;
  90. }
  91. Status DynamicShapePartitioner::CtrlEdgeTransfer() {
  92. GELOGD("Do ctrl edge transfer start!");
  93. GE_CHECK_NOTNULL(root_graph_);
  94. bool is_dynamic_shape = false;
  95. (void)AttrUtils::GetBool(root_graph_, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape);
  96. if (!is_dynamic_shape) {
  97. return SUCCESS;
  98. }
  99. for (auto &subgraph : root_graph_->GetAllSubgraphs()) {
  100. for (ge::NodePtr &n : subgraph->GetDirectNode()) {
  101. auto op_desc = n->GetOpDesc();
  102. if (op_desc == nullptr) {
  103. continue;
  104. }
  105. auto op_type = op_desc->GetType();
  106. if (op_type == CONSTANT || op_type == CONSTANTOP) {
  107. if (n->GetInAllNodes().empty()) {
  108. GELOGD("[CtrlEdgeTransferPass] node [%s] in nodes is empty", n->GetName().c_str());
  109. continue;
  110. }
  111. GELOGD("start to tranfer ctrl edge for const node [%s]", n->GetName().c_str());
  112. for (auto &in_control_node : n->GetInControlNodes()) {
  113. GE_CHECK_NOTNULL(in_control_node);
  114. GE_CHK_STATUS_RET(ge::GraphUtils::RemoveEdge(in_control_node->GetOutControlAnchor(),
  115. n->GetInControlAnchor()),
  116. "[Remove][Edge] between %s and %s failed",
  117. in_control_node->GetOutControlAnchor()->GetOwnerNode()->GetName().c_str(),
  118. n->GetName().c_str());
  119. for (auto &out_node : n->GetOutNodes()) {
  120. if (out_node == nullptr) {
  121. continue;
  122. }
  123. GE_CHK_STATUS_RET(ge::GraphUtils::AddEdge(in_control_node->GetOutControlAnchor(),
  124. out_node->GetInControlAnchor()),
  125. "[Add][Edge] between %s and %s failed.",
  126. in_control_node->GetOutControlAnchor()->GetOwnerNode()->GetName().c_str(),
  127. out_node->GetName().c_str());
  128. }
  129. }
  130. }
  131. }
  132. }
  133. GELOGD("Do ctrl edge transfer end!");
  134. return SUCCESS;
  135. }
  136. Status DynamicShapePartitioner::PartitionImpl() {
  137. REQUIRE_SUCCESS(root_graph_->TopologicalSorting(),
  138. "[Call][TopologicalSorting] failed, graph:%s.", root_graph_->GetName().c_str());
  139. REQUIRE_SUCCESS(InitClusters(), "[Init][Clusters] failed, graph:%s.", root_graph_->GetName().c_str());
  140. REQUIRE_SUCCESS(MergeClusters(), "[Merge][Clusters] failed, graph:%s.", root_graph_->GetName().c_str());
  141. PruneUniqueClusters();
  142. REQUIRE_SUCCESS(BuildPartitionFrame(), "[Build][PartitionFrame] failed, graph:%s.", root_graph_->GetName().c_str());
  143. REQUIRE_SUCCESS(CombinePartitionFrame(),
  144. "[Combine][PartitionFrame] failed, graph:%s.", root_graph_->GetName().c_str());
  145. REQUIRE_SUCCESS(BuildPartitionSubgraph(),
  146. "[Build][PartitionSubgraph] failed, graph:%s.", root_graph_->GetName().c_str());
  147. return SUCCESS;
  148. }
  149. void DynamicShapePartitioner::PruneUniqueClusters() {
  150. for (auto &node : root_graph_->GetDirectNode()) {
  151. auto cluster = node_2_cluster_[node];
  152. if (unique_clusters_.count(cluster) != 0) {
  153. continue;
  154. }
  155. if (unique_clusters_.insert(cluster).second) {
  156. sorted_unique_clusters_.emplace_back(cluster);
  157. }
  158. }
  159. auto comp_func = [](std::shared_ptr<Cluster> clu_a, std::shared_ptr<Cluster> clu_b) -> bool {
  160. return clu_a->Id() < clu_b->Id();
  161. };
  162. std::sort(sorted_unique_clusters_.begin(), sorted_unique_clusters_.end(), comp_func);
  163. }
  164. Status DynamicShapePartitioner::BuildPartitionFrame() {
  165. for (const auto &cluster : sorted_unique_clusters_) {
  166. REQUIRE_SUCCESS(cluster->BuildFrame(), "[Build][Frame] of cluster[%lu] failed.", cluster->Id());
  167. }
  168. return SUCCESS;
  169. }
  170. Status DynamicShapePartitioner::CombinePartitionFrame() {
  171. for (const auto &cluster : sorted_unique_clusters_) {
  172. REQUIRE_SUCCESS(cluster->CombinePartitionFrame(), "[Combine][Frame] of cluster[%lu] failed.", cluster->Id());
  173. }
  174. return SUCCESS;
  175. }
  176. Status DynamicShapePartitioner::BuildPartitionSubgraph() {
  177. for (const auto &cluster : sorted_unique_clusters_) {
  178. REQUIRE_SUCCESS(cluster->BuildPartitionSubgraph(), "[Build][SubGraph] of cluster[%lu] failed.", cluster->Id());
  179. }
  180. return SUCCESS;
  181. }
  182. std::string DynamicShapePartitioner::DebugString() const {
  183. size_t unknown = 0;
  184. size_t known = 0;
  185. size_t data = 0;
  186. size_t netoutput = 0;
  187. size_t is_inputnode = 0;
  188. size_t stage = 0;
  189. std::stringstream ss;
  190. ss << "All unknown shape nodes:" << std::endl;
  191. for (const auto &node : unknown_shape_nodes_) {
  192. ss << " [" << node->GetName() << "](" << node->GetType() << ")" << std::endl;
  193. }
  194. for (const auto &cluster : unique_clusters_) {
  195. if (cluster->IsUnknownShape()) {
  196. unknown++;
  197. } else if (cluster->IsKnownShape()) {
  198. known++;
  199. } else if (cluster->IsData()) {
  200. data++;
  201. } else if (cluster->IsNetOutput()) {
  202. netoutput++;
  203. } else if (cluster->IsInputNode()) {
  204. is_inputnode++;
  205. } else if (cluster->IsIndependent()) {
  206. stage++;
  207. }
  208. }
  209. ss << "All clusters:" << unique_clusters_.size() << ", data:" << data << ", known:" << known
  210. << ", unknown:" << unknown << ", netoutput:" << netoutput << ", is_inputnode:" << is_inputnode
  211. << ", stage:" << stage << std::endl;
  212. for (const auto &cluster : unique_clusters_) {
  213. ss << " " << cluster->DebugString() << std::endl;
  214. }
  215. return ss.str();
  216. }
  217. void DynamicShapePartitioner::DumpGraph(const std::string &suffix) {
  218. GraphUtils::DumpGEGraphToOnnx(*root_graph_, root_graph_->GetName() + suffix);
  219. for (const auto &sub_graph : root_graph_->GetAllSubgraphs()) {
  220. GraphUtils::DumpGEGraphToOnnx(*sub_graph, sub_graph->GetName() + suffix);
  221. }
  222. }
  223. void DynamicShapePartitioner::ClearResource() {
  224. for (const auto &cluster : unique_clusters_) {
  225. cluster->Clear();
  226. }
  227. node_2_cluster_.clear();
  228. ordered_cluster_.clear();
  229. unique_clusters_.clear();
  230. sorted_unique_clusters_.clear();
  231. unknown_shape_nodes_.clear();
  232. root_graph_.reset();
  233. }
  234. Status DynamicShapePartitioner::MarkUnknownShapeNodes() {
  235. for (auto &node : root_graph_->GetDirectNode()) {
  236. REQUIRE_SUCCESS(CollectSpreadUnknownShapeNodes(node),
  237. "[Call][CollectSpreadUnknownShapeNodes] for node:%s failed.", node->GetName().c_str());
  238. }
  239. return SUCCESS;
  240. }
  241. Status DynamicShapePartitioner::InitClusters() {
  242. auto graph = root_graph_;
  243. size_t rank = 0;
  244. for (const auto &node : graph->GetDirectNode()) {
  245. Cluster::Type type = Cluster::DATA;
  246. bool is_input = ((node->GetType() == CONSTANT) || (node->GetType() == CONSTANTOP)) && node->GetInNodes().empty();
  247. REQUIRE_NOT_NULL(node->GetOpDesc(), "[Get][OpDesc] op_desc is null, graph:%s", graph->GetName().c_str());
  248. if (node->GetType() == DATA) {
  249. type = Cluster::DATA;
  250. } else if (is_input) {
  251. type = Cluster::INPUT_NODE;
  252. } else if (node->GetType() == NETOUTPUT) {
  253. type = Cluster::NETOUTPUT;
  254. } else if ((node->GetType() == PARTITIONEDCALL) && (node->GetOpDesc()->HasAttr(ATTR_STAGE_LEVEL))) {
  255. type = Cluster::STAGE;
  256. } else if (unknown_shape_nodes_.count(node) > 0) {
  257. type = Cluster::UNKNOWN_SHAPE;
  258. } else {
  259. type = Cluster::KNOWN_SHAPE;
  260. }
  261. auto cluster = MakeShared<Cluster>(rank++, type, node, this);
  262. REQUIRE_NOT_NULL(cluster, "[New][Memory] for cluster failed.");
  263. node_2_cluster_[node] = cluster;
  264. int64_t group_index = -1;
  265. if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) {
  266. GELOGD("[%s] is rts control flow Op, group index: %ld", node->GetName().c_str(), group_index);
  267. auto &control_cluster = control_clusters_[group_index];
  268. control_cluster.emplace_back(cluster);
  269. }
  270. // Already sorted topologically, so access to the parent cluster is safe
  271. for (const auto &parent : node->GetInAllNodes()) {
  272. cluster->AddInput(node_2_cluster_[parent]);
  273. }
  274. }
  275. for (const auto &node : graph->GetDirectNode()) {
  276. GELOGD("Make cluster for node %s : %s.", node->GetName().c_str(), node_2_cluster_[node]->DebugString().c_str());
  277. }
  278. return SUCCESS;
  279. }
  280. Status DynamicShapePartitioner::TopologicalSortClusters(const OrderedFilter &ordered_filter) {
  281. ordered_cluster_.clear();
  282. // BFS topological sort clusters for known shape cluster
  283. std::queue<ClusterPtr> ready_clusters;
  284. std::unordered_map<ClusterPtr, size_t> cluster_pending_count;
  285. std::unordered_set<ClusterPtr> seen_clusters;
  286. for (auto &node : root_graph_->GetDirectNode()) {
  287. auto &cluster = node_2_cluster_[node];
  288. if (seen_clusters.count(cluster) != 0) {
  289. continue;
  290. }
  291. seen_clusters.insert(cluster);
  292. auto pending_count = cluster->Inputs().size();
  293. if (pending_count == 0) {
  294. ready_clusters.push(cluster);
  295. } else {
  296. cluster_pending_count[cluster] = pending_count;
  297. }
  298. }
  299. size_t rank = 0;
  300. while (!ready_clusters.empty()) {
  301. auto cluster = ready_clusters.front();
  302. ready_clusters.pop();
  303. cluster->UpdateRank(rank++);
  304. if (ordered_filter == nullptr || ordered_filter(cluster)) {
  305. ordered_cluster_.push_back(cluster);
  306. }
  307. for (const auto &out_cluster : cluster->Outputs()) {
  308. if (cluster_pending_count[out_cluster] > 0 && --cluster_pending_count[out_cluster] == 0) {
  309. ready_clusters.push(out_cluster);
  310. }
  311. }
  312. }
  313. if (rank != seen_clusters.size()) {
  314. return FAILED;
  315. }
  316. return SUCCESS;
  317. }
  318. namespace {
  319. static std::string ToString(const std::vector<ClusterPtr> &clusters) {
  320. if (clusters.empty()) {
  321. return "()";
  322. }
  323. std::stringstream ss;
  324. ss << "(";
  325. auto iter = clusters.begin();
  326. for (size_t i = 0; i < clusters.size() - 1; i++) {
  327. ss << (*iter)->Id() << ",";
  328. iter++;
  329. }
  330. ss << (*iter)->Id() << ").";
  331. return ss.str();
  332. }
  333. }
  334. void DynamicShapePartitioner::MergeClustersControlFlow() {
  335. std::unordered_set<ClusterPtr> all_merged_clusters;
  336. for (const auto &item : control_clusters_) {
  337. const auto &control_cluster = item.second;
  338. auto rit = control_cluster.rbegin();
  339. if (rit == control_cluster.rend()) {
  340. GELOGW("Invalid empty control flow cluster.");
  341. continue;
  342. }
  343. const auto &cluster = *rit;
  344. if (all_merged_clusters.count(cluster) > 0) {
  345. continue;
  346. }
  347. for (++rit; rit != control_cluster.rend(); ++rit) {
  348. const auto &cluster_from = *rit;
  349. if (all_merged_clusters.count(cluster_from) > 0) {
  350. continue;
  351. }
  352. auto merged_clusters = cluster->MergeAllPathFrom(cluster_from);
  353. GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(),
  354. ToString(merged_clusters).c_str());
  355. for (const auto &merged_cluster : merged_clusters) {
  356. all_merged_clusters.emplace(merged_cluster);
  357. for (const auto &node : merged_cluster->Nodes()) {
  358. node_2_cluster_[node] = cluster;
  359. }
  360. }
  361. }
  362. }
  363. }
  364. void DynamicShapePartitioner::MergeClustersUnknownShape() {
  365. // Merge unknown shape clusters
  366. for (const auto &cluster : ordered_cluster_) {
  367. if (cluster->IsIndependent()) {
  368. continue;
  369. }
  370. for (const auto &in_cluster : cluster->Inputs()) {
  371. if (!in_cluster->IsUnknownShape()) {
  372. continue;
  373. }
  374. if (!cluster->IsAdjoinNodes(in_cluster)) {
  375. continue;
  376. }
  377. auto merged_clusters = cluster->MergeAllPathFrom(in_cluster);
  378. GELOGD("Merge all path cluster from %lu to %lu %s.", in_cluster->Id(), cluster->Id(),
  379. ToString(merged_clusters).c_str());
  380. for (const auto &merged_cluster : merged_clusters) {
  381. for (const auto &node : merged_cluster->Nodes()) {
  382. node_2_cluster_[node] = cluster;
  383. }
  384. }
  385. }
  386. }
  387. }
  388. void DynamicShapePartitioner::MergeClustersKnownShape() {
  389. // Merge known shape clusters
  390. for (const auto &cluster : ordered_cluster_) {
  391. if (cluster->IsIndependent()) {
  392. continue;
  393. }
  394. if (cluster->IsRefVariable() && cluster->Inputs().size() == 1) {
  395. auto in_cluster = *(cluster->Inputs().begin());
  396. in_cluster->Merge(cluster);
  397. node_2_cluster_[*(cluster->Nodes().begin())] = in_cluster;
  398. continue;
  399. }
  400. for (const auto &in_cluster : cluster->Inputs()) {
  401. if (!in_cluster->IsKnownShape()) {
  402. continue;
  403. }
  404. if (cluster->TryMerge(in_cluster)) {
  405. GELOGD("Success merge known shape cluster from %lu to %lu.", in_cluster->Id(), cluster->Id());
  406. for (const auto &node : in_cluster->Nodes()) {
  407. node_2_cluster_[node] = cluster;
  408. }
  409. }
  410. }
  411. }
  412. }
  413. void DynamicShapePartitioner::MergeClustersInputData() {
  414. // Merge input clusters
  415. std::shared_ptr<Cluster> cluster_pre = nullptr;
  416. for (const auto &cluster : ordered_cluster_) {
  417. if (!cluster->IsInputNode()) {
  418. continue;
  419. }
  420. if (cluster_pre != nullptr) {
  421. cluster_pre->Merge(cluster);
  422. } else {
  423. cluster_pre = cluster;
  424. }
  425. GELOGD("Success merge input node cluster from %lu to %lu.", cluster->Id(), cluster->Id());
  426. for (const auto &node : cluster->Nodes()) {
  427. node_2_cluster_[node] = cluster_pre;
  428. }
  429. }
  430. }
  431. Status DynamicShapePartitioner::MergeClusters() {
  432. const auto filter_known = [](const ClusterPtr &cluster) {
  433. return cluster->IsKnownShape() || cluster->IsInputNode();
  434. };
  435. const auto filter_unknown = [](const ClusterPtr &cluster) {
  436. return cluster->IsUnknownShape();
  437. };
  438. MergeClustersControlFlow();
  439. REQUIRE_SUCCESS(TopologicalSortClusters(filter_unknown),
  440. "[TopologicalSort][Clusters] after merge control flow clusters failed.");
  441. MergeClustersUnknownShape();
  442. REQUIRE_SUCCESS(TopologicalSortClusters(filter_known),
  443. "[TopologicalSort][Clusters] after merge unknown shape clusters failed.");
  444. MergeClustersKnownShape();
  445. MergeClustersInputData();
  446. return SUCCESS;
  447. }
  448. bool DynamicShapePartitioner::JudgeUnknowShapeWithAttr(const OpDescPtr &opdesc) {
  449. bool is_forced_unknown = false;
  450. if (AttrUtils::GetBool(opdesc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_forced_unknown) && is_forced_unknown) {
  451. GELOGD("Collect node %s as unknown as it was marked unknown forcibly.", opdesc->GetName().c_str());
  452. return true;
  453. }
  454. bool forced_unknown = false;
  455. if (AttrUtils::GetBool(opdesc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, forced_unknown) && forced_unknown) {
  456. GELOGD("Collect node %s as unknown as it was marked force unknown node forcibly.", opdesc->GetName().c_str());
  457. return true;
  458. }
  459. return false;
  460. }
  461. Status DynamicShapePartitioner::CollectSpreadUnknownShapeNodes(NodePtr node) {
  462. if (unknown_shape_nodes_.count(node) > 0) {
  463. return SUCCESS;
  464. }
  465. auto opdesc = node->GetOpDesc();
  466. REQUIRE_NOT_NULL(opdesc, "[Get][OpDesc] Opdesc is nullptr.");
  467. // One can set 'ATTR_NAME_IS_UNKNOWN_SHAPE=true' on node so as to forcing the node flow into the unknown subgraph,
  468. // ignore the actual shape.
  469. if (JudgeUnknowShapeWithAttr(opdesc)) {
  470. unknown_shape_nodes_.insert(node);
  471. return SUCCESS;
  472. }
  473. size_t anchor_index = 0;
  474. bool is_unknown = false;
  475. for (auto &out_tensor : opdesc->GetAllOutputsDesc()) {
  476. if (IsUnknownShapeTensor(out_tensor)) {
  477. GELOGD("Collect node %s as unknown as output %lu is unknown.", node->GetName().c_str(), anchor_index);
  478. is_unknown = true;
  479. auto anchor = node->GetOutDataAnchor(static_cast<int>(anchor_index));
  480. for (const auto peer_anchor : anchor->GetPeerInDataAnchors()) {
  481. if (peer_anchor != nullptr) {
  482. GELOGD("Collect node %s as has unknown input from %s:%lu.", peer_anchor->GetOwnerNode()->GetName().c_str(),
  483. node->GetName().c_str(), anchor_index);
  484. unknown_shape_nodes_.insert(peer_anchor->GetOwnerNode());
  485. }
  486. }
  487. }
  488. anchor_index++;
  489. }
  490. anchor_index = 0;
  491. for (auto &in_tensor : opdesc->GetAllInputsDesc()) {
  492. if (IsUnknownShapeTensor(in_tensor)) {
  493. GELOGD("Collect node %s as unknown as input %lu is unknown.", node->GetName().c_str(), anchor_index);
  494. is_unknown = true;
  495. auto anchor = node->GetInDataAnchor(static_cast<int>(anchor_index));
  496. const auto peer_anchor = anchor->GetPeerOutAnchor();
  497. if (peer_anchor != nullptr) {
  498. GELOGD("Collect node %s as has unknown output to %s:%lu.", peer_anchor->GetOwnerNode()->GetName().c_str(),
  499. node->GetName().c_str(), anchor_index);
  500. unknown_shape_nodes_.insert(peer_anchor->GetOwnerNode());
  501. }
  502. }
  503. anchor_index++;
  504. }
  505. if (is_unknown) {
  506. unknown_shape_nodes_.insert(node);
  507. } else {
  508. auto graph = root_graph_;
  509. for (const auto &subgraph_name : opdesc->GetSubgraphInstanceNames()) {
  510. auto subgraph = graph->GetSubgraph(subgraph_name);
  511. REQUIRE_NOT_NULL(subgraph, "[Get][Subgraph] %s of node %s on root graph failed.", subgraph_name.c_str(),
  512. node->GetName().c_str());
  513. bool is_graph_unknow = false;
  514. REQUIRE_SUCCESS(IsUnknownShapeGraph(subgraph, is_graph_unknow),
  515. "[Call][IsUnknownShapeGraph] Failed check subgraph %s shape of node %s.",
  516. subgraph_name.c_str(), node->GetName().c_str());
  517. if (is_graph_unknow) {
  518. GELOGD("Collect node %s as its subgraph %s is unknown.", node->GetName().c_str(), subgraph->GetName().c_str());
  519. unknown_shape_nodes_.insert(node);
  520. break;
  521. }
  522. }
  523. }
  524. return SUCCESS;
  525. }
  526. Status DynamicShapePartitioner::IsUnknownShapeNode(NodePtr node, bool &is_unknown) {
  527. auto opdesc = node->GetOpDesc();
  528. auto graph = root_graph_;
  529. for (auto &out_tensor : opdesc->GetAllOutputsDesc()) {
  530. if (IsUnknownShapeTensor(out_tensor)) {
  531. GELOGD("Mark node %s unknown as unknown output.", node->GetName().c_str());
  532. is_unknown = true;
  533. return SUCCESS;
  534. }
  535. }
  536. for (auto &in_tensor : opdesc->GetAllInputsDesc()) {
  537. if (IsUnknownShapeTensor(in_tensor)) {
  538. GELOGD("Mark node %s unknown as unknown intput.", node->GetName().c_str());
  539. is_unknown = true;
  540. return SUCCESS;
  541. }
  542. }
  543. for (auto &subgraph_name : opdesc->GetSubgraphInstanceNames()) {
  544. auto subgraph = graph->GetSubgraph(subgraph_name);
  545. REQUIRE_NOT_NULL(subgraph, "[Get][Subgraph] %s of node %s on root graph failed.", subgraph_name.c_str(),
  546. node->GetName().c_str());
  547. REQUIRE_SUCCESS(IsUnknownShapeGraph(subgraph, is_unknown),
  548. "[Call][IsUnknownShapeGraph] Failed check subgraph %s shape of node %s.",
  549. subgraph_name.c_str(), node->GetName().c_str());
  550. if (is_unknown) {
  551. GELOGD("Mark node %s unknown as unknown subgraph.", node->GetName().c_str());
  552. return SUCCESS;
  553. }
  554. }
  555. is_unknown = false;
  556. return SUCCESS;
  557. }
  558. Status DynamicShapePartitioner::IsUnknownShapeGraph(ComputeGraphPtr graph, bool &is_unknown) {
  559. for (auto &node : graph->GetDirectNode()) {
  560. REQUIRE_SUCCESS(IsUnknownShapeNode(node, is_unknown),
  561. "[Call][IsUnknownShapeNode]Failed check node %s shape on graph %s.",
  562. node->GetName().c_str(), graph->GetName().c_str());
  563. if (is_unknown) {
  564. GELOGD("Mark graph %s unknown as contains unknown node %s.", graph->GetName().c_str(), node->GetName().c_str());
  565. return SUCCESS;
  566. }
  567. }
  568. return SUCCESS;
  569. }
  570. std::string Cluster::DebugString() const {
  571. std::stringstream ss;
  572. switch (type_) {
  573. case DATA:
  574. ss << "DATA";
  575. break;
  576. case INPUT_NODE:
  577. ss << "INPUT_NODE";
  578. break;
  579. case NETOUTPUT:
  580. ss << "NETOUTPUT";
  581. break;
  582. case UNKNOWN_SHAPE:
  583. ss << "UNKNOW";
  584. break;
  585. case KNOWN_SHAPE:
  586. ss << "KNOW";
  587. break;
  588. default:
  589. break;
  590. }
  591. ss << "[" << id_ << "](size:" << nodes_.size() << ")";
  592. ss << "(" << min_ << "," << max_ << ")(";
  593. for (const auto &cluster : in_clusters_) {
  594. ss << cluster->id_ << ",";
  595. }
  596. ss << ")->(";
  597. for (const auto &cluster : out_clusters_) {
  598. ss << cluster->id_ << ",";
  599. }
  600. ss << ")|";
  601. for (const auto &node : nodes_) {
  602. ss << (node->GetName() + "|");
  603. }
  604. return ss.str();
  605. }
  606. size_t Cluster::Id() const { return id_; }
  607. void Cluster::UpdateRank(size_t rank) {
  608. max_ = rank;
  609. min_ = rank;
  610. };
  611. bool Cluster::IsData() const { return type_ == DATA; };
  612. bool Cluster::IsKnownShape() const { return type_ == KNOWN_SHAPE; };
  613. bool Cluster::IsUnknownShape() const { return type_ == UNKNOWN_SHAPE; };
  614. bool Cluster::IsIndependent() const { return type_ == STAGE; };
  615. bool Cluster::IsNetOutput() const { return type_ == NETOUTPUT; };
  616. bool Cluster::IsInputNode() const { return type_ == INPUT_NODE; };
  617. bool Cluster::IsRefVariable() const {
  618. if ((nodes_.size() == 1) && ((nodes_[0]->GetType() == VARIABLE) || (nodes_[0]->GetType() == VARIABLEV2))) {
  619. std::string ref_variable_name;
  620. return (AttrUtils::GetStr(nodes_[0]->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_variable_name) &&
  621. !ref_variable_name.empty());
  622. }
  623. return false;
  624. }
  625. void Cluster::AddInput(ClusterPtr in) {
  626. if (std::find(in_clusters_.begin(), in_clusters_.end(), in) != in_clusters_.end()) return;
  627. in_clusters_.insert(in_clusters_.end(), in);
  628. if (std::find(in->out_clusters_.begin(), in->out_clusters_.end(), shared_from_this()) != in->out_clusters_.end())
  629. return;
  630. in->out_clusters_.insert(in->out_clusters_.end(), shared_from_this());
  631. };
  632. void Cluster::RemoveInput(ClusterPtr in) {
  633. in_clusters_.erase(std::remove(in_clusters_.begin(), in_clusters_.end(), in), in_clusters_.end());
  634. in->out_clusters_.erase(std::remove(in->out_clusters_.begin(), in->out_clusters_.end(), shared_from_this()),
  635. in->out_clusters_.end());
  636. };
  637. void Cluster::AddOutput(ClusterPtr out) {
  638. if (std::find(out_clusters_.begin(), out_clusters_.end(), out) != out_clusters_.end()) return;
  639. out_clusters_.insert(out_clusters_.end(), out);
  640. if (std::find(out->in_clusters_.begin(), out->in_clusters_.end(), shared_from_this()) != out->in_clusters_.end())
  641. return;
  642. out->in_clusters_.insert(out->in_clusters_.end(), shared_from_this());
  643. };
  644. void Cluster::RemoveOutput(ClusterPtr out) {
  645. out_clusters_.erase(std::remove(out_clusters_.begin(), out_clusters_.end(), out), out_clusters_.end());
  646. out->in_clusters_.erase(std::remove(out->in_clusters_.begin(), out->in_clusters_.end(), shared_from_this()),
  647. out->in_clusters_.end());
  648. };
  649. void Cluster::Merge(ClusterPtr other) {
  650. if (other->IsIndependent()) {
  651. return;
  652. }
  653. nodes_.insert(nodes_.end(), other->nodes_.begin(), other->nodes_.end());
  654. other->in_clusters_.erase(std::remove(other->in_clusters_.begin(), other->in_clusters_.end(), shared_from_this()),
  655. other->in_clusters_.end());
  656. other->out_clusters_.erase(std::remove(other->out_clusters_.begin(), other->out_clusters_.end(), shared_from_this()),
  657. other->out_clusters_.end());
  658. in_clusters_.erase(std::remove(in_clusters_.begin(), in_clusters_.end(), other), in_clusters_.end());
  659. out_clusters_.erase(std::remove(out_clusters_.begin(), out_clusters_.end(), other), out_clusters_.end());
  660. auto in_clusters = other->in_clusters_;
  661. for (const auto &cluster : in_clusters) {
  662. cluster->RemoveOutput(other);
  663. cluster->AddOutput(shared_from_this());
  664. }
  665. auto out_clusters = other->out_clusters_;
  666. for (const auto &cluster : out_clusters) {
  667. cluster->RemoveInput(other);
  668. cluster->AddInput(shared_from_this());
  669. }
  670. if (other->max_ > max_) {
  671. max_ = other->max_;
  672. }
  673. if (other->min_ < min_) {
  674. min_ = other->min_;
  675. }
  676. if (!IsUnknownShape() && other->IsUnknownShape()) {
  677. type_ = UNKNOWN_SHAPE;
  678. }
  679. }
  680. bool Cluster::TryMerge(ClusterPtr other) {
  681. std::queue<ClusterPtr> forward_reached;
  682. forward_reached.push(other);
  683. while (!forward_reached.empty()) {
  684. auto current_cluster = forward_reached.front();
  685. forward_reached.pop();
  686. for (const auto &cluster : current_cluster->out_clusters_) {
  687. if (cluster->max_ == max_ && current_cluster != other) {
  688. return false;
  689. } else if (cluster->min_ < max_) {
  690. forward_reached.push(cluster);
  691. }
  692. }
  693. }
  694. Merge(other);
  695. return true;
  696. };
  697. std::vector<ClusterPtr> Cluster::MergeAllPathFrom(ClusterPtr other) {
  698. std::queue<ClusterPtr> forward_reached_queue;
  699. std::queue<ClusterPtr> backward_reached_queue;
  700. std::unordered_set<ClusterPtr> forward_reached_clusters;
  701. std::unordered_set<ClusterPtr> backward_reached_clusters;
  702. std::vector<ClusterPtr> path_clusters;
  703. if (other->IsIndependent()) {
  704. return path_clusters;
  705. }
  706. path_clusters.push_back(other);
  707. forward_reached_queue.push(other);
  708. backward_reached_queue.push(shared_from_this());
  709. while (!forward_reached_queue.empty()) {
  710. auto current_cluster = forward_reached_queue.front();
  711. forward_reached_queue.pop();
  712. for (const auto &cluster : current_cluster->out_clusters_) {
  713. if (cluster->min_ < max_ && cluster->max_ != max_ && forward_reached_clusters.count(cluster) == 0) {
  714. forward_reached_clusters.insert(cluster);
  715. forward_reached_queue.push(cluster);
  716. }
  717. }
  718. }
  719. while (!backward_reached_queue.empty()) {
  720. auto current_cluster = backward_reached_queue.front();
  721. backward_reached_queue.pop();
  722. for (const auto &cluster : current_cluster->in_clusters_) {
  723. if (cluster->max_ > other->min_ && cluster->max_ != other->max_ &&
  724. backward_reached_clusters.count(cluster) == 0) {
  725. backward_reached_clusters.insert(cluster);
  726. backward_reached_queue.push(cluster);
  727. if (forward_reached_clusters.count(cluster) != 0) {
  728. path_clusters.push_back(cluster);
  729. }
  730. }
  731. }
  732. }
  733. for (const auto &cluster : path_clusters) {
  734. Merge(cluster);
  735. }
  736. return path_clusters;
  737. }
  738. std::vector<ClusterPtr> Cluster::Inputs() const { return in_clusters_; };
  739. std::vector<ClusterPtr> Cluster::Outputs() const { return out_clusters_; };
  740. std::vector<NodePtr> Cluster::Nodes() const { return nodes_; };
  741. void Cluster::AddFrameInput(InDataAnchorPtr anchor) {
  742. if (anchor != nullptr && anchor->GetPeerOutAnchor() != nullptr) {
  743. inputs_index_[anchor] = inputs_.size();
  744. inputs_.push_back(anchor);
  745. }
  746. }
  747. void Cluster::AddFrameOutput(OutDataAnchorPtr anchor) {
  748. if (anchor != nullptr) {
  749. outputs_index_[anchor] = outputs_.size();
  750. outputs_.push_back(anchor);
  751. }
  752. }
  753. InDataAnchorPtr Cluster::GetFrameInDataAnchor(InDataAnchorPtr anchor) {
  754. return partition_node_->GetInDataAnchor(static_cast<int>(inputs_index_[anchor]));
  755. }
  756. OutDataAnchorPtr Cluster::GetFrameOutDataAnchor(OutDataAnchorPtr anchor) {
  757. return partition_node_->GetOutDataAnchor(static_cast<int>(outputs_index_[anchor]));
  758. }
  759. InControlAnchorPtr Cluster::GetFrameInControlAnchor() { return partition_node_->GetInControlAnchor(); };
  760. OutControlAnchorPtr Cluster::GetFrameOutControlAnchor() { return partition_node_->GetOutControlAnchor(); };
  761. Status Cluster::BuildFrame() {
  762. if (IsUnknownShape() || IsKnownShape() || IsInputNode()) {
  763. return BuildPartitionFrame();
  764. } else {
  765. auto node = nodes_.front();
  766. auto in_control_anchor = node->GetInControlAnchor();
  767. if (in_control_anchor != nullptr) {
  768. for (const auto &peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) {
  769. auto src_cluster = partitioner_->node_2_cluster_[peer_out_control_anchor->GetOwnerNode()];
  770. if (src_cluster->id_ != id_) {
  771. REQUIRE_GRAPH_SUCCESS(
  772. GraphUtils::RemoveEdge(peer_out_control_anchor, in_control_anchor),
  773. "[Remove][Edge] from node %s index %d to node %s failed, index %d.",
  774. peer_out_control_anchor->GetOwnerNode()->GetName().c_str(), AnchorUtils::GetIdx(peer_out_control_anchor),
  775. in_control_anchor->GetOwnerNode()->GetName().c_str(), AnchorUtils::GetIdx(in_control_anchor));
  776. control_inputs_.insert(src_cluster);
  777. src_cluster->control_outputs_.insert(peer_out_control_anchor);
  778. }
  779. }
  780. }
  781. if (IsData() || IsIndependent()) {
  782. for (const auto &anchor : node->GetAllOutDataAnchors()) {
  783. AddFrameOutput(anchor);
  784. }
  785. } else {
  786. for (const auto &anchor : node->GetAllInDataAnchors()) {
  787. AddFrameInput(anchor);
  788. }
  789. }
  790. partition_node_ = node;
  791. }
  792. return SUCCESS;
  793. }
  794. Status Cluster::BuildPartitionFrame() {
  795. auto graph = partitioner_->root_graph_;
  796. bool is_unknown_shape = IsUnknownShape();
  797. bool is_input = IsInputNode();
  798. string known_name = (is_unknown_shape ? "_unknow" : "_know");
  799. string sub_graph_name_patten = (is_input ? "_input" : known_name);
  800. std::string sub_graph_name = graph->GetName() + "_sub_" + std::to_string(unique_id_) + sub_graph_name_patten;
  801. subgraph_ = MakeShared<ComputeGraph>(sub_graph_name);
  802. REQUIRE_NOT_NULL(subgraph_, "[New][Memory] for subgraph failed, name:%s.", sub_graph_name.c_str());
  803. auto partition_op = MakeShared<OpDesc>("PartitionedCall_" + std::to_string(unique_id_++), "PartitionedCall");
  804. REQUIRE_NOT_NULL(partition_op, "[New][Memory] for partition op failed.");
  805. REQUIRE(AttrUtils::SetBool(partition_op, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape),
  806. "[Set][Attr] _is_unknown_shape flag on partitioned op %s failed.", partition_op->GetName().c_str());
  807. REQUIRE_GRAPH_SUCCESS(partition_op->AddSubgraphName(subgraph_->GetName()),
  808. "[Add][SubgraphName] %s for op:%s.",
  809. subgraph_->GetName().c_str(), partition_op->GetName().c_str());
  810. REQUIRE_GRAPH_SUCCESS(partition_op->SetSubgraphInstanceName(0, subgraph_->GetName()),
  811. "[Call][SetSubgraphInstanceName] for op:%s failed, index:0, name:%s.",
  812. partition_op->GetName().c_str(), subgraph_->GetName().c_str());
  813. for (auto &node : nodes_) {
  814. REQUIRE_NOT_NULL(subgraph_->AddNode(node),
  815. "[Add][Node] %s to subgraph:%s failed.", node->GetName().c_str(), subgraph_->GetName().c_str());
  816. REQUIRE(AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape),
  817. "[Set][Attr] %s to op:%s failed.", ATTR_NAME_IS_UNKNOWN_SHAPE.c_str(), node->GetName().c_str());
  818. REQUIRE_GRAPH_SUCCESS(GraphUtils::RemoveJustNode(graph, node),
  819. "[Remove][JustNode] failed, graph:%s, node:%s.",
  820. graph->GetName().c_str(), node->GetName().c_str());
  821. REQUIRE_GRAPH_SUCCESS(node->SetOwnerComputeGraph(subgraph_),
  822. "[Set][OwnerComputeGraph] %s for node:%s failed.",
  823. subgraph_->GetName().c_str(), node->GetName().c_str());
  824. for (const auto &anchor : node->GetAllInDataAnchors()) {
  825. auto peer_out_anchor = anchor->GetPeerOutAnchor();
  826. if (peer_out_anchor == nullptr) {
  827. continue; // Skip overhang input.
  828. }
  829. auto src_cluster = partitioner_->node_2_cluster_[peer_out_anchor->GetOwnerNode()];
  830. if (src_cluster->id_ != id_) {
  831. AddFrameInput(anchor);
  832. REQUIRE_GRAPH_SUCCESS(partition_op->AddInputDesc(node->GetOpDesc()->GetInputDesc(anchor->GetIdx())),
  833. "[Add][InputDesc] to op:%s failed.", partition_op->GetName().c_str());
  834. }
  835. }
  836. auto in_control_anchor = node->GetInControlAnchor();
  837. if (in_control_anchor != nullptr) {
  838. for (const auto &peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) {
  839. if (peer_out_control_anchor == nullptr) {
  840. continue;
  841. }
  842. auto src_cluster = partitioner_->node_2_cluster_[peer_out_control_anchor->GetOwnerNode()];
  843. if (src_cluster->id_ != id_) {
  844. REQUIRE_GRAPH_SUCCESS(
  845. GraphUtils::RemoveEdge(peer_out_control_anchor, in_control_anchor),
  846. "[Remove][Edge] from %s:%d to %s:%d failed.", peer_out_control_anchor->GetOwnerNode()->GetName().c_str(),
  847. peer_out_control_anchor->GetIdx(), node->GetName().c_str(), in_control_anchor->GetIdx());
  848. control_inputs_.insert(src_cluster);
  849. src_cluster->control_outputs_.insert(peer_out_control_anchor);
  850. }
  851. }
  852. }
  853. for (const auto &anchor : node->GetAllOutDataAnchors()) {
  854. auto peer_in_anchors = anchor->GetPeerInDataAnchors();
  855. for (const auto &peer_in_anchor : peer_in_anchors) {
  856. auto src_cluster = partitioner_->node_2_cluster_[peer_in_anchor->GetOwnerNode()];
  857. if (src_cluster->id_ != id_) {
  858. AddFrameOutput(anchor);
  859. REQUIRE_GRAPH_SUCCESS(partition_op->AddOutputDesc(node->GetOpDesc()->GetOutputDesc(anchor->GetIdx())),
  860. "[Add][OutputDesc] to op:%s failed.", partition_op->GetName().c_str());
  861. break;
  862. }
  863. }
  864. }
  865. }
  866. partition_node_ = graph->AddNode(partition_op);
  867. REQUIRE_NOT_NULL(partition_node_,
  868. "[Add][Node] %s to graph:%s failed.", partition_op->GetName().c_str(), graph->GetName().c_str());
  869. REQUIRE_GRAPH_SUCCESS(partition_node_->SetOwnerComputeGraph(graph),
  870. "[Set][OwnerComputeGraph] %s for node:%s failed.",
  871. graph->GetName().c_str(), partition_op->GetName().c_str());
  872. subgraph_->SetParentNode(partition_node_);
  873. subgraph_->SetParentGraph(graph);
  874. REQUIRE_GRAPH_SUCCESS(graph->AddSubgraph(subgraph_),
  875. "[Add][Subgraph] %s to root graph:%s failed.",
  876. subgraph_->GetName().c_str(), graph->GetName().c_str());
  877. std::string session_graph_id;
  878. REQUIRE(AttrUtils::GetStr(*graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id),
  879. "[Get][Attr] %s on root graph:%s failed.", ATTR_NAME_SESSION_GRAPH_ID.c_str(), graph->GetName().c_str());
  880. REQUIRE(AttrUtils::SetStr(*subgraph_, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id),
  881. "[Set][Attr] %s on subgraph:%s failed.", ATTR_NAME_SESSION_GRAPH_ID.c_str(), subgraph_->GetName().c_str());
  882. return SUCCESS;
  883. }
  884. Status Cluster::CombinePartitionFrame() {
  885. for (const auto &anchor : inputs_) {
  886. auto peer_out_anchor = anchor->GetPeerOutAnchor();
  887. auto src_cluster = partitioner_->node_2_cluster_[peer_out_anchor->GetOwnerNode()];
  888. auto src_anchor = src_cluster->GetFrameOutDataAnchor(peer_out_anchor);
  889. auto dst_anchor = GetFrameInDataAnchor(anchor);
  890. REQUIRE_GRAPH_SUCCESS(GraphUtils::RemoveEdge(peer_out_anchor, anchor), "[Remove][Edge] from %s:%d to %s:%d fail.",
  891. peer_out_anchor->GetOwnerNode()->GetName().c_str(), peer_out_anchor->GetIdx(),
  892. anchor->GetOwnerNode()->GetName().c_str(), anchor->GetIdx());
  893. REQUIRE_GRAPH_SUCCESS(GraphUtils::AddEdge(src_anchor, dst_anchor), "[Add][Edge] from %s:%d to %s:%d failed.",
  894. src_anchor->GetOwnerNode()->GetName().c_str(), src_anchor->GetIdx(),
  895. dst_anchor->GetOwnerNode()->GetName().c_str(), dst_anchor->GetIdx());
  896. }
  897. for (const auto &src_cluster : control_inputs_) {
  898. auto src_anchor = src_cluster->GetFrameOutControlAnchor();
  899. auto dst_anchor = GetFrameInControlAnchor();
  900. REQUIRE_GRAPH_SUCCESS(GraphUtils::AddEdge(src_anchor, dst_anchor), "[Add][Edge] from %s:%d to %s:%d failed.",
  901. src_anchor->GetOwnerNode()->GetName().c_str(), src_anchor->GetIdx(),
  902. dst_anchor->GetOwnerNode()->GetName().c_str(), dst_anchor->GetIdx());
  903. }
  904. return SUCCESS;
  905. }
  906. Status Cluster::BuildPartitionSubgraph() {
  907. if (IsData() || IsNetOutput() || IsIndependent()) {
  908. return SUCCESS;
  909. }
  910. int64_t parent_node_index = 0;
  911. for (auto anchor : inputs_) {
  912. auto data_op =
  913. MakeShared<OpDesc>(subgraph_->GetName() + std::string("Data_") + std::to_string(parent_node_index), ge::DATA);
  914. REQUIRE_NOT_NULL(data_op, "[New][Memory] for data op failed.");
  915. auto input_desc = anchor->GetOwnerNode()->GetOpDesc()->GetInputDesc(anchor->GetIdx());
  916. REQUIRE_GRAPH_SUCCESS(data_op->AddInputDesc(input_desc),
  917. "[Add][InputDesc] to op:%s failed.", data_op->GetName().c_str());
  918. REQUIRE_GRAPH_SUCCESS(data_op->AddOutputDesc(input_desc),
  919. "[Add][OutputDesc] to op:%s failed.", data_op->GetName().c_str());
  920. REQUIRE(AttrUtils::SetInt(data_op, ATTR_NAME_PARENT_NODE_INDEX, parent_node_index),
  921. "[Set][Attr] %s on subgraph data node:%s failed.",
  922. ATTR_NAME_PARENT_NODE_INDEX.c_str(), data_op->GetName().c_str());
  923. bool is_unknown_shape = IsUnknownShape();
  924. REQUIRE(AttrUtils::SetBool(data_op, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape),
  925. "[Set][Attr] %s on data op %s failed.", ATTR_NAME_IS_UNKNOWN_SHAPE.c_str(), data_op->GetName().c_str());
  926. auto data_node = subgraph_->AddNode(data_op);
  927. REQUIRE_NOT_NULL(data_node,
  928. "[Add][Node] %s to subgraph:%s failed.", data_op->GetName().c_str(), subgraph_->GetName().c_str());
  929. REQUIRE_GRAPH_SUCCESS(data_node->SetOwnerComputeGraph(subgraph_),
  930. "[Set][OwnerGraph] %s of data node:%s failed.",
  931. subgraph_->GetName().c_str(), data_op->GetName().c_str());
  932. REQUIRE_GRAPH_SUCCESS(GraphUtils::AddEdge(data_node->GetOutDataAnchor(0), anchor),
  933. "[Call][AddEdge] Failed add data input edge to %s:%d",
  934. anchor->GetOwnerNode()->GetName().c_str(), anchor->GetIdx());
  935. parent_node_index++;
  936. }
  937. if (outputs_.empty() && control_outputs_.empty()) {
  938. return SUCCESS;
  939. }
  940. auto net_output_op = MakeShared<OpDesc>(subgraph_->GetName() + "_" + NODE_NAME_NET_OUTPUT, ge::NETOUTPUT);
  941. REQUIRE_NOT_NULL(net_output_op, "[New][Memory] for netoutput op failed.");
  942. bool is_unknown_shape = IsUnknownShape();
  943. REQUIRE(AttrUtils::SetBool(net_output_op, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape),
  944. "[Set][Attr] %s on op:%s failed.", ATTR_NAME_IS_UNKNOWN_SHAPE.c_str(), net_output_op->GetName().c_str());
  945. for (size_t i = 0; i < outputs_.size(); ++i) {
  946. GeTensorDesc input_desc;
  947. REQUIRE_GRAPH_SUCCESS(net_output_op->AddInputDesc(input_desc),
  948. "[Add][InputDesc] to op:%s failed.", net_output_op->GetName().c_str());
  949. }
  950. auto net_output_node = subgraph_->AddNode(net_output_op);
  951. REQUIRE_NOT_NULL(net_output_node,
  952. "[Call][AddNode] Failed add netoutput node:%s to subgraph:%s.",
  953. net_output_op->GetName().c_str(), subgraph_->GetName().c_str());
  954. REQUIRE_GRAPH_SUCCESS(net_output_node->SetOwnerComputeGraph(subgraph_),
  955. "[Set][OwnerGraph] %s of netoutput node:%s failed.",
  956. subgraph_->GetName().c_str(), net_output_node->GetName().c_str());
  957. parent_node_index = 0;
  958. for (const auto &anchor : outputs_) {
  959. auto output_desc = anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(static_cast<uint32_t>(anchor->GetIdx()));
  960. REQUIRE(AttrUtils::SetInt(output_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_node_index),
  961. "[Set][Attr] parent_node_index on subgraph node:%s netoutput's input failed.",
  962. anchor->GetOwnerNode()->GetName().c_str());
  963. REQUIRE_GRAPH_SUCCESS(net_output_op->UpdateInputDesc(parent_node_index, output_desc),
  964. "[Update][InputDesc] of netoutput node:%s failed.", net_output_op->GetName().c_str());
  965. REQUIRE_GRAPH_SUCCESS(GraphUtils::AddEdge(anchor, net_output_node->GetInDataAnchor(parent_node_index)),
  966. "[Add][Edge] from %s:%d to netoutput node:%s failed.",
  967. anchor->GetOwnerNode()->GetName().c_str(), anchor->GetIdx(),
  968. net_output_op->GetName().c_str());
  969. parent_node_index++;
  970. }
  971. for (const auto &anchor : control_outputs_) {
  972. REQUIRE_GRAPH_SUCCESS(GraphUtils::AddEdge(anchor, net_output_node->GetInControlAnchor()),
  973. "[Add][ControlEdge] from %s:%d to netoutput node:%s failed.",
  974. anchor->GetOwnerNode()->GetName().c_str(), anchor->GetIdx(),
  975. net_output_op->GetName().c_str());
  976. }
  977. return SUCCESS;
  978. }
  979. void Cluster::Clear() {
  980. in_clusters_.clear();
  981. out_clusters_.clear();
  982. nodes_.clear();
  983. partitioner_ = nullptr;
  984. inputs_index_.clear();
  985. outputs_index_.clear();
  986. inputs_.clear();
  987. outputs_.clear();
  988. control_inputs_.clear();
  989. control_outputs_.clear();
  990. partition_node_.reset();
  991. subgraph_.reset();
  992. unique_id_ = 0;
  993. }
  994. thread_local size_t Cluster::unique_id_ = 0;
  995. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示