You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

allreduce_fusion_pass.cc 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/optimize/optimizer/allreduce_fusion_pass.h"
  17. #include <string>
  18. #include "common/debug/log.h"
  19. #include "framework/common/debug/ge_log.h"
  20. #include "common/types.h"
  21. #include "common/util.h"
  22. #include "graph/anchor.h"
  23. #include "graph/node.h"
  24. #include "graph/op_desc.h"
  25. #include "graph/utils/attr_utils.h"
  26. #include "graph/utils/graph_utils.h"
  27. #include "graph/utils/tensor_utils.h"
  28. #include "graph/debug/ge_attr_define.h"
  29. #include "hccl/base.h"
  30. #include "hccl/hcom.h"
  31. namespace ge {
  32. Status AllReducePass::Run(ge::ComputeGraphPtr graph) {
  33. GELOGI("FusionAllReducePass: start");
  34. std::vector<NodePtr> fusionOps;
  35. std::vector<float> inputGradientSize;
  36. std::vector<float> inputGradientTime;
  37. static const float inputGradientSizeTemp = 0.0;
  38. static const float inputGradientTimeTemp = 0.0;
  39. // Get all nodes
  40. for (auto nodePtr : graph->GetDirectNode()) {
  41. GE_IF_BOOL_EXEC(nullptr == nodePtr, GELOGW("FusionAllReducePass: null node exists"); continue;);
  42. ge::OpDescPtr opDescPtr = nodePtr->GetOpDesc();
  43. GE_IF_BOOL_EXEC(nullptr == opDescPtr,
  44. GELOGW("FusionAllReducePass: desc of node %s is null", nodePtr->GetName().c_str());
  45. continue;)
  46. GE_IF_BOOL_EXEC(HCOMALLREDUCE == opDescPtr->GetType(),
  47. // the op is allreduce and fusion > 0, then run fusion
  48. std::int64_t hcom_fusion = 1;
  49. GE_IF_BOOL_EXEC(!ge::AttrUtils::GetInt(opDescPtr, HCOM_ATTR_FUSION, hcom_fusion),
  50. GELOGW("FusionAllReducePass: not get hcom_fusion from opDescPtr "
  51. "by HCOM_ATTR_FUSION"));
  52. GELOGI("after GetInt, hcom_fusion is :%ld", hcom_fusion); GE_IF_BOOL_EXEC(
  53. hcom_fusion > 0, fusionOps.push_back(nodePtr); inputGradientSize.push_back(inputGradientSizeTemp);
  54. inputGradientTime.push_back(inputGradientTimeTemp);))
  55. }
  56. // The number of allredecue operator must be more than 1
  57. GE_IF_BOOL_EXEC(1 >= fusionOps.size(), GELOGW("FusionAllReducePass NOT_CHANGED: the graph has "
  58. "%lu allreduce operator",
  59. fusionOps.size());
  60. return NOT_CHANGED;);
  61. string group = "group";
  62. u32 gradientNum = fusionOps.size();
  63. string model_name_str = graph->GetName();
  64. const char *model_name = model_name_str.c_str();
  65. model_feature modelFeature{model_name, gradientNum, inputGradientSize.data(), inputGradientTime.data()};
  66. u32 segmentNum = 0;
  67. u32 segmentIndex[HCCL_MAX_SEGMENT_NUM] = {};
  68. // Call HCCL function: hcom_gradient_segment
  69. GELOGI("FusionAllReducePass: invoking hcom_get_split_strategy");
  70. GE_IF_BOOL_EXEC(HCCL_SUCCESS != hcom_get_split_strategy(group.c_str(), &modelFeature, HCCL_MAX_SEGMENT_NUM,
  71. &segmentNum, segmentIndex),
  72. GELOGE(FAILED, "FusionAllReducePass FAILED: the graph has %lu allreduce operator", fusionOps.size());
  73. return FAILED;)
  74. GELOGI("FusionAllReducePass: invoke hcom_get_split_strategy successfully");
  75. // check whether segmentNum is legal or not
  76. GE_IF_BOOL_EXEC((HCCL_MAX_SEGMENT_NUM < segmentNum || 1 > segmentNum || segmentNum > gradientNum),
  77. GELOGE(FAILED,
  78. "FusionAllReducePass FAILED: illegal segmentNum=%u, "
  79. "HCCL_MAX_SEGMENT_NUM=%u, gradientNum=%u",
  80. segmentNum, HCCL_MAX_SEGMENT_NUM, gradientNum);
  81. return FAILED;);
  82. // check whether segmentIndex is legal or not
  83. GE_IF_BOOL_EXEC((segmentIndex[segmentNum - 1] != gradientNum - 1),
  84. GELOGE(FAILED,
  85. "FusionAllReducePass FAILED: illegal segmentIndex[0]=%u, "
  86. "segmentIndex[segmentNum-1]=%u, gradientNum=%u",
  87. segmentIndex[0], segmentIndex[(segmentNum)-1], gradientNum);
  88. return FAILED;);
  89. for (uint32_t i = 0; i < segmentNum - 1; i++) {
  90. GE_IF_BOOL_EXEC(segmentIndex[i] >= segmentIndex[i + 1], GELOGE(FAILED,
  91. "FusionAllReducePass FAILED: illegal "
  92. "segmentIndex[%u]=%u, segmentIndex[%u]=%u",
  93. i, segmentIndex[i], i + 1, segmentIndex[i + 1]);
  94. return FAILED;);
  95. }
  96. // check whether fusion is needed or not
  97. GE_IF_BOOL_EXEC(
  98. segmentNum == gradientNum,
  99. GELOGE(NOT_CHANGED, "FusionAllReducePass NOT_CHANGED: segmentNum=%u, gradientNum=%u", segmentNum, gradientNum);
  100. return NOT_CHANGED;)
  101. std::unordered_set<void *> anchorPtrSet;
  102. std::vector<ge::OutDataAnchorPtr> fusionOpPeerOutDataAnchor;
  103. std::vector<ge::OutDataAnchorPtr> fusionOpPeerOutDataToInControl;
  104. std::vector<ge::OutControlAnchorPtr> fusionOpPeerOutControlAnchor;
  105. std::vector<std::pair<int, ge::InDataAnchorPtr>> fusionOpPeerInDataAnchor;
  106. std::vector<std::pair<int, ge::InControlAnchorPtr>> fusionOpPeerInControlFromOutData;
  107. std::vector<ge::InControlAnchorPtr> fusionOpPeerInControlAnchor;
  108. ge::OutControlAnchorPtr previousNewAllreduceOutControlAnchor = nullptr;
  109. // Traversing the segmentNum
  110. uint32_t start = 0;
  111. uint32_t end = 0;
  112. for (uint32_t segmentIdx = 0; segmentIdx < segmentNum; segmentIdx++) {
  113. end = segmentIndex[segmentIdx];
  114. GE_IF_BOOL_EXEC(end - start < 1,
  115. GELOGI("FusionAllReducePass: segmentIndex[%u]=%u", segmentIdx, segmentIndex[segmentIdx]);
  116. start = end + 1; continue;);
  117. ge::OpDescPtr originDescPtr = fusionOps[start]->GetOpDesc();
  118. GE_CHECK_NOTNULL(originDescPtr);
  119. ge::OpDescPtr newAllreduceDesc = AttrUtils::CloneOpDesc(originDescPtr);
  120. GE_CHECK_NOTNULL(newAllreduceDesc);
  121. // Cleat buffer
  122. anchorPtrSet.clear();
  123. fusionOpPeerOutDataAnchor.clear();
  124. fusionOpPeerOutDataToInControl.clear();
  125. fusionOpPeerOutControlAnchor.clear();
  126. fusionOpPeerInDataAnchor.clear();
  127. fusionOpPeerInControlFromOutData.clear();
  128. fusionOpPeerInControlAnchor.clear();
  129. // Traversing the Allreduce operators of each group
  130. int outDataAnchorIndex = 0;
  131. GE_CHK_STATUS_RET(GetPeerOutDataToInData(anchorPtrSet, fusionOpPeerOutDataAnchor, fusionOps[start]),
  132. "Get peer outDataAnchor to inDataAnchor failed");
  133. GE_CHK_STATUS_RET(GetPeerInAnchorToOutData(anchorPtrSet, fusionOpPeerInDataAnchor, fusionOpPeerInControlFromOutData,
  134. fusionOps[start]),
  135. "Get peer inDataAnchor and inControlAnchor to outDataAnchor failed");
  136. GE_CHK_STATUS_RET(GetPeerOutDataToInControl(anchorPtrSet, fusionOpPeerOutDataToInControl, fusionOps[start]),
  137. "Get peer outDataAnchor to inControlAnchor failed");
  138. GE_CHK_STATUS_RET(GetPeerOutControlToInControl(anchorPtrSet, fusionOpPeerOutControlAnchor, fusionOps[start]),
  139. "Get peer outControlAnchor to inControlAnchor failed");
  140. GE_CHK_STATUS_RET(GetPeerInControlFromOutControl(anchorPtrSet, fusionOpPeerInControlAnchor, fusionOps[start]),
  141. "Get peer outControlAnchor from inControlAnchor failed");
  142. GE_CHK_STATUS_RET(graph->RemoveNode(fusionOps[start]), "FusionAllReducePass FAILED: remove node %s\n.",
  143. fusionOps[start]->GetName().c_str());
  144. for (uint32_t idx = start + 1; idx <= end; idx++) {
  145. GE_CHK_STATUS_RET(
  146. GetPeerOutDataToInData(anchorPtrSet, fusionOpPeerOutDataAnchor, fusionOps[idx], newAllreduceDesc),
  147. "Get peer outDataAnchor to inDataAnchor failed");
  148. GE_CHK_STATUS_RET(GetPeerOutDataToInControl(anchorPtrSet, fusionOpPeerOutDataToInControl, fusionOps[idx]),
  149. "Get peer outDataAnchor to inControlAnchor failed");
  150. GE_CHK_STATUS_RET(GetPeerOutControlToInControl(anchorPtrSet, fusionOpPeerOutControlAnchor, fusionOps[idx]),
  151. "Get peer outControlAnchor to inControlAnchor failed");
  152. GE_CHK_STATUS_RET(
  153. GetPeerAnchorFromOutData(anchorPtrSet, fusionOpPeerInDataAnchor, fusionOpPeerInControlFromOutData,
  154. fusionOps[idx], newAllreduceDesc, outDataAnchorIndex),
  155. "Get peerAnchor from outDataAnchor failed");
  156. GE_CHK_STATUS_RET(GetPeerInControlFromOutControl(anchorPtrSet, fusionOpPeerInControlAnchor, fusionOps[idx]),
  157. "Get peer outControlAnchor from inControlAnchor failed");
  158. // Delete the node
  159. GE_CHK_STATUS_RET(graph->RemoveNode(fusionOps[idx]), "FusionAllReducePass FAILED: remove node %s\n.",
  160. fusionOps[idx]->GetName().c_str());
  161. }
  162. NodePtr newAllReducePtr = graph->AddNode(newAllreduceDesc);
  163. GE_CHECK_NOTNULL(newAllReducePtr);
  164. // Link the inputDataAnchor
  165. for (uint32_t i = 0; i < fusionOpPeerOutDataAnchor.size(); i++) {
  166. GE_CHK_STATUS_RET(
  167. GraphUtils::AddEdge(fusionOpPeerOutDataAnchor[i], newAllReducePtr->GetInDataAnchor(static_cast<int>(i))),
  168. "FusionAllReducePass FAILED: add input data edge failed");
  169. }
  170. // Link the inputControlAnchor
  171. for (uint32_t i = 0; i < fusionOpPeerOutControlAnchor.size(); i++) {
  172. GE_CHK_STATUS_RET(GraphUtils::AddEdge(fusionOpPeerOutControlAnchor[i], newAllReducePtr->GetInControlAnchor()),
  173. "FusionAllReducePass FAILED: add input control edge failed");
  174. }
  175. for (uint32_t i = 0; i < fusionOpPeerOutDataToInControl.size(); i++) {
  176. GE_CHK_STATUS_RET(GraphUtils::AddEdge(fusionOpPeerOutDataToInControl[i], newAllReducePtr->GetInControlAnchor()),
  177. "FusionAllReducePass FAILED: add edge from out data to incontrol "
  178. "failed");
  179. }
  180. // Link the outputDataAnchor
  181. for (uint32_t i = 0; i < fusionOpPeerInDataAnchor.size(); i++) {
  182. auto peerInDataAnchor = fusionOpPeerInDataAnchor[i].second;
  183. GE_CHK_STATUS_RET(
  184. GraphUtils::AddEdge(newAllReducePtr->GetOutDataAnchor(fusionOpPeerInDataAnchor[i].first), peerInDataAnchor),
  185. "FusionAllReducePass FAILED: add output data edge failed");
  186. }
  187. for (uint32_t i = 0; i < fusionOpPeerInControlFromOutData.size(); i++) {
  188. auto peerInControlAnchor = fusionOpPeerInControlFromOutData[i].second;
  189. GE_CHK_STATUS_RET(
  190. GraphUtils::AddEdge(newAllReducePtr->GetOutDataAnchor(fusionOpPeerInControlFromOutData[i].first),
  191. peerInControlAnchor),
  192. "FusionAllReducePass FAILED: add edge from out data to in control "
  193. "failed");
  194. }
  195. // Link the outputControlAnchor
  196. for (uint32_t i = 0; i < fusionOpPeerInControlAnchor.size(); i++) {
  197. GE_CHK_STATUS_RET(GraphUtils::AddEdge(newAllReducePtr->GetOutControlAnchor(), fusionOpPeerInControlAnchor[i]),
  198. "FusionAllReducePass FAILED: add output control edge failed");
  199. }
  200. // Link the newAllreduce
  201. if (segmentIdx > 0 && previousNewAllreduceOutControlAnchor != nullptr) {
  202. GE_CHK_STATUS_RET(
  203. GraphUtils::AddEdge(previousNewAllreduceOutControlAnchor, newAllReducePtr->GetInControlAnchor()),
  204. "FusionAllReducePass FAILED: add input previous control edge failed");
  205. }
  206. previousNewAllreduceOutControlAnchor = newAllReducePtr->GetOutControlAnchor();
  207. start = end + 1;
  208. }
  209. return SUCCESS;
  210. }
  211. Status AllReducePass::GetPeerOutDataToInData(std::unordered_set<void *> &anchorSet,
  212. vector<ge::OutDataAnchorPtr> &peerOutDataAnchorVec,
  213. ge::NodePtr &srcNodePtr) {
  214. for (auto inDataAnchor : srcNodePtr->GetAllInDataAnchors()) {
  215. GE_IF_BOOL_EXEC(inDataAnchor == nullptr, continue;);
  216. OutDataAnchorPtr peerOutDataAnchor = inDataAnchor->GetPeerOutAnchor();
  217. GE_IF_BOOL_EXEC(peerOutDataAnchor == nullptr, continue;);
  218. if (anchorSet.count(peerOutDataAnchor.get()) == 0) {
  219. peerOutDataAnchorVec.push_back(peerOutDataAnchor);
  220. anchorSet.insert(peerOutDataAnchor.get());
  221. GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peerOutDataAnchor, inDataAnchor));
  222. }
  223. }
  224. return SUCCESS;
  225. }
  226. Status AllReducePass::GetPeerInAnchorToOutData(
  227. std::unordered_set<void *> &anchorSet, std::vector<std::pair<int, ge::InDataAnchorPtr>> &fusionOpPeerInDataAnchor,
  228. std::vector<std::pair<int, ge::InControlAnchorPtr>> &fusionOpPeerInControlFromOutData, ge::NodePtr &srcNodePtr) {
  229. for (auto outDataAnchor : srcNodePtr->GetAllOutDataAnchors()) {
  230. GE_IF_BOOL_EXEC(outDataAnchor == nullptr, continue;);
  231. for (auto peerInDataAnchor : outDataAnchor->GetPeerInDataAnchors()) {
  232. GE_IF_BOOL_EXEC(peerInDataAnchor == nullptr, continue;);
  233. if (anchorSet.count(peerInDataAnchor.get()) == 0) {
  234. std::pair<int, ge::InDataAnchorPtr> pairPeerInDataAnchor;
  235. pairPeerInDataAnchor.first = 0;
  236. pairPeerInDataAnchor.second = peerInDataAnchor;
  237. fusionOpPeerInDataAnchor.push_back(pairPeerInDataAnchor);
  238. anchorSet.insert(peerInDataAnchor.get());
  239. GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outDataAnchor, peerInDataAnchor));
  240. }
  241. }
  242. for (auto peerInControlAnchorFromData : outDataAnchor->GetPeerInControlAnchors()) {
  243. GE_IF_BOOL_EXEC(peerInControlAnchorFromData == nullptr, continue;);
  244. if (anchorSet.count(peerInControlAnchorFromData.get()) == 0) {
  245. std::pair<uint32_t, ge::InControlAnchorPtr> pairPeerInControlAnchorFromData;
  246. pairPeerInControlAnchorFromData.first = 0;
  247. pairPeerInControlAnchorFromData.second = peerInControlAnchorFromData;
  248. fusionOpPeerInControlFromOutData.push_back(pairPeerInControlAnchorFromData);
  249. anchorSet.insert(peerInControlAnchorFromData.get());
  250. GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outDataAnchor, peerInControlAnchorFromData));
  251. }
  252. }
  253. }
  254. return SUCCESS;
  255. }
  256. Status AllReducePass::GetPeerOutDataToInData(std::unordered_set<void *> &anchorSet,
  257. vector<ge::OutDataAnchorPtr> &peerOutDataAnchorVec,
  258. ge::NodePtr &srcNodePtr, ge::OpDescPtr &dstOpDescPtr) {
  259. for (auto inDataAnchor : srcNodePtr->GetAllInDataAnchors()) {
  260. GE_IF_BOOL_EXEC(inDataAnchor == nullptr, continue;);
  261. OutDataAnchorPtr peerOutDataAnchor = inDataAnchor->GetPeerOutAnchor();
  262. GE_IF_BOOL_EXEC(peerOutDataAnchor == nullptr, continue;);
  263. if (anchorSet.count(peerOutDataAnchor.get()) == 0) {
  264. peerOutDataAnchorVec.push_back(peerOutDataAnchor);
  265. anchorSet.insert(peerOutDataAnchor.get());
  266. if (dstOpDescPtr->AddInputDesc(inDataAnchor->GetOwnerNode()->GetOpDesc()->GetInputDesc(inDataAnchor->GetIdx())) !=
  267. ge::GRAPH_SUCCESS) {
  268. GELOGW("GetPeerOutDataToInData: AddInputDesc failed");
  269. }
  270. GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peerOutDataAnchor, inDataAnchor));
  271. }
  272. }
  273. return SUCCESS;
  274. }
  275. Status AllReducePass::GetPeerOutDataToInControl(std::unordered_set<void *> &anchorSet,
  276. vector<ge::OutDataAnchorPtr> &peerOutDataToInControlVec,
  277. ge::NodePtr &srcNodePtr) {
  278. InControlAnchorPtr inControlAnchor = srcNodePtr->GetInControlAnchor();
  279. GE_CHECK_NOTNULL(inControlAnchor);
  280. for (auto peerOutDataToInControl : inControlAnchor->GetPeerOutDataAnchors()) {
  281. GE_IF_BOOL_EXEC(peerOutDataToInControl == nullptr, continue;);
  282. if (anchorSet.count(peerOutDataToInControl.get()) == 0) {
  283. peerOutDataToInControlVec.push_back(peerOutDataToInControl);
  284. anchorSet.insert(peerOutDataToInControl.get());
  285. GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peerOutDataToInControl, inControlAnchor));
  286. }
  287. }
  288. return SUCCESS;
  289. }
  290. Status AllReducePass::GetPeerOutControlToInControl(std::unordered_set<void *> &anchorSet,
  291. vector<ge::OutControlAnchorPtr> &peerOutControlToInControlVec,
  292. ge::NodePtr &srcNodePtr) {
  293. InControlAnchorPtr inControlAnchor = srcNodePtr->GetInControlAnchor();
  294. GE_CHECK_NOTNULL(inControlAnchor);
  295. for (auto peerOutControlAnchor : inControlAnchor->GetPeerOutControlAnchors()) {
  296. GE_IF_BOOL_EXEC(peerOutControlAnchor == nullptr, continue;);
  297. if (anchorSet.count(peerOutControlAnchor.get()) == 0) {
  298. peerOutControlToInControlVec.push_back(peerOutControlAnchor);
  299. anchorSet.insert(peerOutControlAnchor.get());
  300. GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peerOutControlAnchor, inControlAnchor));
  301. }
  302. }
  303. return SUCCESS;
  304. }
  305. Status AllReducePass::GetPeerAnchorFromOutData(
  306. std::unordered_set<void *> &anchorSet, vector<std::pair<int, ge::InDataAnchorPtr>> &peerInDataFromOutDataVec,
  307. vector<std::pair<int, ge::InControlAnchorPtr>> &peerInControlFromOutDataVec, ge::NodePtr &srcNodePtr,
  308. ge::OpDescPtr &dstOpDescPtr, int &index) {
  309. for (auto outDataAnchor : srcNodePtr->GetAllOutDataAnchors()) {
  310. GE_IF_BOOL_EXEC(outDataAnchor == nullptr, continue;)
  311. if (outDataAnchor->GetPeerInDataAnchors().size() > 0 || outDataAnchor->GetPeerInControlAnchors().size() > 0) {
  312. if (dstOpDescPtr->AddOutputDesc(
  313. outDataAnchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(outDataAnchor->GetIdx())) != ge::GRAPH_SUCCESS) {
  314. GELOGW("GetPeerAnchorFromOutData: AddOutputDesc failed");
  315. }
  316. index++;
  317. }
  318. for (auto peerInDataAnchor : outDataAnchor->GetPeerInDataAnchors()) {
  319. GE_IF_BOOL_EXEC(peerInDataAnchor == nullptr, continue;)
  320. if (anchorSet.count(peerInDataAnchor.get()) == 0) {
  321. std::pair<int, ge::InDataAnchorPtr> pairPeerInDataAnchor;
  322. pairPeerInDataAnchor.first = index;
  323. pairPeerInDataAnchor.second = peerInDataAnchor;
  324. peerInDataFromOutDataVec.push_back(pairPeerInDataAnchor);
  325. anchorSet.insert(peerInDataAnchor.get());
  326. GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outDataAnchor, peerInDataAnchor))
  327. }
  328. }
  329. for (auto peerInControlAnchorFromData : outDataAnchor->GetPeerInControlAnchors()) {
  330. GE_IF_BOOL_EXEC(peerInControlAnchorFromData == nullptr, continue;)
  331. if (anchorSet.count(peerInControlAnchorFromData.get()) == 0) {
  332. std::pair<int, ge::InControlAnchorPtr> pairPeerInControlAnchorFromData;
  333. pairPeerInControlAnchorFromData.first = index;
  334. pairPeerInControlAnchorFromData.second = peerInControlAnchorFromData;
  335. peerInControlFromOutDataVec.push_back(pairPeerInControlAnchorFromData);
  336. anchorSet.insert(peerInControlAnchorFromData.get());
  337. GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outDataAnchor, peerInControlAnchorFromData))
  338. }
  339. }
  340. }
  341. return SUCCESS;
  342. }
  343. Status AllReducePass::GetPeerInControlFromOutControl(std::unordered_set<void *> &anchorSet,
  344. vector<ge::InControlAnchorPtr> &peerInControlFromOutControlVec,
  345. ge::NodePtr &srcNodePtr) {
  346. OutControlAnchorPtr outControlAnchor = srcNodePtr->GetOutControlAnchor();
  347. GE_CHECK_NOTNULL(outControlAnchor);
  348. for (auto peerInControlAnchor : outControlAnchor->GetPeerInControlAnchors()) {
  349. GE_IF_BOOL_EXEC(peerInControlAnchor == nullptr, continue;)
  350. if (anchorSet.count(peerInControlAnchor.get()) == 0) {
  351. peerInControlFromOutControlVec.push_back(peerInControlAnchor);
  352. anchorSet.insert(peerInControlAnchor.get());
  353. GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(outControlAnchor, peerInControlAnchor))
  354. }
  355. }
  356. return SUCCESS;
  357. }
  358. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示