You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

switch_to_stream_switch_pass.cc 41 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/switch_to_stream_switch_pass.h"
  17. #include <stack>
  18. #include "common/ge/ge_util.h"
  19. #include "ge/ge_api_types.h"
  20. #include "graph/common/omg_util.h"
  21. #include "graph/ge_context.h"
  22. #include "graph/utils/type_utils.h"
  23. namespace ge {
  24. Status SwitchToStreamSwitchPass::Run(ComputeGraphPtr graph) {
  25. GELOGD("SwitchToStreamSwitchPass Enter");
  26. GE_CHK_STATUS_RET(CheckCycleDependence(graph), "Check cyclic dependence failed.");
  27. for (const auto &switch_node : switch_nodes_) {
  28. GE_CHK_STATUS_RET(ReplaceSwitchNode(graph, switch_node), "Replace Switch by StreamSwitch failed.");
  29. }
  30. GE_CHK_STATUS_RET(CombineSwitchNode(graph), "Combine StreamSwitch nodes failed.");
  31. for (const auto &node : bypass_nodes_) {
  32. GE_CHK_BOOL_EXEC(graph->IsolateNode(node) == GRAPH_SUCCESS,
  33. REPORT_CALL_ERROR("E19999", "Isolate node:%s(%s) in graph:%s failed",
  34. node->GetName().c_str(), node->GetType().c_str(), graph->GetName().c_str());
  35. return FAILED, "Isolate node failed.");
  36. GE_CHK_BOOL_EXEC(GraphUtils::RemoveNodeWithoutRelink(graph, node) == GRAPH_SUCCESS,
  37. REPORT_CALL_ERROR("E19999", "Remove node:%s(%s) without relink in graph:%s failed",
  38. node->GetName().c_str(), node->GetType().c_str(), graph->GetName().c_str());
  39. return FAILED,
  40. "Remove switch node failed.");
  41. }
  42. GELOGD("SwitchToStreamSwitchPass Leave");
  43. return SUCCESS;
  44. }
  45. ///
  46. /// @brief Clear Status
  47. /// @return
  48. ///
  49. Status SwitchToStreamSwitchPass::ClearStatus() {
  50. switch_nodes_.clear();
  51. switch_cyclic_map_.clear();
  52. bypass_nodes_.clear();
  53. stream_switch_nodes_.clear();
  54. cond_node_map_.clear();
  55. switch_node_map_.clear();
  56. node_num_map_.clear();
  57. return SUCCESS;
  58. }
  59. ///
  60. /// @brief Check cyclic dependence
  61. /// @param [in] graph
  62. /// @return Status
  63. ///
  64. Status SwitchToStreamSwitchPass::CheckCycleDependence(const ComputeGraphPtr &graph) {
  65. std::string type;
  66. std::unordered_map<NodePtr, std::vector<NodePtr>> cond_switch_map;
  67. for (const NodePtr &node : graph->GetDirectNode()) {
  68. GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed.");
  69. if ((type != SWITCH) && (type != REFSWITCH)) {
  70. continue;
  71. }
  72. InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT);
  73. GE_CHECK_NOTNULL(in_cond_anchor);
  74. OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor();
  75. GE_CHECK_NOTNULL(peer_out_anchor);
  76. if (FindSwitchCondInput(peer_out_anchor) != SUCCESS) {
  77. GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str());
  78. return FAILED;
  79. }
  80. NodePtr cond_node = peer_out_anchor->GetOwnerNode();
  81. auto iter = cond_switch_map.find(cond_node);
  82. if (iter == cond_switch_map.end()) {
  83. cond_switch_map[cond_node] = { node };
  84. } else {
  85. iter->second.emplace_back(node);
  86. }
  87. switch_nodes_.emplace_back(node);
  88. }
  89. MarkCycleDependence(cond_switch_map);
  90. return SUCCESS;
  91. }
  92. ///
  93. /// @brief Mark cyclic dependence
  94. /// @param [in] graph
  95. /// @param [in] cond_switch_map
  96. /// @return void
  97. ///
  98. void SwitchToStreamSwitchPass::MarkCycleDependence(
  99. const std::unordered_map<NodePtr, std::vector<NodePtr>> &cond_switch_map) {
  100. std::stack<NodePtr> out_nodes;
  101. NodePtr tmp_node = nullptr;
  102. std::unordered_set<NodePtr> visited;
  103. for (const auto &iter : cond_switch_map) {
  104. std::set<NodePtr> switch_nodes(iter.second.begin(), iter.second.end());
  105. for (const auto &switch_node : switch_nodes) {
  106. GELOGD("MarkCycleDependence: cond_node=%s, switch=%s.", iter.first->GetName().c_str(),
  107. switch_node->GetName().c_str());
  108. for (const auto &node : switch_node->GetOutAllNodes()) {
  109. out_nodes.push(node);
  110. }
  111. }
  112. visited.clear();
  113. while (!out_nodes.empty()) {
  114. tmp_node = out_nodes.top();
  115. out_nodes.pop();
  116. if (visited.count(tmp_node) > 0) {
  117. continue;
  118. }
  119. for (const NodePtr &out_node : tmp_node->GetOutAllNodes()) {
  120. if (switch_nodes.find(out_node) == switch_nodes.end()) {
  121. out_nodes.push(out_node);
  122. continue;
  123. }
  124. GELOGD("MarkCycleDependence: tmp_node=%s, switch_node=%s.",
  125. tmp_node->GetName().c_str(), out_node->GetName().c_str());
  126. GE_IF_BOOL_EXEC(SetCyclicDependenceFlag(out_node) != SUCCESS,
  127. GELOGW("set cyclic dependence attr failed."); return );
  128. auto map_iter = switch_cyclic_map_.find(out_node);
  129. if (map_iter == switch_cyclic_map_.end()) {
  130. switch_cyclic_map_[out_node] = {tmp_node->GetName()};
  131. } else {
  132. map_iter->second.insert(tmp_node->GetName());
  133. }
  134. }
  135. visited.insert(tmp_node);
  136. }
  137. }
  138. return;
  139. }
  140. ///
  141. /// @brief Replace Switch Op
  142. /// @param [in] graph
  143. /// @param [in] switch_node
  144. /// @return Status
  145. ///
  146. Status SwitchToStreamSwitchPass::ReplaceSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node) {
  147. OutDataAnchorPtr peer_data_anchor = nullptr;
  148. OutDataAnchorPtr peer_cond_anchor = nullptr;
  149. GE_CHK_BOOL_EXEC(BypassSwitchNode(switch_node, peer_data_anchor, peer_cond_anchor) == SUCCESS, return FAILED,
  150. "Bypass switch node %s failed.", switch_node->GetName().c_str());
  151. GE_CHECK_NOTNULL(peer_data_anchor);
  152. GE_CHECK_NOTNULL(peer_cond_anchor);
  153. OpDescPtr cond_desc = peer_cond_anchor->GetOwnerNode()->GetOpDesc();
  154. GE_CHECK_NOTNULL(cond_desc);
  155. DataType cond_data_type = cond_desc->GetOutputDesc(peer_cond_anchor->GetIdx()).GetDataType();
  156. GE_CHK_BOOL_EXEC(cond_data_type == DT_BOOL,
  157. REPORT_INNER_ERROR("E19999", "Pred_input of Switch node:%s(%s) only support DT_BOOL data_type, "
  158. "but %s exactly", switch_node->GetName().c_str(), switch_node->GetType().c_str(),
  159. TypeUtils::DataTypeToSerialString(cond_data_type).c_str());
  160. return FAILED,
  161. "pred_input of Switch only support DT_BOOL data_type, but %s exactly.",
  162. TypeUtils::DataTypeToSerialString(cond_data_type).c_str());
  163. OpDescPtr switch_desc = switch_node->GetOpDesc();
  164. GE_CHECK_NOTNULL(switch_desc);
  165. bool cyclic_flag = switch_desc->HasAttr(ATTR_NAME_CYCLIC_DEPENDENCE_FLAG);
  166. std::set<std::string> out_node_list;
  167. for (const auto &out_data_anchor : switch_node->GetAllOutDataAnchors()) {
  168. bool true_branch_flag = (static_cast<uint32_t>(out_data_anchor->GetIdx()) == SWITCH_TRUE_OUTPUT);
  169. NodePtr stream_switch = nullptr;
  170. out_node_list.clear();
  171. for (const auto &peer_in_anchor : out_data_anchor->GetPeerAnchors()) {
  172. GE_IF_BOOL_EXEC(stream_switch == nullptr, {
  173. stream_switch = CreateStreamSwitchNode(graph, switch_node, true_branch_flag ? "_t" : "_f", peer_cond_anchor);
  174. GE_CHK_BOOL_EXEC(stream_switch != nullptr, return FAILED, "Create stream_switch node failed.");
  175. if (SetSwitchTrueBranchFlag(stream_switch, true_branch_flag) != SUCCESS) {
  176. REPORT_CALL_ERROR("E19999", "Set switch true branch flag from node:%s(%s) failed",
  177. stream_switch->GetName().c_str(), stream_switch->GetType().c_str());
  178. GELOGE(FAILED, "SetSwitchTrueBranchFlag for node %s failed.", stream_switch->GetName().c_str());
  179. return FAILED;
  180. }
  181. if (MarkBranches(peer_cond_anchor, stream_switch, true_branch_flag) != SUCCESS) {
  182. GELOGE(FAILED, "Mark branches for stream_switch %s failed.", stream_switch->GetName().c_str());
  183. return FAILED;
  184. }
  185. if (!cyclic_flag) {
  186. GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor->GetOwnerNode()->GetOutControlAnchor(),
  187. stream_switch->GetInControlAnchor()),
  188. "StreamSwitch node add ctl edge failed.");
  189. }
  190. });
  191. GE_CHK_STATUS(GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor), "Remove Switch data output failed.");
  192. NodePtr out_node = peer_in_anchor->GetOwnerNode();
  193. GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor, peer_in_anchor), "StreamSwitch node add edge failed.");
  194. GE_CHK_STATUS(GraphUtils::AddEdge(stream_switch->GetOutControlAnchor(), out_node->GetInControlAnchor()),
  195. "StreamSwitch node add ctl edge failed.");
  196. out_node_list.insert(out_node->GetName());
  197. }
  198. GE_IF_BOOL_EXEC(stream_switch != nullptr, {
  199. MoveCtrlEdges(switch_node, stream_switch);
  200. switch_node_map_[stream_switch] = out_node_list;
  201. if (SetOriginalNodeName(stream_switch, switch_node->GetName()) != SUCCESS) {
  202. REPORT_CALL_ERROR("E19999", "Set original node name:%s to node:%s(%s) failed", switch_node->GetName().c_str(),
  203. stream_switch->GetName().c_str(), stream_switch->GetType().c_str());
  204. GELOGE(FAILED, "SetOriginalNodeName for node %s failed.", stream_switch->GetName().c_str());
  205. return FAILED;
  206. }
  207. });
  208. }
  209. (void)bypass_nodes_.insert(switch_node);
  210. return SUCCESS;
  211. }
  212. ///
  213. /// @brief Bypass Switch Node
  214. /// @param [in] switch_node
  215. /// @param [out] peer_data_anchor
  216. /// @param [out] peer_cond_anchor
  217. /// @return Status
  218. ///
  219. Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor,
  220. OutDataAnchorPtr &peer_cond_anchor) {
  221. for (uint32_t idx = 0; idx < SWITCH_INPUT_NUM; ++idx) {
  222. InDataAnchorPtr in_data_anchor = switch_node->GetInDataAnchor(idx);
  223. GE_CHECK_NOTNULL(in_data_anchor);
  224. OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
  225. GE_CHECK_NOTNULL(peer_out_anchor);
  226. // Remove Switch data input.
  227. if (GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor) != GRAPH_SUCCESS) {
  228. REPORT_CALL_ERROR("E19999", "Remove edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%u) failed",
  229. peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  230. peer_out_anchor->GetOwnerNode()->GetType().c_str(), peer_out_anchor->GetIdx(),
  231. switch_node->GetName().c_str(), switch_node->GetType().c_str(), idx);
  232. GELOGE(FAILED, "Remove data edge %s->%s failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  233. switch_node->GetName().c_str());
  234. return FAILED;
  235. }
  236. if (idx == SWITCH_DATA_INPUT) {
  237. peer_data_anchor = peer_out_anchor;
  238. } else {
  239. peer_cond_anchor = peer_out_anchor;
  240. }
  241. }
  242. return SUCCESS;
  243. }
  244. ///
  245. /// @brief Find Switch cond input
  246. /// @param [out] peer_cond_anchor
  247. /// @return Status
  248. ///
  249. Status SwitchToStreamSwitchPass::FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor) {
  250. NodePtr tmp_node = nullptr;
  251. std::string type;
  252. bool pass_flag = true;
  253. while (pass_flag) {
  254. if (tmp_node == nullptr) {
  255. tmp_node = peer_cond_anchor->GetOwnerNode();
  256. } else {
  257. InDataAnchorPtr in_data_anchor = tmp_node->GetInDataAnchor(SWITCH_DATA_INPUT);
  258. GE_CHECK_NOTNULL(in_data_anchor);
  259. peer_cond_anchor = in_data_anchor->GetPeerOutAnchor();
  260. GE_CHECK_NOTNULL(peer_cond_anchor);
  261. tmp_node = peer_cond_anchor->GetOwnerNode();
  262. }
  263. GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type failed.");
  264. pass_flag = ((type == SWITCH) || (type == REFSWITCH));
  265. }
  266. return SUCCESS;
  267. }
  268. ///
  269. /// @brief Create StreamSwitch Node
  270. /// @param [in] graph
  271. /// @param [in] switch_node
  272. /// @param [in] suffix
  273. /// @param [in] peer_cond_anchor
  274. /// @return ge::NodePtr
  275. ///
  276. NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node,
  277. const std::string &suffix,
  278. const OutDataAnchorPtr &peer_cond_anchor) {
  279. OpDescPtr switch_op_desc = switch_node->GetOpDesc();
  280. GE_CHK_BOOL_EXEC(switch_op_desc != nullptr,
  281. REPORT_INNER_ERROR("E19999", "OpDesc in node is nullptr, check invalid");
  282. return nullptr, "OpDesc of Switch node is invalid.");
  283. GE_IF_BOOL_EXEC(switch_op_desc->GetInputsSize() != SWITCH_INPUT_NUM, {
  284. REPORT_INNER_ERROR("E19999", "Input desc size:%zu of node:%s(%s) not equal to %u, check invalid",
  285. switch_op_desc->GetInputsSize(),
  286. switch_op_desc->GetName().c_str(), switch_op_desc->GetType().c_str(), SWITCH_INPUT_NUM);
  287. GELOGE(FAILED, "Switch input param invalid, input_size=%lu, should be %u.", switch_op_desc->GetInputsSize(),
  288. SWITCH_INPUT_NUM);
  289. return nullptr;
  290. });
  291. const std::string &node_name = switch_node->GetName() + "_" + STREAMSWITCH + suffix;
  292. GELOGI("Create StreamSwitch, name=%s.", node_name.c_str());
  293. OpDescPtr op_desc = MakeShared<OpDesc>(node_name, STREAMSWITCH);
  294. if (op_desc == nullptr) {
  295. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  296. GELOGE(FAILED, "Create op_desc failed, StreamSwitch:%s.", node_name.c_str());
  297. return nullptr;
  298. }
  299. // mark hccl group id
  300. std::string hccl_group_id;
  301. if (AttrUtils::GetStr(switch_node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) {
  302. (void)AttrUtils::SetStr(op_desc, ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id);
  303. GELOGD("Set attr ATTR_NAME_HCCL_FUSED_GROUP for Stream_Switch %s, value is %s.", node_name.c_str(),
  304. hccl_group_id.c_str());
  305. }
  306. int64_t switch_type;
  307. if (AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_TYPE, switch_type)) {
  308. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, switch_type);
  309. GELOGD("Set attr ATTR_NAME_STREAM_SWITCH_TYPE for Stream_Switch %s, value is %ld.", node_name.c_str(),
  310. switch_type);
  311. }
  312. if (!AttrUtils::SetInt(op_desc, ATTR_NAME_SWITCH_DATA_TYPE, RT_SWITCH_INT32) ||
  313. !AttrUtils::SetInt(op_desc, ATTR_NAME_STREAM_SWITCH_COND, (int64_t)RT_EQUAL)) {
  314. REPORT_CALL_ERROR("E19999", "Set Attr:%s or Attr:%s to op:%s(%s) failed",
  315. ATTR_NAME_SWITCH_DATA_TYPE.c_str(), ATTR_NAME_STREAM_SWITCH_COND.c_str(),
  316. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  317. GELOGE(INTERNAL_ERROR, "set int failed");
  318. return nullptr;
  319. }
  320. // Already checked, first input is Variable will passed, second is condition will checked.
  321. GeTensorDesc cond_input_desc = switch_op_desc->GetInputDesc(SWITCH_PRED_INPUT);
  322. GeTensorDesc input_desc(GeShape(cond_input_desc.GetShape().GetDims()), cond_input_desc.GetFormat(), DT_INT32);
  323. GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(input_desc) == GRAPH_SUCCESS,
  324. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
  325. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  326. return nullptr,
  327. "Create StreamSwitch node: add input desc failed.");
  328. GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(input_desc) == GRAPH_SUCCESS,
  329. REPORT_CALL_ERROR("E19999", "Add ouput desc to op:%s(%s) failed",
  330. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  331. return nullptr,
  332. "Create StreamSwitch node: add input desc failed.");
  333. NodePtr stream_switch = graph->AddNode(op_desc);
  334. GE_CHK_BOOL_EXEC(stream_switch != nullptr,
  335. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  336. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  337. return nullptr, "Insert StreamSwitch node failed.");
  338. GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)),
  339. "StreamSwitch node add cond edge failed.");
  340. MarkForceUnknownShape(stream_switch, switch_node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE));
  341. return stream_switch;
  342. }
  343. ///
  344. /// @brief Mark Switch Branch
  345. /// @param [in] peer_cond_anchor
  346. /// @param [in] stream_switch
  347. /// @param [in] true_branch_flag
  348. /// @return Status
  349. ///
  350. Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_anchor, const NodePtr &stream_switch,
  351. bool true_branch_flag) {
  352. uint32_t index = true_branch_flag ? SWITCH_TRUE_OUTPUT : SWITCH_FALSE_OUTPUT;
  353. auto it = cond_node_map_.find(peer_cond_anchor);
  354. if (it != cond_node_map_.end()) {
  355. int64_t switch_group_id = GetGroupId(stream_switch);
  356. auto switch_group_it = it->second.find(switch_group_id);
  357. if (switch_group_it == it->second.end()) {
  358. std::list<NodePtr> false_node_list;
  359. std::list<NodePtr> true_node_list;
  360. std::list<NodePtr> &node_list = true_branch_flag ? true_node_list : false_node_list;
  361. node_list.emplace_back(stream_switch);
  362. std::vector<std::list<NodePtr>> switch_list;
  363. switch_list.emplace_back(false_node_list);
  364. switch_list.emplace_back(true_node_list);
  365. it->second[switch_group_id] = switch_list;
  366. } else {
  367. GE_IF_BOOL_EXEC(switch_group_it->second.size() != SWITCH_OUTPUT_NUM, {
  368. REPORT_INNER_ERROR("E19999", "switch group size:%zu not equal to %u, group_id:%ld, check invalid",
  369. switch_group_it->second.size(), SWITCH_OUTPUT_NUM, switch_group_id);
  370. GELOGE(INTERNAL_ERROR, "Check size failed, node: %s", stream_switch->GetName().c_str());
  371. return FAILED;
  372. });
  373. switch_group_it->second[index].emplace_back(stream_switch);
  374. }
  375. } else {
  376. int64_t switch_group_id = GetGroupId(stream_switch);
  377. std::map<int64_t, std::vector<std::list<NodePtr>>> switch_group_map;
  378. std::list<NodePtr> false_node_list;
  379. std::list<NodePtr> true_node_list;
  380. std::list<NodePtr> &node_list = true_branch_flag ? true_node_list : false_node_list;
  381. node_list.emplace_back(stream_switch);
  382. std::vector<std::list<NodePtr>> switch_list;
  383. switch_list.emplace_back(false_node_list);
  384. switch_list.emplace_back(true_node_list);
  385. switch_group_map[switch_group_id] = switch_list;
  386. cond_node_map_[peer_cond_anchor] = switch_group_map;
  387. }
  388. return SUCCESS;
  389. }
  390. ///
  391. /// @brief Get group_id for switch_node
  392. /// @param [in] node
  393. /// @return group_id
  394. ///
  395. int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) {
  396. std::string tailing_optimization_option;
  397. bool is_tailing_optimization = false;
  398. if (GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option) == GRAPH_SUCCESS) {
  399. // "1" means it's True from frontend option
  400. is_tailing_optimization = (tailing_optimization_option == "1");
  401. GELOGI("Option ge.exec.isTailingOptimization is %s", tailing_optimization_option.c_str());
  402. }
  403. if (!is_tailing_optimization) {
  404. return 0;
  405. }
  406. std::string hccl_group_id;
  407. if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) {
  408. GELOGI("Node %s can not find hccl group id.", node->GetName().c_str());
  409. return 0;
  410. }
  411. auto key_index = hccl_group_id.find_last_of('_');
  412. auto key_num = hccl_group_id.substr(key_index + 1, hccl_group_id.length() - key_index);
  413. GELOGI("Node:%s, hccl_group_id=%s, key_num=%s", node->GetName().c_str(), hccl_group_id.c_str(), key_num.c_str());
  414. int64_t num = atoi(key_num.c_str());
  415. if (num == 0) {
  416. return 0;
  417. }
  418. GELOGI("Hccl_group_id is %s, group_id is %ld", hccl_group_id.c_str(), num);
  419. return num;
  420. }
  421. ///
  422. /// @brief Combine switch nodes link to same cond
  423. /// @param [in] graph
  424. /// @return Status
  425. ///
  426. Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) {
  427. for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) {
  428. for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) {
  429. std::list<NodePtr> false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT];
  430. std::list<NodePtr> true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT];
  431. std::set<NodePtr> same_cond_switch;
  432. same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end());
  433. same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end());
  434. OutDataAnchorPtr peer_cond_anchor = iter->first;
  435. GE_CHECK_NOTNULL(peer_cond_anchor);
  436. NodePtr cond_node = peer_cond_anchor->GetOwnerNode();
  437. GELOGI("CombineSwitchNode: cond_node=%s.", cond_node->GetName().c_str());
  438. NodePtr cast_node = CreateCastOp(graph, peer_cond_anchor);
  439. GE_CHK_BOOL_EXEC(cast_node != nullptr, return FAILED, "Create cast_node failed.");
  440. NodePtr active_node = CreateActiveNode(graph, cond_node);
  441. GE_CHK_BOOL_EXEC(active_node != nullptr, return FAILED, "Create StreamActive node failed.");
  442. GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutControlAnchor(), active_node->GetInControlAnchor()),
  443. "StreamActive add ctl edge failed.");
  444. if (SetActiveLabelList(active_node, { cast_node->GetName() }) != SUCCESS) {
  445. REPORT_CALL_ERROR("E19999", "Set active label list:%s to op:%s(%s) failed",
  446. cast_node->GetName().c_str(), active_node->GetName().c_str(), active_node->GetType().c_str());
  447. GELOGE(FAILED, "Set active_label_list attr for node %s failed.", active_node->GetName().c_str());
  448. return FAILED;
  449. }
  450. std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) {
  451. return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE);
  452. };
  453. bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback);
  454. MarkForceUnknownShape(active_node, is_unknown_shape);
  455. const std::string &cond_group = cond_node->GetName();
  456. for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) {
  457. bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT);
  458. std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list);
  459. GE_IF_BOOL_EXEC(switch_list.empty(), continue);
  460. // select first stream_switch
  461. NodePtr stream_switch = switch_list.front();
  462. // set stream_label
  463. if (SetStreamLabel(stream_switch, cast_node->GetName()) != SUCCESS) {
  464. REPORT_CALL_ERROR("E19999", "Set stream_label:%s to op:%s(%s) failed",
  465. cast_node->GetName().c_str(), stream_switch->GetName().c_str(),
  466. stream_switch->GetType().c_str());
  467. GELOGE(FAILED, "Set stream label failed.");
  468. return FAILED;
  469. }
  470. OpDescPtr switch_desc = stream_switch->GetOpDesc();
  471. GE_CHECK_NOTNULL(switch_desc);
  472. switch_desc->SetName(CheckDuplicateName(cond_group + "/" + STREAMSWITCH + (true_branch_flag ? "_t" : "_f")));
  473. stream_switch_nodes_.emplace_back(stream_switch);
  474. // 0_input: original pred input, 1_input: constant node
  475. GE_CHK_STATUS_RET(AddConstNode(graph, stream_switch), "Add const node failed.");
  476. GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)),
  477. "StreamSwitch remove data edge failed.");
  478. GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)),
  479. "Cast add data edge failed.");
  480. MarkForceUnknownShape(stream_switch, is_unknown_shape);
  481. for (const NodePtr &node : switch_list) {
  482. GE_IF_BOOL_EXEC(node != stream_switch, {
  483. GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)),
  484. "StreamSwitch remove data edge failed.");
  485. });
  486. GE_CHK_STATUS(ModifySwitchInCtlEdges(node, cast_node, same_cond_switch), "ModifySwitchInCtlEdges failed.");
  487. GE_CHK_STATUS(ModifySwitchOutCtlEdges(node, stream_switch, active_node), "ModifySwitchOutCtlEdges failed.");
  488. }
  489. GE_CHK_STATUS(GraphUtils::AddEdge(active_node->GetOutControlAnchor(), stream_switch->GetInControlAnchor()),
  490. "StreamActive add ctl edge failed.");
  491. }
  492. }
  493. }
  494. return SUCCESS;
  495. }
  496. ///
  497. /// @brief Create Active Op
  498. /// @param [in] graph
  499. /// @param [in] cond_node
  500. /// @return ge::NodePtr
  501. ///
  502. NodePtr SwitchToStreamSwitchPass::CreateActiveNode(const ComputeGraphPtr &graph, const NodePtr &node) {
  503. const std::string &node_name = CheckDuplicateName(node->GetName() + "_" + STREAMACTIVE);
  504. GELOGI("Create StreamActive op:%s.", node_name.c_str());
  505. OpDescPtr op_desc = MakeShared<OpDesc>(node_name, STREAMACTIVE);
  506. if (op_desc == nullptr) {
  507. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  508. GELOGE(FAILED, "Create op_desc failed, StreamActive:%s.", node_name.c_str());
  509. return nullptr;
  510. }
  511. NodePtr active_node = graph->AddNode(op_desc);
  512. GE_CHK_BOOL_EXEC(active_node != nullptr,
  513. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  514. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  515. return nullptr, "Create StreamActive node failed.");
  516. GE_IF_BOOL_EXEC(GraphUtils::AddEdge(node->GetOutControlAnchor(), active_node->GetInControlAnchor()) != SUCCESS,
  517. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  518. node->GetName().c_str(), node->GetType().c_str(),
  519. active_node->GetName().c_str(), active_node->GetType().c_str());
  520. GELOGE(INTERNAL_ERROR, "add edge failed");
  521. return nullptr);
  522. GE_IF_BOOL_EXEC(SetSwitchBranchNodeLabel(active_node, node_name) != SUCCESS,
  523. REPORT_CALL_ERROR("E19999", "Set switch branch node label:%s to node:%s(%s) failed",
  524. node_name.c_str(), active_node->GetName().c_str(), active_node->GetType().c_str());
  525. GELOGE(INTERNAL_ERROR, "set switch branch node label failed");
  526. return nullptr);
  527. return active_node;
  528. }
  529. ///
  530. /// @brief Create cast node
  531. /// @param [in] graph
  532. /// @param [in] peer_cond_anchor
  533. /// @return NodePtr
  534. ///
  535. NodePtr SwitchToStreamSwitchPass::CreateCastOp(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_cond_anchor) {
  536. OpDescPtr cond_desc = peer_cond_anchor->GetOwnerNode()->GetOpDesc();
  537. GE_CHK_BOOL_EXEC(cond_desc != nullptr, return nullptr, "Get cond_desc failed.");
  538. const std::string &cast_name = CheckDuplicateName(cond_desc->GetName() + "_" + CAST);
  539. GELOGI("Create cast_node: %s, input datatype:DT_BOOL, out datatype:DT_INT32", cast_name.c_str());
  540. OpDescPtr cast_desc = MakeShared<OpDesc>(cast_name, CAST);
  541. if (cast_desc == nullptr) {
  542. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  543. GELOGE(FAILED, "Create op_desc failed, Cast:%s.", cast_name.c_str());
  544. return nullptr;
  545. }
  546. if (!(AttrUtils::SetInt(cast_desc, CAST_ATTR_SRCT, (int64_t)DT_BOOL) &&
  547. AttrUtils::SetInt(cast_desc, CAST_ATTR_DSTT, (int64_t)DT_INT32) &&
  548. AttrUtils::SetInt(cast_desc, CAST_ATTR_DST_TYPE, (int64_t)DT_INT32) &&
  549. AttrUtils::SetBool(cast_desc, CAST_ATTR_TRUNCATE, false))) {
  550. REPORT_CALL_ERROR("E19999", "Set Attr:%s or %s or %s or %s to op:%s(%s) failed",
  551. CAST_ATTR_SRCT.c_str(), CAST_ATTR_DSTT.c_str(),
  552. CAST_ATTR_DST_TYPE.c_str(), CAST_ATTR_TRUNCATE.c_str(),
  553. cast_desc->GetName().c_str(), cast_desc->GetType().c_str());
  554. GELOGE(FAILED, "Set CAST_ATTR_SRCT or CAST_ATTR_DSTT or CAST_ATTR_DST_TYPE or CAST_ATTR_TRUNCATE failed, node: %s.",
  555. cast_name.c_str());
  556. return nullptr;
  557. }
  558. GeTensorDesc tensor_desc = cond_desc->GetOutputDesc(peer_cond_anchor->GetIdx());
  559. tensor_desc.SetDataType(DT_BOOL);
  560. GE_CHK_BOOL_EXEC(cast_desc->AddInputDesc(tensor_desc) == SUCCESS,
  561. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
  562. cast_desc->GetName().c_str(), cast_desc->GetType().c_str());
  563. return nullptr,
  564. "Cast_node add input desc failed.");
  565. tensor_desc.SetDataType(DT_INT32);
  566. GE_CHK_BOOL_EXEC(cast_desc->AddOutputDesc(tensor_desc) == SUCCESS,
  567. REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed",
  568. cast_desc->GetName().c_str(), cast_desc->GetType().c_str());
  569. return nullptr,
  570. "Cast_node add output desc failed.");
  571. NodePtr cast_node = graph->AddNode(cast_desc);
  572. GE_CHK_BOOL_EXEC(cast_node != nullptr,
  573. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  574. cast_desc->GetName().c_str(), cast_desc->GetType().c_str(),
  575. graph->GetName().c_str());
  576. return nullptr, "Create cast_node failed.");
  577. // Cast node has and only has one input
  578. GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, cast_node->GetInDataAnchor(0)), "Cast add data edge failed.");
  579. return cast_node;
  580. }
  581. ///
  582. /// @brief Add const node as switch input1
  583. /// @param [in] graph
  584. /// @param [in] stream_switch
  585. /// @return Status
  586. ///
  587. Status SwitchToStreamSwitchPass::AddConstNode(const ComputeGraphPtr &graph, const NodePtr &stream_switch) {
  588. OpDescPtr op_desc = stream_switch->GetOpDesc();
  589. GE_CHECK_NOTNULL(op_desc);
  590. bool value = false;
  591. GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value),
  592. REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed",
  593. ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG.c_str(),
  594. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  595. return FAILED,
  596. "StreamSwitch get attr TRUE_BRANCH_STREAM failed.");
  597. const std::string &const_node_name = op_desc->GetName() + "_Constant_" + (value ? "t" : "f");
  598. GELOGI("Create const op: %s", const_node_name.c_str());
  599. OpDescPtr const_op_desc = MakeShared<OpDesc>(const_node_name, CONSTANT);
  600. if (const_op_desc == nullptr) {
  601. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  602. GELOGE(FAILED, "Create op_desc failed, Constant:%s.", const_node_name.c_str());
  603. return FAILED;
  604. }
  605. auto resize_value = (int32_t)value;
  606. GeTensorDesc data_desc = op_desc->GetInputDesc(1);
  607. GeTensorPtr const_value =
  608. MakeShared<GeTensor>(data_desc, reinterpret_cast<uint8_t *>(&resize_value), sizeof(int32_t));
  609. if (const_value == nullptr) {
  610. REPORT_CALL_ERROR("E19999", "New GeTensor failed");
  611. GELOGE(FAILED, "Create tensor failed.");
  612. return FAILED;
  613. }
  614. GE_CHK_BOOL_EXEC(AttrUtils::SetTensor(const_op_desc, ATTR_NAME_WEIGHTS, const_value),
  615. REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_NAME_WEIGHTS.c_str(),
  616. const_op_desc->GetName().c_str(), const_op_desc->GetType().c_str());
  617. return FAILED);
  618. GE_CHK_BOOL_EXEC(const_op_desc->AddOutputDesc(data_desc) == GRAPH_SUCCESS,
  619. REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed",
  620. const_op_desc->GetName().c_str(), const_op_desc->GetType().c_str());
  621. return FAILED,
  622. "Create Const op: add output desc failed.");
  623. NodePtr const_node = graph->AddNode(const_op_desc);
  624. GE_CHK_BOOL_EXEC(const_node != nullptr,
  625. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  626. const_op_desc->GetName().c_str(), const_op_desc->GetType().c_str(),
  627. graph->GetName().c_str());
  628. return FAILED, "Insert Const node failed.");
  629. GE_CHK_STATUS(GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(1)),
  630. "StreamSwitch node add ctl edge failed.");
  631. return SUCCESS;
  632. }
  633. ///
  634. /// @brief Modify in ctl edge for switch_node
  635. /// @param [in] switch_node
  636. /// @param [in] cast_node
  637. /// @param [in] same_cond_switch
  638. /// @return Status
  639. ///
  640. Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_node, const NodePtr &cast_node,
  641. const std::set<NodePtr> &same_cond_switch) {
  642. GELOGD("ModifySwitchInCtlEdges: switch_node=%s, cast_node=%s", switch_node->GetName().c_str(),
  643. cast_node->GetName().c_str());
  644. std::string orig_switch_name = switch_node->GetName();
  645. OpDescPtr switch_desc = switch_node->GetOpDesc();
  646. GE_CHECK_NOTNULL(switch_desc);
  647. if (!AttrUtils::GetStr(switch_desc, ATTR_NAME_ORIG_NODE_NAME, orig_switch_name) || orig_switch_name.empty()) {
  648. REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_NAME_ORIG_NODE_NAME.c_str(),
  649. switch_desc->GetName().c_str(), switch_desc->GetType().c_str());
  650. GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ORIG_NODE_NAME failed, node: %s", switch_desc->GetName().c_str());
  651. return INTERNAL_ERROR;
  652. }
  653. for (const NodePtr &in_ctrl_node : switch_node->GetInControlNodes()) {
  654. GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()),
  655. "Remove ctl edge failed.");
  656. GE_IF_BOOL_EXEC(!in_ctrl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), {
  657. GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()),
  658. "Add ctl edge failed.");
  659. });
  660. GE_IF_BOOL_EXEC(in_ctrl_node->GetType() != STREAMSWITCH, continue);
  661. if (same_cond_switch.count(in_ctrl_node) > 0) {
  662. GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()),
  663. "Remove ctl edge failed.");
  664. continue;
  665. }
  666. auto find_res1 = switch_node_map_.find(in_ctrl_node);
  667. GE_IF_BOOL_EXEC(find_res1 == switch_node_map_.end(), {
  668. REPORT_INNER_ERROR("E19999", "Node:%s(%s) can't find in switch_node_map_, check invalid",
  669. in_ctrl_node->GetName().c_str(), in_ctrl_node->GetType().c_str());
  670. GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctrl_node->GetName().c_str());
  671. return INTERNAL_ERROR;
  672. });
  673. auto find_res2 = find_res1->second.find(orig_switch_name);
  674. auto find_res3 = find_res1->second.find(cast_node->GetName());
  675. GE_IF_BOOL_EXEC((find_res2 != find_res1->second.end()) && (find_res3 == find_res1->second.end()), {
  676. find_res1->second.erase(find_res2);
  677. find_res1->second.insert(cast_node->GetName());
  678. continue;
  679. });
  680. }
  681. return SUCCESS;
  682. }
  683. ///
  684. /// @brief Modify out ctl edge for switch_node
  685. /// @param [in] switch_node
  686. /// @param [in] stream_switch
  687. /// @param [in] active_node
  688. /// @return Status
  689. ///
  690. Status SwitchToStreamSwitchPass::ModifySwitchOutCtlEdges(const NodePtr &switch_node, const NodePtr &stream_switch,
  691. const NodePtr &active_node) {
  692. GELOGD("ModifySwitchOutCtlEdges: switch_node=%s, stream_switch=%s, active_node=%s", switch_node->GetName().c_str(),
  693. stream_switch->GetName().c_str(), active_node->GetName().c_str());
  694. auto find_res = switch_node_map_.find(switch_node);
  695. GE_IF_BOOL_EXEC(find_res == switch_node_map_.end(), {
  696. REPORT_INNER_ERROR("E19999", "Node:%s(%s) can't find in switch_node_map_, check invalid",
  697. switch_node->GetName().c_str(), switch_node->GetType().c_str());
  698. GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", switch_node->GetName().c_str());
  699. return INTERNAL_ERROR;
  700. });
  701. GE_IF_BOOL_EXEC(find_res->second.empty(), {
  702. REPORT_INNER_ERROR("E19999", "True_nodes of StreamSwitch node:%s(%s) is empty, check invalid",
  703. switch_node->GetName().c_str(), switch_node->GetType().c_str());
  704. GELOGE(INTERNAL_ERROR, "true_nodes of StreamSwitch node %s is empty.", switch_node->GetName().c_str());
  705. return INTERNAL_ERROR;
  706. });
  707. for (const NodePtr &node : switch_node->GetOutControlNodes()) {
  708. GE_CHK_STATUS(GraphUtils::RemoveEdge(switch_node->GetOutControlAnchor(), node->GetInControlAnchor()),
  709. "Remove ctl edge failed.");
  710. OpDescPtr op_desc = node->GetOpDesc();
  711. GE_CHECK_NOTNULL(op_desc);
  712. std::string orig_name = op_desc->GetName();
  713. GE_IF_BOOL_EXEC(op_desc->HasAttr(ATTR_NAME_ORIG_NODE_NAME), {
  714. if (!AttrUtils::GetStr(op_desc, ATTR_NAME_ORIG_NODE_NAME, orig_name) || orig_name.empty()) {
  715. REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_NAME_ORIG_NODE_NAME.c_str(),
  716. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  717. GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ORIG_NODE_NAME failed, node: %s.", op_desc->GetName().c_str());
  718. return INTERNAL_ERROR;
  719. }
  720. });
  721. if (find_res->second.find(orig_name) == find_res->second.end()) {
  722. auto active_out_ctrl_anchor = active_node->GetOutControlAnchor();
  723. GE_CHECK_NOTNULL(active_out_ctrl_anchor);
  724. GE_IF_BOOL_EXEC(!active_out_ctrl_anchor->IsLinkedWith(node->GetInControlAnchor()), {
  725. GE_CHK_STATUS(GraphUtils::AddEdge(active_out_ctrl_anchor, node->GetInControlAnchor()), "Add ctl edge failed.");
  726. });
  727. } else {
  728. auto switch_out_ctrl_anchor = stream_switch->GetOutControlAnchor();
  729. GE_CHECK_NOTNULL(switch_out_ctrl_anchor);
  730. GE_IF_BOOL_EXEC(!switch_out_ctrl_anchor->IsLinkedWith(node->GetInControlAnchor()), {
  731. GE_CHK_STATUS(GraphUtils::AddEdge(switch_out_ctrl_anchor, node->GetInControlAnchor()), "Add ctl edge failed.");
  732. });
  733. }
  734. }
  735. GE_IF_BOOL_EXEC(switch_node != stream_switch, (void)bypass_nodes_.insert(switch_node));
  736. return SUCCESS;
  737. }
  738. ///
  739. /// @brief Check duplicate node_name
  740. /// @param [in] node_name
  741. /// @return std::string
  742. ///
  743. std::string SwitchToStreamSwitchPass::CheckDuplicateName(const std::string &node_name) {
  744. std::string tmp_name = node_name;
  745. auto iter = node_num_map_.find(tmp_name);
  746. if (iter != node_num_map_.end()) {
  747. tmp_name = tmp_name + "_" + std::to_string(iter->second);
  748. (iter->second)++;
  749. } else {
  750. node_num_map_[tmp_name] = 1;
  751. }
  752. return tmp_name;
  753. }
  754. ///
  755. /// @brief Move Control Edges
  756. /// @param [in] old_node
  757. /// @param [in] new_node
  758. /// @return void
  759. ///
  760. void SwitchToStreamSwitchPass::MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new_node) {
  761. GE_IF_BOOL_EXEC(old_node == new_node, return );
  762. auto iter = switch_cyclic_map_.find(old_node);
  763. bool check_flag = (iter != switch_cyclic_map_.end());
  764. for (const NodePtr &in_node : old_node->GetInControlNodes()) {
  765. auto out_ctrl_anchor = in_node->GetOutControlAnchor();
  766. GE_CHECK_NOTNULL_JUST_RETURN(out_ctrl_anchor);
  767. if (check_flag && (iter->second.count(in_node->GetName()) > 0)) {
  768. for (const auto &out_node : old_node->GetOutAllNodes()) {
  769. GE_IF_BOOL_EXEC(!out_ctrl_anchor->IsLinkedWith(out_node->GetInControlAnchor()), {
  770. GE_CHK_STATUS(GraphUtils::AddEdge(out_ctrl_anchor, out_node->GetInControlAnchor()),
  771. "Add in ctrl edge failed.");
  772. });
  773. }
  774. } else {
  775. GE_IF_BOOL_EXEC(!out_ctrl_anchor->IsLinkedWith(new_node->GetInControlAnchor()), {
  776. GE_CHK_STATUS(GraphUtils::AddEdge(out_ctrl_anchor, new_node->GetInControlAnchor()),
  777. "Add in ctrl edge failed.");
  778. });
  779. }
  780. GE_CHK_STATUS(GraphUtils::RemoveEdge(out_ctrl_anchor, old_node->GetInControlAnchor()),
  781. "Remove in ctrl edge failed.");
  782. }
  783. for (const NodePtr &out_node : old_node->GetOutControlNodes()) {
  784. GE_IF_BOOL_EXEC(!new_node->GetOutControlAnchor()->IsLinkedWith(out_node->GetInControlAnchor()), {
  785. GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), out_node->GetInControlAnchor()),
  786. "Add out ctrl edge failed.");
  787. });
  788. GE_CHK_STATUS(GraphUtils::RemoveEdge(old_node->GetOutControlAnchor(), out_node->GetInControlAnchor()),
  789. "Remove out ctrl edge failed.");
  790. }
  791. }
  792. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示