You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

switch_to_stream_switch_pass.cc 46 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/switch_to_stream_switch_pass.h"
  17. #include <stack>
  18. #include "common/ge/ge_util.h"
  19. #include "ge/ge_api_types.h"
  20. #include "graph/common/omg_util.h"
  21. #include "graph/ge_context.h"
  22. #include "graph/utils/type_utils.h"
  23. namespace ge {
  24. Status SwitchToStreamSwitchPass::Run(ComputeGraphPtr graph) {
  25. GELOGD("SwitchToStreamSwitchPass Enter");
  26. GE_CHK_STATUS_RET(CheckCycleDependence(graph),
  27. "[Check][CycleDependence] in graph:%s failed.", graph->GetName().c_str());
  28. for (const auto &switch_node : switch_nodes_) {
  29. GE_CHK_STATUS_RET(ReplaceSwitchNode(graph, switch_node),
  30. "[Replace][Switch] by StreamSwitch in graph:%s failed.", graph->GetName().c_str());
  31. }
  32. GE_CHK_STATUS_RET(CombineSwitchNode(graph),
  33. "[Combine][SwitchNode] in graph:%s failed.", graph->GetName().c_str());
  34. for (const auto &node : bypass_nodes_) {
  35. GE_CHK_BOOL_EXEC(graph->IsolateNode(node) == GRAPH_SUCCESS,
  36. REPORT_CALL_ERROR("E19999", "Isolate node:%s(%s) in graph:%s failed",
  37. node->GetName().c_str(), node->GetType().c_str(), graph->GetName().c_str());
  38. return FAILED,
  39. "[Isolate][Node] %s(%s) in graph:%s failed.",
  40. node->GetName().c_str(), node->GetType().c_str(), graph->GetName().c_str());
  41. GE_CHK_BOOL_EXEC(GraphUtils::RemoveNodeWithoutRelink(graph, node) == GRAPH_SUCCESS,
  42. REPORT_CALL_ERROR("E19999", "Remove node:%s(%s) without relink in graph:%s failed",
  43. node->GetName().c_str(), node->GetType().c_str(), graph->GetName().c_str());
  44. return FAILED,
  45. "[Remove][Node] %s(%s) without relink in graph:%s failed",
  46. node->GetName().c_str(), node->GetType().c_str(), graph->GetName().c_str());
  47. }
  48. GELOGD("SwitchToStreamSwitchPass Leave");
  49. return SUCCESS;
  50. }
  51. ///
  52. /// @brief Clear Status
  53. /// @return
  54. ///
  55. Status SwitchToStreamSwitchPass::ClearStatus() {
  56. switch_nodes_.clear();
  57. switch_cyclic_map_.clear();
  58. bypass_nodes_.clear();
  59. stream_switch_nodes_.clear();
  60. cond_node_map_.clear();
  61. switch_node_map_.clear();
  62. node_num_map_.clear();
  63. return SUCCESS;
  64. }
  65. ///
  66. /// @brief Check cyclic dependence
  67. /// @param [in] graph
  68. /// @return Status
  69. ///
  70. Status SwitchToStreamSwitchPass::CheckCycleDependence(const ComputeGraphPtr &graph) {
  71. std::string type;
  72. std::unordered_map<NodePtr, std::vector<NodePtr>> cond_switch_map;
  73. for (const NodePtr &node : graph->GetDirectNode()) {
  74. GE_CHK_STATUS_RET(GetOriginalType(node, type),
  75. "[Get][OriginalType] failed, graph:%s.", graph->GetName().c_str());
  76. if ((type != SWITCH) && (type != REFSWITCH)) {
  77. continue;
  78. }
  79. InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT);
  80. GE_CHECK_NOTNULL(in_cond_anchor);
  81. OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor();
  82. GE_CHECK_NOTNULL(peer_out_anchor);
  83. if (FindSwitchCondInput(peer_out_anchor) != SUCCESS) {
  84. GELOGE(FAILED, "[Find][PredInput] for switch_node %s failed.", node->GetName().c_str());
  85. return FAILED;
  86. }
  87. NodePtr cond_node = peer_out_anchor->GetOwnerNode();
  88. auto iter = cond_switch_map.find(cond_node);
  89. if (iter == cond_switch_map.end()) {
  90. cond_switch_map[cond_node] = { node };
  91. } else {
  92. iter->second.emplace_back(node);
  93. }
  94. switch_nodes_.emplace_back(node);
  95. }
  96. MarkCycleDependence(cond_switch_map);
  97. return SUCCESS;
  98. }
  99. ///
  100. /// @brief Mark cyclic dependence
  101. /// @param [in] graph
  102. /// @param [in] cond_switch_map
  103. /// @return void
  104. ///
  105. void SwitchToStreamSwitchPass::MarkCycleDependence(
  106. const std::unordered_map<NodePtr, std::vector<NodePtr>> &cond_switch_map) {
  107. std::stack<NodePtr> out_nodes;
  108. NodePtr tmp_node = nullptr;
  109. std::unordered_set<NodePtr> visited;
  110. for (const auto &iter : cond_switch_map) {
  111. std::set<NodePtr> switch_nodes(iter.second.begin(), iter.second.end());
  112. for (const auto &switch_node : switch_nodes) {
  113. GELOGD("MarkCycleDependence: cond_node=%s, switch=%s.", iter.first->GetName().c_str(),
  114. switch_node->GetName().c_str());
  115. for (const auto &node : switch_node->GetOutAllNodes()) {
  116. out_nodes.push(node);
  117. }
  118. }
  119. visited.clear();
  120. while (!out_nodes.empty()) {
  121. tmp_node = out_nodes.top();
  122. out_nodes.pop();
  123. if (visited.count(tmp_node) > 0) {
  124. continue;
  125. }
  126. for (const NodePtr &out_node : tmp_node->GetOutAllNodes()) {
  127. if (switch_nodes.find(out_node) == switch_nodes.end()) {
  128. out_nodes.push(out_node);
  129. continue;
  130. }
  131. GELOGD("MarkCycleDependence: tmp_node=%s, switch_node=%s.",
  132. tmp_node->GetName().c_str(), out_node->GetName().c_str());
  133. GE_IF_BOOL_EXEC(SetCyclicDependenceFlag(out_node) != SUCCESS,
  134. GELOGW("set cyclic dependence attr failed."); return );
  135. auto map_iter = switch_cyclic_map_.find(out_node);
  136. if (map_iter == switch_cyclic_map_.end()) {
  137. switch_cyclic_map_[out_node] = {tmp_node->GetName()};
  138. } else {
  139. map_iter->second.insert(tmp_node->GetName());
  140. }
  141. }
  142. visited.insert(tmp_node);
  143. }
  144. }
  145. return;
  146. }
  147. ///
  148. /// @brief Replace Switch Op
  149. /// @param [in] graph
  150. /// @param [in] switch_node
  151. /// @return Status
  152. ///
  153. Status SwitchToStreamSwitchPass::ReplaceSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node) {
  154. OutDataAnchorPtr peer_data_anchor = nullptr;
  155. OutDataAnchorPtr peer_cond_anchor = nullptr;
  156. GE_CHK_BOOL_EXEC(BypassSwitchNode(switch_node, peer_data_anchor, peer_cond_anchor) == SUCCESS, return FAILED,
  157. "[Bypass][SwitchNode] %s failed.", switch_node->GetName().c_str());
  158. GE_CHECK_NOTNULL(peer_data_anchor);
  159. GE_CHECK_NOTNULL(peer_cond_anchor);
  160. OpDescPtr cond_desc = peer_cond_anchor->GetOwnerNode()->GetOpDesc();
  161. GE_CHECK_NOTNULL(cond_desc);
  162. DataType cond_data_type = cond_desc->GetOutputDesc(peer_cond_anchor->GetIdx()).GetDataType();
  163. GE_CHK_BOOL_EXEC(cond_data_type == DT_BOOL,
  164. REPORT_INNER_ERROR("E19999", "Pred_input of Switch node:%s(%s) only support DT_BOOL data_type, "
  165. "but %s exactly", switch_node->GetName().c_str(), switch_node->GetType().c_str(),
  166. TypeUtils::DataTypeToSerialString(cond_data_type).c_str());
  167. return FAILED,
  168. "[Check][Param] Pred_input of Switch node:%s(%s) only support DT_BOOL data_type, but %s exactly",
  169. switch_node->GetName().c_str(), switch_node->GetType().c_str(),
  170. TypeUtils::DataTypeToSerialString(cond_data_type).c_str());
  171. OpDescPtr switch_desc = switch_node->GetOpDesc();
  172. GE_CHECK_NOTNULL(switch_desc);
  173. bool cyclic_flag = switch_desc->HasAttr(ATTR_NAME_CYCLIC_DEPENDENCE_FLAG);
  174. std::set<std::string> out_node_list;
  175. for (const auto &out_data_anchor : switch_node->GetAllOutDataAnchors()) {
  176. bool true_branch_flag = (static_cast<uint32_t>(out_data_anchor->GetIdx()) == SWITCH_TRUE_OUTPUT);
  177. NodePtr stream_switch = nullptr;
  178. out_node_list.clear();
  179. for (const auto &peer_in_anchor : out_data_anchor->GetPeerAnchors()) {
  180. GE_IF_BOOL_EXEC(stream_switch == nullptr, {
  181. stream_switch = CreateStreamSwitchNode(graph, switch_node, true_branch_flag ? "_t" : "_f", peer_cond_anchor);
  182. GE_CHK_BOOL_EXEC(stream_switch != nullptr, return FAILED,
  183. "[Create][StreamSwitchNode] for switch node:%s in graph:%s failed.",
  184. switch_node->GetName().c_str(), graph->GetName().c_str());
  185. if (SetSwitchTrueBranchFlag(stream_switch, true_branch_flag) != SUCCESS) {
  186. REPORT_CALL_ERROR("E19999", "Set switch true branch flag from node:%s(%s) failed",
  187. stream_switch->GetName().c_str(), stream_switch->GetType().c_str());
  188. GELOGE(FAILED, "[Set][SwitchTrueBranchFlag] for node %s failed.", stream_switch->GetName().c_str());
  189. return FAILED;
  190. }
  191. if (MarkBranches(peer_cond_anchor, stream_switch, true_branch_flag) != SUCCESS) {
  192. GELOGE(FAILED, "[Mark][Branches] for stream_switch %s failed.", stream_switch->GetName().c_str());
  193. return FAILED;
  194. }
  195. if (!cyclic_flag) {
  196. GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor->GetOwnerNode()->GetOutControlAnchor(),
  197. stream_switch->GetInControlAnchor()),
  198. "[Add][ControlEdge] between %s and %s failed.",
  199. peer_data_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str());
  200. }
  201. });
  202. GE_CHK_STATUS(GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor),
  203. "[Remove][Edge] between %s and %s failed.",
  204. switch_node->GetName().c_str(), peer_in_anchor->GetOwnerNode()->GetName().c_str());
  205. NodePtr out_node = peer_in_anchor->GetOwnerNode();
  206. GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor, peer_in_anchor),
  207. "[Add][Edge] between %s and %s failed.",
  208. peer_data_anchor->GetOwnerNode()->GetName().c_str(), out_node->GetName().c_str());
  209. GE_CHK_STATUS(GraphUtils::AddEdge(stream_switch->GetOutControlAnchor(), out_node->GetInControlAnchor()),
  210. "[Add][ControlEdge] between %s and %s failed.",
  211. stream_switch->GetName().c_str(), out_node->GetName().c_str());
  212. out_node_list.insert(out_node->GetName());
  213. }
  214. GE_IF_BOOL_EXEC(stream_switch != nullptr, {
  215. MoveCtrlEdges(switch_node, stream_switch);
  216. switch_node_map_[stream_switch] = out_node_list;
  217. if (SetOriginalNodeName(stream_switch, switch_node->GetName()) != SUCCESS) {
  218. REPORT_CALL_ERROR("E19999", "Set original node name:%s to node:%s(%s) failed", switch_node->GetName().c_str(),
  219. stream_switch->GetName().c_str(), stream_switch->GetType().c_str());
  220. GELOGE(FAILED, "[Set][OriginalNodeName] for node %s failed.", stream_switch->GetName().c_str());
  221. return FAILED;
  222. }
  223. });
  224. }
  225. (void)bypass_nodes_.insert(switch_node);
  226. return SUCCESS;
  227. }
  228. ///
  229. /// @brief Bypass Switch Node
  230. /// @param [in] switch_node
  231. /// @param [out] peer_data_anchor
  232. /// @param [out] peer_cond_anchor
  233. /// @return Status
  234. ///
  235. Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor,
  236. OutDataAnchorPtr &peer_cond_anchor) {
  237. for (uint32_t idx = 0; idx < SWITCH_INPUT_NUM; ++idx) {
  238. InDataAnchorPtr in_data_anchor = switch_node->GetInDataAnchor(idx);
  239. GE_CHECK_NOTNULL(in_data_anchor);
  240. OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
  241. GE_CHECK_NOTNULL(peer_out_anchor);
  242. // Remove Switch data input.
  243. if (GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor) != GRAPH_SUCCESS) {
  244. REPORT_CALL_ERROR("E19999", "Remove edge between op:%s(%s)(index:%d) and op:%s(%s)(index:%u) failed",
  245. peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  246. peer_out_anchor->GetOwnerNode()->GetType().c_str(), peer_out_anchor->GetIdx(),
  247. switch_node->GetName().c_str(), switch_node->GetType().c_str(), idx);
  248. GELOGE(FAILED, "[Remove][Edge] between op:%s(%s)(index:%d) and op:%s(%s)(index:%u) failed",
  249. peer_out_anchor->GetOwnerNode()->GetName().c_str(),
  250. peer_out_anchor->GetOwnerNode()->GetType().c_str(), peer_out_anchor->GetIdx(),
  251. switch_node->GetName().c_str(), switch_node->GetType().c_str(), idx);
  252. return FAILED;
  253. }
  254. if (idx == SWITCH_DATA_INPUT) {
  255. peer_data_anchor = peer_out_anchor;
  256. } else {
  257. peer_cond_anchor = peer_out_anchor;
  258. }
  259. }
  260. return SUCCESS;
  261. }
  262. ///
  263. /// @brief Find Switch cond input
  264. /// @param [out] peer_cond_anchor
  265. /// @return Status
  266. ///
  267. Status SwitchToStreamSwitchPass::FindSwitchCondInput(OutDataAnchorPtr &peer_cond_anchor) {
  268. NodePtr tmp_node = nullptr;
  269. std::string type;
  270. bool pass_flag = true;
  271. while (pass_flag) {
  272. if (tmp_node == nullptr) {
  273. tmp_node = peer_cond_anchor->GetOwnerNode();
  274. } else {
  275. InDataAnchorPtr in_data_anchor = tmp_node->GetInDataAnchor(SWITCH_DATA_INPUT);
  276. GE_CHECK_NOTNULL(in_data_anchor);
  277. peer_cond_anchor = in_data_anchor->GetPeerOutAnchor();
  278. GE_CHECK_NOTNULL(peer_cond_anchor);
  279. tmp_node = peer_cond_anchor->GetOwnerNode();
  280. }
  281. GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "[Get][OriginalType] failed.");
  282. pass_flag = ((type == SWITCH) || (type == REFSWITCH));
  283. }
  284. return SUCCESS;
  285. }
  286. ///
  287. /// @brief Create StreamSwitch Node
  288. /// @param [in] graph
  289. /// @param [in] switch_node
  290. /// @param [in] suffix
  291. /// @param [in] peer_cond_anchor
  292. /// @return ge::NodePtr
  293. ///
  294. NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node,
  295. const std::string &suffix,
  296. const OutDataAnchorPtr &peer_cond_anchor) {
  297. OpDescPtr switch_op_desc = switch_node->GetOpDesc();
  298. GE_CHK_BOOL_EXEC(switch_op_desc != nullptr,
  299. REPORT_INNER_ERROR("E19999", "OpDesc in node is nullptr, check invalid");
  300. return nullptr, "[Get][OpDesc] failed, OpDesc of Switch node is invalid.");
  301. GE_IF_BOOL_EXEC(switch_op_desc->GetInputsSize() != SWITCH_INPUT_NUM, {
  302. REPORT_INNER_ERROR("E19999", "Input desc size:%zu of node:%s(%s) not equal to %u, check invalid",
  303. switch_op_desc->GetInputsSize(),
  304. switch_op_desc->GetName().c_str(), switch_op_desc->GetType().c_str(), SWITCH_INPUT_NUM);
  305. GELOGE(FAILED, "[Check][Param] Switch input param invalid, input_size=%lu, should be %u.",
  306. switch_op_desc->GetInputsSize(), SWITCH_INPUT_NUM);
  307. return nullptr;
  308. });
  309. const std::string &node_name = switch_node->GetName() + "_" + STREAMSWITCH + suffix;
  310. GELOGI("Create StreamSwitch, name=%s.", node_name.c_str());
  311. OpDescPtr op_desc = MakeShared<OpDesc>(node_name, STREAMSWITCH);
  312. if (op_desc == nullptr) {
  313. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  314. GELOGE(FAILED, "[New][OpDesc] failed.");
  315. return nullptr;
  316. }
  317. // mark hccl group id
  318. std::string hccl_group_id;
  319. if (AttrUtils::GetStr(switch_node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) {
  320. (void)AttrUtils::SetStr(op_desc, ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id);
  321. GELOGD("Set attr ATTR_NAME_HCCL_FUSED_GROUP for Stream_Switch %s, value is %s.", node_name.c_str(),
  322. hccl_group_id.c_str());
  323. }
  324. int64_t switch_type;
  325. if (AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_STREAM_SWITCH_TYPE, switch_type)) {
  326. (void)AttrUtils::SetInt(op_desc, ATTR_NAME_STREAM_SWITCH_TYPE, switch_type);
  327. GELOGD("Set attr ATTR_NAME_STREAM_SWITCH_TYPE for Stream_Switch %s, value is %ld.", node_name.c_str(),
  328. switch_type);
  329. }
  330. if (!AttrUtils::SetInt(op_desc, ATTR_NAME_SWITCH_DATA_TYPE, RT_SWITCH_INT32) ||
  331. !AttrUtils::SetInt(op_desc, ATTR_NAME_STREAM_SWITCH_COND, (int64_t)RT_EQUAL)) {
  332. REPORT_CALL_ERROR("E19999", "Set Attr:%s or Attr:%s to op:%s(%s) failed",
  333. ATTR_NAME_SWITCH_DATA_TYPE.c_str(), ATTR_NAME_STREAM_SWITCH_COND.c_str(),
  334. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  335. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s or Attr:%s to op:%s(%s) failed",
  336. ATTR_NAME_SWITCH_DATA_TYPE.c_str(), ATTR_NAME_STREAM_SWITCH_COND.c_str(),
  337. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  338. return nullptr;
  339. }
  340. // Already checked, first input is Variable will passed, second is condition will checked.
  341. GeTensorDesc cond_input_desc = switch_op_desc->GetInputDesc(SWITCH_PRED_INPUT);
  342. GeTensorDesc input_desc(GeShape(cond_input_desc.GetShape().GetDims()), cond_input_desc.GetFormat(), DT_INT32);
  343. GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(input_desc) == GRAPH_SUCCESS,
  344. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
  345. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  346. return nullptr,
  347. "[Add][InputDesc] to op:%s(%s) failed",
  348. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  349. GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(input_desc) == GRAPH_SUCCESS,
  350. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
  351. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  352. return nullptr,
  353. "[Add][InputDesc] to op:%s(%s) failed",
  354. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  355. NodePtr stream_switch = graph->AddNode(op_desc);
  356. GE_CHK_BOOL_EXEC(stream_switch != nullptr,
  357. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  358. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  359. return nullptr,
  360. "[Add][Node] %s(%s) to graph:%s failed",
  361. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  362. GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)),
  363. "[Add][Edge] between %s and %s failed.",
  364. peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str());
  365. int64_t group_index = -1;
  366. if (AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) {
  367. SetControlFlowGroup(stream_switch, group_index);
  368. }
  369. return stream_switch;
  370. }
  371. ///
  372. /// @brief Mark Switch Branch
  373. /// @param [in] peer_cond_anchor
  374. /// @param [in] stream_switch
  375. /// @param [in] true_branch_flag
  376. /// @return Status
  377. ///
  378. Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_anchor, const NodePtr &stream_switch,
  379. bool true_branch_flag) {
  380. uint32_t index = true_branch_flag ? SWITCH_TRUE_OUTPUT : SWITCH_FALSE_OUTPUT;
  381. auto it = cond_node_map_.find(peer_cond_anchor);
  382. if (it != cond_node_map_.end()) {
  383. int64_t switch_group_id = GetGroupId(stream_switch);
  384. auto switch_group_it = it->second.find(switch_group_id);
  385. if (switch_group_it == it->second.end()) {
  386. std::list<NodePtr> false_node_list;
  387. std::list<NodePtr> true_node_list;
  388. std::list<NodePtr> &node_list = true_branch_flag ? true_node_list : false_node_list;
  389. node_list.emplace_back(stream_switch);
  390. std::vector<std::list<NodePtr>> switch_list;
  391. switch_list.emplace_back(false_node_list);
  392. switch_list.emplace_back(true_node_list);
  393. it->second[switch_group_id] = switch_list;
  394. } else {
  395. GE_IF_BOOL_EXEC(switch_group_it->second.size() != SWITCH_OUTPUT_NUM, {
  396. REPORT_INNER_ERROR("E19999", "switch group size:%zu not equal to %u, group_id:%ld, check invalid",
  397. switch_group_it->second.size(), SWITCH_OUTPUT_NUM, switch_group_id);
  398. GELOGE(INTERNAL_ERROR, "[Check][Param] switch group size:%zu not equal to %u, group_id:%ld",
  399. switch_group_it->second.size(), SWITCH_OUTPUT_NUM, switch_group_id);
  400. return FAILED;
  401. });
  402. switch_group_it->second[index].emplace_back(stream_switch);
  403. }
  404. } else {
  405. int64_t switch_group_id = GetGroupId(stream_switch);
  406. std::map<int64_t, std::vector<std::list<NodePtr>>> switch_group_map;
  407. std::list<NodePtr> false_node_list;
  408. std::list<NodePtr> true_node_list;
  409. std::list<NodePtr> &node_list = true_branch_flag ? true_node_list : false_node_list;
  410. node_list.emplace_back(stream_switch);
  411. std::vector<std::list<NodePtr>> switch_list;
  412. switch_list.emplace_back(false_node_list);
  413. switch_list.emplace_back(true_node_list);
  414. switch_group_map[switch_group_id] = switch_list;
  415. cond_node_map_[peer_cond_anchor] = switch_group_map;
  416. }
  417. return SUCCESS;
  418. }
  419. ///
  420. /// @brief Get group_id for switch_node
  421. /// @param [in] node
  422. /// @return group_id
  423. ///
  424. int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) {
  425. std::string tailing_optimization_option;
  426. bool is_tailing_optimization = false;
  427. if (GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option) == GRAPH_SUCCESS) {
  428. // "1" means it's True from frontend option
  429. is_tailing_optimization = (tailing_optimization_option == "1");
  430. GELOGI("Option ge.exec.isTailingOptimization is %s", tailing_optimization_option.c_str());
  431. }
  432. if (!is_tailing_optimization) {
  433. return 0;
  434. }
  435. std::string hccl_group_id;
  436. if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) {
  437. GELOGI("Node %s can not find hccl group id.", node->GetName().c_str());
  438. return 0;
  439. }
  440. auto key_index = hccl_group_id.find_last_of('_');
  441. auto key_num = hccl_group_id.substr(key_index + 1, hccl_group_id.length() - key_index);
  442. GELOGI("Node:%s, hccl_group_id=%s, key_num=%s", node->GetName().c_str(), hccl_group_id.c_str(), key_num.c_str());
  443. int64_t num = atoi(key_num.c_str());
  444. if (num == 0) {
  445. return 0;
  446. }
  447. GELOGI("Hccl_group_id is %s, group_id is %ld", hccl_group_id.c_str(), num);
  448. return num;
  449. }
  450. ///
  451. /// @brief Combine switch nodes link to same cond
  452. /// @param [in] graph
  453. /// @return Status
  454. ///
  455. Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) {
  456. for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) {
  457. for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) {
  458. const std::list<NodePtr> &false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT];
  459. const std::list<NodePtr> &true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT];
  460. std::set<NodePtr> same_cond_switch;
  461. same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end());
  462. same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end());
  463. OutDataAnchorPtr peer_cond_anchor = iter->first;
  464. GE_CHECK_NOTNULL(peer_cond_anchor);
  465. NodePtr cond_node = peer_cond_anchor->GetOwnerNode();
  466. GELOGI("CombineSwitchNode: cond_node=%s.", cond_node->GetName().c_str());
  467. NodePtr cast_node = CreateCastOp(graph, peer_cond_anchor);
  468. GE_CHK_BOOL_EXEC(cast_node != nullptr, return FAILED,
  469. "[Create][CastOp] for cond_node:%s failed.", cond_node->GetName().c_str());
  470. NodePtr active_node = CreateActiveNode(graph, cond_node);
  471. GE_CHK_BOOL_EXEC(active_node != nullptr, return FAILED,
  472. "[Create][StreamActiveNode] for cond node:%s failed.", cond_node->GetName().c_str());
  473. GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutControlAnchor(), active_node->GetInControlAnchor()),
  474. "[Add][ControlEdge] between %s and %s failed.",
  475. cond_node->GetName().c_str(), active_node->GetName().c_str());
  476. if (SetActiveLabelList(active_node, { cast_node->GetName() }) != SUCCESS) {
  477. REPORT_CALL_ERROR("E19999", "Set active label list:%s to op:%s(%s) failed",
  478. cast_node->GetName().c_str(), active_node->GetName().c_str(), active_node->GetType().c_str());
  479. GELOGE(FAILED, "[Set][ActiveLabelList] %s to op:%s(%s) failed.",
  480. cast_node->GetName().c_str(), active_node->GetName().c_str(), active_node->GetType().c_str());
  481. return FAILED;
  482. }
  483. int64_t group_index = -1;
  484. std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) {
  485. return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index);
  486. };
  487. (void)std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback);
  488. SetControlFlowGroup(active_node, group_index);
  489. const std::string &cond_group = cond_node->GetName();
  490. for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) {
  491. bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT);
  492. const std::list<NodePtr> &switch_list = (true_branch_flag ? true_switch_list : false_switch_list);
  493. GE_IF_BOOL_EXEC(switch_list.empty(), continue);
  494. // select first stream_switch
  495. NodePtr stream_switch = switch_list.front();
  496. OpDescPtr switch_desc = stream_switch->GetOpDesc();
  497. GE_CHECK_NOTNULL(switch_desc);
  498. // set stream_label
  499. if (SetStreamLabel(stream_switch, cast_node->GetName()) != SUCCESS) {
  500. REPORT_CALL_ERROR("E19999", "Set stream_label:%s to op:%s(%s) failed",
  501. cast_node->GetName().c_str(), stream_switch->GetName().c_str(),
  502. stream_switch->GetType().c_str());
  503. GELOGE(FAILED, "[Set][StreamLabel] %s to op:%s(%s) failed", cast_node->GetName().c_str(),
  504. stream_switch->GetName().c_str(), stream_switch->GetType().c_str());
  505. return FAILED;
  506. }
  507. switch_desc->SetName(CheckDuplicateName(cond_group + "/" + STREAMSWITCH + (true_branch_flag ? "_t" : "_f")));
  508. stream_switch_nodes_.emplace_back(stream_switch);
  509. // 0_input: original pred input, 1_input: constant node
  510. GE_CHK_STATUS_RET(AddConstNode(graph, stream_switch),
  511. "[Add][ConstNode] failed, stream_switch:%s.", stream_switch->GetName().c_str());
  512. GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)),
  513. "[Remove][Edge] between %s and %s failed.",
  514. peer_cond_anchor->GetOwnerNode()->GetName().c_str(), stream_switch->GetName().c_str());
  515. GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)),
  516. "[Add][Edge] between %s and %s failed.",
  517. cast_node->GetName().c_str(), stream_switch->GetName().c_str());
  518. SetControlFlowGroup(stream_switch, group_index);
  519. for (const NodePtr &node : switch_list) {
  520. GE_IF_BOOL_EXEC(node != stream_switch, {
  521. GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)),
  522. "[Remove][Edge] between %s and %s failed.",
  523. peer_cond_anchor->GetOwnerNode()->GetName().c_str(), node->GetName().c_str());
  524. });
  525. GE_CHK_STATUS(ModifySwitchInCtlEdges(node, cast_node, same_cond_switch),
  526. "[Modify][SwitchInCtlEdges] failed, switch node:%s, cast node:%s.",
  527. node->GetName().c_str(), cast_node->GetName().c_str());
  528. GE_CHK_STATUS(ModifySwitchOutCtlEdges(node, stream_switch, active_node),
  529. "[Modify][SwitchOutCtlEdges] failed, node:%s, stream_switch:%s.",
  530. node->GetName().c_str(), stream_switch->GetName().c_str());
  531. }
  532. GE_CHK_STATUS(GraphUtils::AddEdge(active_node->GetOutControlAnchor(), stream_switch->GetInControlAnchor()),
  533. "[Add][ControlEdge] between %s and %s failed.",
  534. active_node->GetName().c_str(), stream_switch->GetName().c_str());
  535. }
  536. }
  537. }
  538. return SUCCESS;
  539. }
  540. ///
  541. /// @brief Create Active Op
  542. /// @param [in] graph
  543. /// @param [in] cond_node
  544. /// @return ge::NodePtr
  545. ///
  546. NodePtr SwitchToStreamSwitchPass::CreateActiveNode(const ComputeGraphPtr &graph, const NodePtr &node) {
  547. const std::string &node_name = CheckDuplicateName(node->GetName() + "_" + STREAMACTIVE);
  548. GELOGI("Create StreamActive op:%s.", node_name.c_str());
  549. OpDescPtr op_desc = MakeShared<OpDesc>(node_name, STREAMACTIVE);
  550. if (op_desc == nullptr) {
  551. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  552. GELOGE(FAILED, "[New][OpDesc] failed.");
  553. return nullptr;
  554. }
  555. NodePtr active_node = graph->AddNode(op_desc);
  556. GE_CHK_BOOL_EXEC(active_node != nullptr,
  557. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  558. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  559. return nullptr,
  560. "[Add][Node] %s(%s) to graph:%s failed",
  561. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  562. GE_IF_BOOL_EXEC(SetSwitchBranchNodeLabel(active_node, node_name) != SUCCESS,
  563. REPORT_CALL_ERROR("E19999", "Set switch branch node label:%s to node:%s(%s) failed",
  564. node_name.c_str(), active_node->GetName().c_str(), active_node->GetType().c_str());
  565. GELOGE(INTERNAL_ERROR, "[Set][SwitchBranchNodeLabel] %s to node:%s(%s) failed",
  566. node_name.c_str(), active_node->GetName().c_str(), active_node->GetType().c_str());
  567. return nullptr);
  568. return active_node;
  569. }
  570. ///
  571. /// @brief Create cast node
  572. /// @param [in] graph
  573. /// @param [in] peer_cond_anchor
  574. /// @return NodePtr
  575. ///
  576. NodePtr SwitchToStreamSwitchPass::CreateCastOp(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_cond_anchor) {
  577. OpDescPtr cond_desc = peer_cond_anchor->GetOwnerNode()->GetOpDesc();
  578. GE_CHK_BOOL_EXEC(cond_desc != nullptr, return nullptr,
  579. "[Get][OpDesc] failed, opdesc of Param peer_cond_anchor's owner node is nullptr.");
  580. const std::string &cast_name = CheckDuplicateName(cond_desc->GetName() + "_" + CAST);
  581. GELOGI("Create cast_node: %s, input datatype:DT_BOOL, out datatype:DT_INT32", cast_name.c_str());
  582. OpDescPtr cast_desc = MakeShared<OpDesc>(cast_name, CAST);
  583. if (cast_desc == nullptr) {
  584. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  585. GELOGE(FAILED, "[New][OpDesc] failed.");
  586. return nullptr;
  587. }
  588. if (!(AttrUtils::SetInt(cast_desc, CAST_ATTR_SRCT, (int64_t)DT_BOOL) &&
  589. AttrUtils::SetInt(cast_desc, CAST_ATTR_DSTT, (int64_t)DT_INT32) &&
  590. AttrUtils::SetInt(cast_desc, CAST_ATTR_DST_TYPE, (int64_t)DT_INT32) &&
  591. AttrUtils::SetBool(cast_desc, CAST_ATTR_TRUNCATE, false))) {
  592. REPORT_CALL_ERROR("E19999", "Set Attr:%s or %s or %s or %s to op:%s(%s) failed",
  593. CAST_ATTR_SRCT.c_str(), CAST_ATTR_DSTT.c_str(),
  594. CAST_ATTR_DST_TYPE.c_str(), CAST_ATTR_TRUNCATE.c_str(),
  595. cast_desc->GetName().c_str(), cast_desc->GetType().c_str());
  596. GELOGE(FAILED, "[Set][Attr] %s or %s or %s or %s to op:%s(%s) failed",
  597. CAST_ATTR_SRCT.c_str(), CAST_ATTR_DSTT.c_str(),
  598. CAST_ATTR_DST_TYPE.c_str(), CAST_ATTR_TRUNCATE.c_str(),
  599. cast_desc->GetName().c_str(), cast_desc->GetType().c_str());
  600. return nullptr;
  601. }
  602. GeTensorDesc tensor_desc = cond_desc->GetOutputDesc(peer_cond_anchor->GetIdx());
  603. tensor_desc.SetDataType(DT_BOOL);
  604. GE_CHK_BOOL_EXEC(cast_desc->AddInputDesc(tensor_desc) == SUCCESS,
  605. REPORT_CALL_ERROR("E19999", "Add input desc to op:%s(%s) failed",
  606. cast_desc->GetName().c_str(), cast_desc->GetType().c_str());
  607. return nullptr,
  608. "[Add][InputDesc] to op:%s(%s) failed",
  609. cast_desc->GetName().c_str(), cast_desc->GetType().c_str());
  610. tensor_desc.SetDataType(DT_INT32);
  611. GE_CHK_BOOL_EXEC(cast_desc->AddOutputDesc(tensor_desc) == SUCCESS,
  612. REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed",
  613. cast_desc->GetName().c_str(), cast_desc->GetType().c_str());
  614. return nullptr,
  615. "[Add][OutputDesc] to op:%s(%s) failed",
  616. cast_desc->GetName().c_str(), cast_desc->GetType().c_str());
  617. NodePtr cast_node = graph->AddNode(cast_desc);
  618. GE_CHK_BOOL_EXEC(cast_node != nullptr,
  619. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  620. cast_desc->GetName().c_str(), cast_desc->GetType().c_str(),
  621. graph->GetName().c_str());
  622. return nullptr,
  623. "[Add][Node] %s(%s) to graph:%s failed",
  624. cast_desc->GetName().c_str(), cast_desc->GetType().c_str(),
  625. graph->GetName().c_str());
  626. // Cast node has and only has one input
  627. GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, cast_node->GetInDataAnchor(0)),
  628. "[Add][Edge] between %s and %s failed.",
  629. cond_desc->GetName().c_str(), cast_node->GetName().c_str());
  630. return cast_node;
  631. }
  632. ///
  633. /// @brief Add const node as switch input1
  634. /// @param [in] graph
  635. /// @param [in] stream_switch
  636. /// @return Status
  637. ///
  638. Status SwitchToStreamSwitchPass::AddConstNode(const ComputeGraphPtr &graph, const NodePtr &stream_switch) {
  639. OpDescPtr op_desc = stream_switch->GetOpDesc();
  640. GE_CHECK_NOTNULL(op_desc);
  641. bool value = false;
  642. GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value),
  643. REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed",
  644. ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG.c_str(),
  645. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  646. return FAILED,
  647. "[Get][Attr] %s from op:%s(%s) failed", ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG.c_str(),
  648. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  649. const std::string &const_node_name = op_desc->GetName() + "_Constant_" + (value ? "t" : "f");
  650. GELOGI("Create const op: %s", const_node_name.c_str());
  651. OpDescPtr const_op_desc = MakeShared<OpDesc>(const_node_name, CONSTANT);
  652. if (const_op_desc == nullptr) {
  653. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  654. GELOGE(FAILED, "New OpDesc failed.");
  655. return FAILED;
  656. }
  657. auto resize_value = (int32_t)value;
  658. GeTensorDesc data_desc = op_desc->GetInputDesc(1);
  659. GeTensorPtr const_value =
  660. MakeShared<GeTensor>(data_desc, reinterpret_cast<uint8_t *>(&resize_value), sizeof(int32_t));
  661. if (const_value == nullptr) {
  662. REPORT_CALL_ERROR("E19999", "New GeTensor failed");
  663. GELOGE(FAILED, "[New][GeTensor] failed.");
  664. return FAILED;
  665. }
  666. GE_CHK_BOOL_EXEC(AttrUtils::SetTensor(const_op_desc, ATTR_NAME_WEIGHTS, const_value),
  667. REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_NAME_WEIGHTS.c_str(),
  668. const_op_desc->GetName().c_str(), const_op_desc->GetType().c_str());
  669. return FAILED,
  670. "[Get][Attr] %s from op:%s(%s) failed", ATTR_NAME_WEIGHTS.c_str(),
  671. const_op_desc->GetName().c_str(), const_op_desc->GetType().c_str());
  672. GE_CHK_BOOL_EXEC(const_op_desc->AddOutputDesc(data_desc) == GRAPH_SUCCESS,
  673. REPORT_CALL_ERROR("E19999", "Add output desc to op:%s(%s) failed",
  674. const_op_desc->GetName().c_str(), const_op_desc->GetType().c_str());
  675. return FAILED,
  676. "[Add][OutputDesc] to op:%s(%s) failed",
  677. const_op_desc->GetName().c_str(), const_op_desc->GetType().c_str());
  678. NodePtr const_node = graph->AddNode(const_op_desc);
  679. GE_CHK_BOOL_EXEC(const_node != nullptr,
  680. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  681. const_op_desc->GetName().c_str(), const_op_desc->GetType().c_str(),
  682. graph->GetName().c_str());
  683. return FAILED,
  684. "[Add][Node] %s(%s) to graph:%s failed",
  685. const_op_desc->GetName().c_str(), const_op_desc->GetType().c_str(),
  686. graph->GetName().c_str());
  687. GE_CHK_STATUS(GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(1)),
  688. "[Add][Edge] between %s and %s failed.",
  689. const_node->GetName().c_str(), stream_switch->GetName().c_str());
  690. return SUCCESS;
  691. }
  692. ///
  693. /// @brief Modify in ctl edge for switch_node
  694. /// @param [in] switch_node
  695. /// @param [in] cast_node
  696. /// @param [in] same_cond_switch
  697. /// @return Status
  698. ///
  699. Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_node, const NodePtr &cast_node,
  700. const std::set<NodePtr> &same_cond_switch) {
  701. GELOGD("ModifySwitchInCtlEdges: switch_node=%s, cast_node=%s", switch_node->GetName().c_str(),
  702. cast_node->GetName().c_str());
  703. std::string orig_switch_name = switch_node->GetName();
  704. OpDescPtr switch_desc = switch_node->GetOpDesc();
  705. GE_CHECK_NOTNULL(switch_desc);
  706. if (!AttrUtils::GetStr(switch_desc, ATTR_NAME_ORIG_NODE_NAME, orig_switch_name) || orig_switch_name.empty()) {
  707. REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_NAME_ORIG_NODE_NAME.c_str(),
  708. switch_desc->GetName().c_str(), switch_desc->GetType().c_str());
  709. GELOGE(INTERNAL_ERROR, "[Get][Attr] %s from op:%s(%s) failed", ATTR_NAME_ORIG_NODE_NAME.c_str(),
  710. switch_desc->GetName().c_str(), switch_desc->GetType().c_str());
  711. return INTERNAL_ERROR;
  712. }
  713. for (const NodePtr &in_ctrl_node : switch_node->GetInControlNodes()) {
  714. GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()),
  715. "[Remove][ControlEdge] between %s and %s failed.",
  716. in_ctrl_node->GetName().c_str(), switch_node->GetName().c_str());
  717. GE_IF_BOOL_EXEC(!in_ctrl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), {
  718. GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()),
  719. "[Add][ControlEdge] between %s and %s failed.",
  720. in_ctrl_node->GetName().c_str(), cast_node->GetName().c_str());
  721. });
  722. GE_IF_BOOL_EXEC(in_ctrl_node->GetType() != STREAMSWITCH, continue);
  723. if (same_cond_switch.count(in_ctrl_node) > 0) {
  724. GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()),
  725. "[Remove][ControlEdge] between %s and %s failed.",
  726. in_ctrl_node->GetName().c_str(), cast_node->GetName().c_str());
  727. continue;
  728. }
  729. auto find_res1 = switch_node_map_.find(in_ctrl_node);
  730. GE_IF_BOOL_EXEC(find_res1 == switch_node_map_.end(), {
  731. REPORT_INNER_ERROR("E19999", "Node:%s(%s) can't find in switch_node_map_, check invalid",
  732. in_ctrl_node->GetName().c_str(), in_ctrl_node->GetType().c_str());
  733. GELOGE(INTERNAL_ERROR, "[Check][Param] StreamSwitch node %s not found in switch_node_map_.",
  734. in_ctrl_node->GetName().c_str());
  735. return INTERNAL_ERROR;
  736. });
  737. auto find_res2 = find_res1->second.find(orig_switch_name);
  738. auto find_res3 = find_res1->second.find(cast_node->GetName());
  739. GE_IF_BOOL_EXEC((find_res2 != find_res1->second.end()) && (find_res3 == find_res1->second.end()), {
  740. find_res1->second.erase(find_res2);
  741. find_res1->second.insert(cast_node->GetName());
  742. continue;
  743. });
  744. }
  745. return SUCCESS;
  746. }
  747. ///
  748. /// @brief Modify out ctl edge for switch_node
  749. /// @param [in] switch_node
  750. /// @param [in] stream_switch
  751. /// @param [in] active_node
  752. /// @return Status
  753. ///
  754. Status SwitchToStreamSwitchPass::ModifySwitchOutCtlEdges(const NodePtr &switch_node, const NodePtr &stream_switch,
  755. const NodePtr &active_node) {
  756. GELOGD("ModifySwitchOutCtlEdges: switch_node=%s, stream_switch=%s, active_node=%s", switch_node->GetName().c_str(),
  757. stream_switch->GetName().c_str(), active_node->GetName().c_str());
  758. auto find_res = switch_node_map_.find(switch_node);
  759. GE_IF_BOOL_EXEC(find_res == switch_node_map_.end(), {
  760. REPORT_INNER_ERROR("E19999", "Node:%s(%s) can't find in switch_node_map_, check invalid",
  761. switch_node->GetName().c_str(), switch_node->GetType().c_str());
  762. GELOGE(INTERNAL_ERROR, "[Check][Param] StreamSwitch node %s not found in switch_node_map_.",
  763. switch_node->GetName().c_str());
  764. return INTERNAL_ERROR;
  765. });
  766. GE_IF_BOOL_EXEC(find_res->second.empty(), {
  767. REPORT_INNER_ERROR("E19999", "True_nodes of StreamSwitch node:%s(%s) is empty, check invalid",
  768. switch_node->GetName().c_str(), switch_node->GetType().c_str());
  769. GELOGE(INTERNAL_ERROR, "[Check][Param] true_nodes of StreamSwitch node %s is empty.",
  770. switch_node->GetName().c_str());
  771. return INTERNAL_ERROR;
  772. });
  773. for (const NodePtr &node : switch_node->GetOutControlNodes()) {
  774. OpDescPtr op_desc = node->GetOpDesc();
  775. GE_CHECK_NOTNULL(op_desc);
  776. GE_CHK_STATUS(GraphUtils::RemoveEdge(switch_node->GetOutControlAnchor(), node->GetInControlAnchor()),
  777. "[Remove][ControlEdge] between %s and %s failed.",
  778. switch_node->GetName().c_str(), node->GetName().c_str());
  779. std::string orig_name = op_desc->GetName();
  780. GE_IF_BOOL_EXEC(op_desc->HasAttr(ATTR_NAME_ORIG_NODE_NAME), {
  781. if (!AttrUtils::GetStr(op_desc, ATTR_NAME_ORIG_NODE_NAME, orig_name) || orig_name.empty()) {
  782. REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_NAME_ORIG_NODE_NAME.c_str(),
  783. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  784. GELOGE(INTERNAL_ERROR, "[Get][Attr] %s from op:%s(%s) failed", ATTR_NAME_ORIG_NODE_NAME.c_str(),
  785. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  786. return INTERNAL_ERROR;
  787. }
  788. });
  789. if (find_res->second.find(orig_name) == find_res->second.end()) {
  790. auto active_out_ctrl_anchor = active_node->GetOutControlAnchor();
  791. GE_CHECK_NOTNULL(active_out_ctrl_anchor);
  792. GE_IF_BOOL_EXEC(!active_out_ctrl_anchor->IsLinkedWith(node->GetInControlAnchor()), {
  793. GE_CHK_STATUS(GraphUtils::AddEdge(active_out_ctrl_anchor, node->GetInControlAnchor()),
  794. "[Add][ControlEdge] between %s and %s failed.",
  795. active_node->GetName().c_str(), node->GetName().c_str());
  796. });
  797. } else {
  798. auto switch_out_ctrl_anchor = stream_switch->GetOutControlAnchor();
  799. GE_CHECK_NOTNULL(switch_out_ctrl_anchor);
  800. GE_IF_BOOL_EXEC(!switch_out_ctrl_anchor->IsLinkedWith(node->GetInControlAnchor()), {
  801. GE_CHK_STATUS(GraphUtils::AddEdge(switch_out_ctrl_anchor, node->GetInControlAnchor()),
  802. "[Add][ControlEdge] between %s and %s failed.",
  803. stream_switch->GetName().c_str(), node->GetName().c_str());
  804. });
  805. }
  806. }
  807. GE_IF_BOOL_EXEC(switch_node != stream_switch, (void)bypass_nodes_.insert(switch_node));
  808. return SUCCESS;
  809. }
  810. ///
  811. /// @brief Check duplicate node_name
  812. /// @param [in] node_name
  813. /// @return std::string
  814. ///
  815. std::string SwitchToStreamSwitchPass::CheckDuplicateName(const std::string &node_name) {
  816. std::string tmp_name = node_name;
  817. auto iter = node_num_map_.find(tmp_name);
  818. if (iter != node_num_map_.end()) {
  819. tmp_name = tmp_name + "_" + std::to_string(iter->second);
  820. (iter->second)++;
  821. } else {
  822. node_num_map_[tmp_name] = 1;
  823. }
  824. return tmp_name;
  825. }
  826. ///
  827. /// @brief Move Control Edges
  828. /// @param [in] old_node
  829. /// @param [in] new_node
  830. /// @return void
  831. ///
  832. void SwitchToStreamSwitchPass::MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new_node) {
  833. GE_IF_BOOL_EXEC(old_node == new_node, return );
  834. auto iter = switch_cyclic_map_.find(old_node);
  835. bool check_flag = (iter != switch_cyclic_map_.end());
  836. for (const NodePtr &in_node : old_node->GetInControlNodes()) {
  837. auto out_ctrl_anchor = in_node->GetOutControlAnchor();
  838. GE_CHECK_NOTNULL_JUST_RETURN(out_ctrl_anchor);
  839. if (check_flag && (iter->second.count(in_node->GetName()) > 0)) {
  840. for (const auto &out_node : old_node->GetOutAllNodes()) {
  841. GE_IF_BOOL_EXEC(!out_ctrl_anchor->IsLinkedWith(out_node->GetInControlAnchor()), {
  842. GE_CHK_STATUS(GraphUtils::AddEdge(out_ctrl_anchor, out_node->GetInControlAnchor()),
  843. "[Add][ControlEdge] between %s and %s failed.",
  844. in_node->GetName().c_str(), out_node->GetName().c_str());
  845. });
  846. }
  847. } else {
  848. GE_IF_BOOL_EXEC(!out_ctrl_anchor->IsLinkedWith(new_node->GetInControlAnchor()), {
  849. GE_CHK_STATUS(GraphUtils::AddEdge(out_ctrl_anchor, new_node->GetInControlAnchor()),
  850. "[Add][ControlEdge] between %s and %s failed.",
  851. in_node->GetName().c_str(), new_node->GetName().c_str());
  852. });
  853. }
  854. GE_CHK_STATUS(GraphUtils::RemoveEdge(out_ctrl_anchor, old_node->GetInControlAnchor()),
  855. "[Remove][ControlEdge] between %s and %s failed.",
  856. in_node->GetName().c_str(), old_node->GetName().c_str());
  857. }
  858. for (const NodePtr &out_node : old_node->GetOutControlNodes()) {
  859. GE_IF_BOOL_EXEC(!new_node->GetOutControlAnchor()->IsLinkedWith(out_node->GetInControlAnchor()), {
  860. GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), out_node->GetInControlAnchor()),
  861. "[Add][ControlEdge] between %s and %s failed.",
  862. new_node->GetName().c_str(), out_node->GetName().c_str());
  863. });
  864. GE_CHK_STATUS(GraphUtils::RemoveEdge(old_node->GetOutControlAnchor(), out_node->GetInControlAnchor()),
  865. "[Remove][ControlEdge] between %s and %s failed.",
  866. old_node->GetName().c_str(), out_node->GetName().c_str());
  867. }
  868. }
  869. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示