You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

control_trigger_pass.cc 22 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/control_trigger_pass.h"
  17. #include <stack>
  18. #include "common/ge/ge_util.h"
  19. #include "common/omg_util.h"
  20. #include "graph/utils/type_utils.h"
  21. namespace ge {
  22. Status ControlTriggerPass::Run(ComputeGraphPtr graph) {
  23. GELOGD("ControlTriggerPass Enter");
  24. for (NodePtr &node : graph->GetDirectNode()) {
  25. if (node->GetType() != CONTROLTRIGGER) {
  26. continue;
  27. }
  28. auto in_ctrl_nodes = node->GetInControlNodes();
  29. for (NodePtr &in_ctrl_node : in_ctrl_nodes) {
  30. if (HandleDynamicCtrlEdges(graph, node, in_ctrl_node) != SUCCESS) {
  31. GELOGE(FAILED, "[Handle][DynamicCtrlEdges] for node:%s->node:%s failed.", in_ctrl_node->GetName().c_str(),
  32. node->GetName().c_str());
  33. return FAILED;
  34. }
  35. }
  36. }
  37. GELOGD("ControlTriggerPass Leave");
  38. return SUCCESS;
  39. }
  40. ///
  41. /// @brief Handle input ctrl edges for ControlTrigger node
  42. /// @param [in] graph
  43. /// @param [in] node
  44. /// @param [in] in_ctrl_node
  45. /// @return Status
  46. ///
  47. Status ControlTriggerPass::HandleDynamicCtrlEdges(ComputeGraphPtr &graph, NodePtr &node, NodePtr &in_ctrl_node) {
  48. GE_CHECK_NOTNULL(node);
  49. GE_CHECK_NOTNULL(in_ctrl_node);
  50. GELOGI("HandleDynamicCtrlEdges: node=%s, in_ctrl_node=%s", node->GetName().c_str(), in_ctrl_node->GetName().c_str());
  51. NodePtr switch_node = nullptr;
  52. bool branch_flag = false;
  53. if (FindSwitchNode(in_ctrl_node, switch_node, branch_flag) != SUCCESS) {
  54. GELOGE(FAILED, "[Find][SwitchNode] failed, in_ctrl_node:%s.", in_ctrl_node->GetName().c_str());
  55. return FAILED;
  56. }
  57. if (switch_node == nullptr) {
  58. GELOGI("Not find valid switch node.");
  59. return SUCCESS;
  60. }
  61. auto iter1 = control_trigger_map_.find(node);
  62. if (iter1 != control_trigger_map_.end()) {
  63. auto iter2 = iter1->second.find(switch_cond_map_[switch_node]);
  64. if (iter2 != iter1->second.end()) {
  65. NodePtr constant = (branch_flag ? iter2->second.second : iter2->second.first);
  66. if ((GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), node->GetInControlAnchor()) != GRAPH_SUCCESS) ||
  67. (GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), constant->GetInControlAnchor()) != GRAPH_SUCCESS)) {
  68. REPORT_CALL_ERROR("E19999", "Remove control edge between op:%s(%s) and op:%s(%s), then "
  69. "add control edge between op:%s(%s) and op:%s(%s) failed",
  70. in_ctrl_node->GetName().c_str(), in_ctrl_node->GetType().c_str(),
  71. node->GetName().c_str(), node->GetType().c_str(),
  72. in_ctrl_node->GetName().c_str(), in_ctrl_node->GetType().c_str(),
  73. constant->GetName().c_str(), constant->GetType().c_str());
  74. GELOGE(FAILED, "[Replace][CtrlEdge] failed, remove edge:%s->%s, add edge:%s->%s.",
  75. in_ctrl_node->GetName().c_str(), node->GetName().c_str(),
  76. in_ctrl_node->GetName().c_str(), constant->GetName().c_str());
  77. return FAILED;
  78. }
  79. GELOGI("No need to insert new branch.");
  80. return SUCCESS;
  81. }
  82. }
  83. if (InsertOppositeBranch(graph, node, in_ctrl_node, switch_node, branch_flag) != SUCCESS) {
  84. GELOGE(FAILED, "[Insert][OppositeBranch] failed, node:%s, in_ctrl_node:%s.",
  85. node->GetName().c_str(), in_ctrl_node->GetName().c_str());
  86. return FAILED;
  87. }
  88. return SUCCESS;
  89. }
  90. ///
  91. /// @brief Find switch_node for ControlTrigger node
  92. /// @param [in] node
  93. /// @param [out] switch_node
  94. /// @param [out] branch_flag
  95. /// @return Status
  96. ///
  97. Status ControlTriggerPass::FindSwitchNode(const NodePtr &node, NodePtr &switch_node, bool &branch_flag) {
  98. std::set<std::pair<NodePtr, uint32_t>> handle_nodes;
  99. // {node, <idx, <cond_merge_num, loop_switchf_num>>}
  100. std::stack<std::pair<NodePtr, std::pair<uint32_t, std::pair<uint32_t, uint32_t>>>> nodes;
  101. nodes.push(std::make_pair(node, std::make_pair(UINT32_MAX, std::make_pair(0, 0))));
  102. std::set<std::pair<NodePtr, uint32_t>> in_nodes;
  103. while (!nodes.empty()) {
  104. auto iter = nodes.top();
  105. NodePtr tmp_node = iter.first;
  106. GE_CHECK_NOTNULL(tmp_node);
  107. nodes.pop();
  108. uint32_t index = iter.second.first;
  109. auto num_pair = iter.second.second;
  110. if (handle_nodes.count(std::make_pair(tmp_node, index)) > 0) {
  111. continue;
  112. }
  113. switch (TransferNodeType(tmp_node, index)) {
  114. case kCondSwitch:
  115. if (num_pair.first == 0) {
  116. switch_node = tmp_node;
  117. branch_flag = (index == SWITCH_TRUE_OUTPUT);
  118. GELOGI("FindSwitchNode succ, switch_node=%s, idx=%u", switch_node->GetName().c_str(), index);
  119. return SUCCESS;
  120. }
  121. num_pair.first--;
  122. break;
  123. case kCondMerge:
  124. num_pair.first++;
  125. break;
  126. case kLoopSwitchT:
  127. GELOGI("in while_body, no need handle");
  128. return SUCCESS;
  129. case kLoopSwitchF:
  130. num_pair.second++;
  131. break;
  132. case kEnter:
  133. if (num_pair.second > 0) {
  134. num_pair.second--;
  135. }
  136. break;
  137. case kNotControlOp:
  138. break;
  139. default:
  140. GELOGE(FAILED, "[Check][Param] invalid node type");
  141. return FAILED;
  142. }
  143. GetInNodes(tmp_node, in_nodes);
  144. for (auto &node_idx : in_nodes) {
  145. nodes.push(std::make_pair(node_idx.first, std::make_pair(node_idx.second, num_pair)));
  146. }
  147. (void)handle_nodes.insert(std::make_pair(tmp_node, index));
  148. }
  149. return SUCCESS;
  150. }
  151. ///
  152. /// @brief Check if need insert opposite branch
  153. /// @param [in] node
  154. /// @param [in] index
  155. /// @return ControlNodeType
  156. ///
  157. ControlNodeType ControlTriggerPass::TransferNodeType(const NodePtr &node, uint32_t index) {
  158. OpDescPtr merge_desc = node->GetOpDesc();
  159. if (merge_desc == nullptr) {
  160. REPORT_INNER_ERROR("E19999", "op_desc in merge node is nullptr, check invalid");
  161. GELOGE(INTERNAL_ERROR, "[Get][OpDesc] failed, merge_desc is nullptr.");
  162. return kInvalidType;
  163. }
  164. const std::string type = node->GetType();
  165. if ((type == SWITCH) || (type == REFSWITCH)) {
  166. if ((index != SWITCH_TRUE_OUTPUT) && (index != SWITCH_FALSE_OUTPUT)) {
  167. GELOGI("TransferNodeType: neither true nor false branch.");
  168. return kNotControlOp;
  169. }
  170. if (FindPredInput(node) != SUCCESS) {
  171. GELOGE(INTERNAL_ERROR, "[Find][PredInput] failed, switch_node:%s.", node->GetName().c_str());
  172. return kInvalidType;
  173. }
  174. NodePtr pred_node = switch_cond_map_[node];
  175. bool branch_flag = (index == SWITCH_TRUE_OUTPUT);
  176. if (pred_node->GetType() != LOOPCOND) {
  177. GELOGI("TransferNodeType: kCondSwitch node=%s, idx=%u", node->GetName().c_str(), index);
  178. return kCondSwitch;
  179. } else {
  180. GELOGI("TransferNodeType: kLoopSwitch node=%s, idx=%u", node->GetName().c_str(), index);
  181. return branch_flag ? kLoopSwitchT : kLoopSwitchF;
  182. }
  183. } else if ((type == MERGE) || (type == REFMERGE)) {
  184. if (!merge_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) {
  185. return kCondMerge;
  186. }
  187. } else if ((type == ENTER) || (type == REFENTER)) {
  188. return kEnter;
  189. }
  190. return kNotControlOp;
  191. }
  192. ///
  193. /// @brief Get in_node & idx pairs
  194. /// @param [in] node
  195. /// @param [out] in_nodes
  196. /// @return void
  197. ///
  198. void ControlTriggerPass::GetInNodes(const NodePtr &node, std::set<std::pair<NodePtr, uint32_t>> &in_nodes) {
  199. in_nodes.clear();
  200. for (auto &in_ctrl_node : node->GetInControlNodes()) {
  201. (void)in_nodes.insert(std::make_pair(in_ctrl_node, UINT32_MAX));
  202. }
  203. for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) {
  204. OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor();
  205. if (peer_out_anchor == nullptr) {
  206. continue;
  207. }
  208. (void)in_nodes.insert(std::make_pair(peer_out_anchor->GetOwnerNode(), peer_out_anchor->GetIdx()));
  209. }
  210. return;
  211. }
  212. ///
  213. /// @brief Insert opposite branch for ControlTrigger
  214. /// @param [in] graph
  215. /// @param [in] ControlTrigger node
  216. /// @param [in] in_ctrl_node
  217. /// @param [in] switch_node
  218. /// @param [in] branch_flag
  219. /// @return Status
  220. ///
  221. Status ControlTriggerPass::InsertOppositeBranch(ComputeGraphPtr &graph, NodePtr &node, NodePtr &in_ctrl_node,
  222. NodePtr &switch_node, bool branch_flag) {
  223. GE_CHECK_NOTNULL(node);
  224. GE_CHECK_NOTNULL(in_ctrl_node);
  225. GE_CHECK_NOTNULL(switch_node);
  226. OpDescPtr switch_desc = switch_node->GetOpDesc();
  227. GE_CHECK_NOTNULL(switch_desc);
  228. GeTensorDesc data_desc(GeShape(), FORMAT_NCHW, DT_INT32);
  229. NodePtr merge_node = InsertMergeNode(graph, node, in_ctrl_node, data_desc);
  230. if (merge_node == nullptr) {
  231. GELOGE(FAILED, "[Insert][MergeNode] failed, node:%s, in_ctrl_node:%s.",
  232. node->GetName().c_str(), in_ctrl_node->GetName().c_str());
  233. return FAILED;
  234. }
  235. NodePtr const_f = InsertConstNode(graph, merge_node, data_desc, false);
  236. NodePtr const_t = InsertConstNode(graph, merge_node, data_desc, true);
  237. if ((const_f == nullptr) || (const_t == nullptr)) {
  238. GELOGE(FAILED, "[Insert][ConstNode] failed, graph:%s, merge_node:%s.",
  239. graph->GetName().c_str(), merge_node->GetName().c_str());
  240. return FAILED;
  241. }
  242. NodePtr orig_const = branch_flag ? const_t : const_f;
  243. NodePtr new_const = !branch_flag ? const_t : const_f;
  244. uint32_t new_idx = branch_flag ? SWITCH_FALSE_OUTPUT : SWITCH_TRUE_OUTPUT;
  245. const std::string identity_name = switch_desc->GetName() + "_" + IDENTITY;
  246. NodePtr identity_node = InsertIdentityNode(graph, identity_name, switch_desc->GetOutputDesc(new_idx));
  247. if (identity_node == nullptr) {
  248. GELOGE(FAILED, "[Insert][IdentityNode] name:%s failed, graph:%s.",
  249. identity_name.c_str(), graph->GetName().c_str());
  250. return FAILED;
  251. }
  252. if (GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), orig_const->GetInControlAnchor()) != GRAPH_SUCCESS) {
  253. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  254. in_ctrl_node->GetName().c_str(), in_ctrl_node->GetType().c_str(),
  255. orig_const->GetName().c_str(), orig_const->GetType().c_str());
  256. GELOGE(FAILED, "[Add][CtrlEdge] failed, %s->%s.", in_ctrl_node->GetName().c_str(), orig_const->GetName().c_str());
  257. return FAILED;
  258. }
  259. if (GraphUtils::AddEdge(switch_node->GetOutDataAnchor(new_idx), identity_node->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  260. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:%u) and op:%s(%s)(index:0) failed",
  261. switch_node->GetName().c_str(), switch_node->GetType().c_str(), new_idx,
  262. identity_node->GetName().c_str(), identity_node->GetType().c_str());
  263. GELOGE(FAILED, "[Add][Edge] between op:%s(%s)(index:%u) and op:%s(%s)(index:0) failed",
  264. switch_node->GetName().c_str(), switch_node->GetType().c_str(), new_idx,
  265. identity_node->GetName().c_str(), identity_node->GetType().c_str());
  266. return FAILED;
  267. }
  268. if (GraphUtils::AddEdge(identity_node->GetOutControlAnchor(), new_const->GetInControlAnchor()) != GRAPH_SUCCESS) {
  269. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  270. identity_node->GetName().c_str(), identity_node->GetType().c_str(),
  271. new_const->GetName().c_str(), new_const->GetType().c_str());
  272. GELOGE(FAILED, "[Add][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
  273. identity_node->GetName().c_str(), identity_node->GetType().c_str(),
  274. new_const->GetName().c_str(), new_const->GetType().c_str());
  275. return FAILED;
  276. }
  277. auto pred_const = std::make_pair(switch_cond_map_[switch_node], std::make_pair(const_f, const_t));
  278. auto iter = control_trigger_map_.find(node);
  279. if (iter == control_trigger_map_.end()) {
  280. control_trigger_map_[node] = {pred_const};
  281. } else {
  282. if (!iter->second.insert(pred_const).second) {
  283. REPORT_INNER_ERROR("E19999", "Insert to control_trigger_map_ failed");
  284. GELOGE(FAILED, "[Check][Param] control_trigger_map_ insert failed.");
  285. return FAILED;
  286. }
  287. }
  288. return SUCCESS;
  289. }
  290. ///
  291. /// @brief Insert Merge Node
  292. /// @param [in] graph
  293. /// @param [in] node
  294. /// @param [in] in_ctrl_node
  295. /// @param [in] data_desc
  296. /// @return NodePtr
  297. ///
  298. NodePtr ControlTriggerPass::InsertMergeNode(ComputeGraphPtr &graph, NodePtr &node, NodePtr &in_ctrl_node,
  299. const GeTensorDesc &data_desc) {
  300. const std::string name = node->GetName() + "_" + MERGE;
  301. OpDescPtr op_desc = MakeShared<OpDesc>(name, MERGE);
  302. if (op_desc == nullptr) {
  303. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  304. GELOGE(FAILED, "[New][OpDesc] failed");
  305. return nullptr;
  306. }
  307. if ((op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) || (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) ||
  308. (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS) || (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS)) {
  309. REPORT_CALL_ERROR("E19999", "Add input or ouput desc to op:%s(%s) failed",
  310. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  311. GELOGE(INTERNAL_ERROR, "[Add][GeTensorDesc] to op:%s(%s) failed",
  312. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  313. return nullptr;
  314. }
  315. GELOGI("Create Merge op:%s.", name.c_str());
  316. NodePtr merge_node = graph->AddNode(op_desc);
  317. if (merge_node == nullptr) {
  318. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  319. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  320. GELOGE(INTERNAL_ERROR, "[Add][Node] %s(%s) to graph:%s failed",
  321. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  322. return nullptr;
  323. }
  324. if ((GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), node->GetInControlAnchor()) != GRAPH_SUCCESS) ||
  325. (GraphUtils::AddEdge(merge_node->GetOutControlAnchor(), node->GetInControlAnchor()) != GRAPH_SUCCESS)) {
  326. REPORT_CALL_ERROR("E19999", "Remove control edge between op:%s(%s) and op:%s(%s), then "
  327. "add control edge between op:%s(%s) and op:%s(%s) failed",
  328. in_ctrl_node->GetName().c_str(), in_ctrl_node->GetType().c_str(),
  329. node->GetName().c_str(), node->GetType().c_str(),
  330. merge_node->GetName().c_str(), merge_node->GetType().c_str(),
  331. node->GetName().c_str(), node->GetType().c_str());
  332. GELOGE(FAILED, "[Replace][CtrlEdge] failed, remove edge:%s->%s, add edge:%s->%s",
  333. in_ctrl_node->GetName().c_str(), node->GetName().c_str(),
  334. merge_node->GetName().c_str(), node->GetName().c_str());
  335. return nullptr;
  336. }
  337. return merge_node;
  338. }
  339. ///
  340. /// @brief Insert Const Node
  341. /// @param [in] graph
  342. /// @param [in] merge_node
  343. /// @param [in] data_desc
  344. /// @param [in] flag
  345. /// @return NodePtr
  346. ///
  347. NodePtr ControlTriggerPass::InsertConstNode(ComputeGraphPtr &graph, NodePtr &merge_node, const GeTensorDesc &data_desc,
  348. bool flag) {
  349. const std::string name = merge_node->GetName() + "_" + CONSTANT + (flag ? "_t" : "_f");
  350. OpDescPtr op_desc = MakeShared<OpDesc>(name, CONSTANT);
  351. if (op_desc == nullptr) {
  352. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  353. GELOGE(FAILED, "[New][OpDesc] failed.");
  354. return nullptr;
  355. }
  356. int32_t value = 0;
  357. GeTensorPtr const_value = MakeShared<GeTensor>(data_desc, reinterpret_cast<uint8_t *>(&value), sizeof(int32_t));
  358. if (const_value == nullptr) {
  359. REPORT_CALL_ERROR("E19999", "New GeTensor failed");
  360. GELOGE(FAILED, "[New][GeTensor] failed.");
  361. return nullptr;
  362. }
  363. if (!AttrUtils::SetTensor(op_desc, ATTR_NAME_WEIGHTS, const_value)) {
  364. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed", ATTR_NAME_WEIGHTS.c_str(),
  365. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  366. GELOGE(INTERNAL_ERROR, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_WEIGHTS.c_str(),
  367. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  368. return nullptr;
  369. }
  370. if (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS) {
  371. REPORT_CALL_ERROR("E19999", "Add ouput desc to op:%s(%s) failed",
  372. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  373. GELOGE(INTERNAL_ERROR, "[Add][OutputDesc] to op:%s(%s) failed",
  374. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  375. return nullptr;
  376. }
  377. GELOGI("Create Const op: %s", name.c_str());
  378. NodePtr const_node = graph->AddNode(op_desc);
  379. if (const_node == nullptr) {
  380. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  381. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  382. GELOGE(INTERNAL_ERROR, "[Add][Node] %s(%s) to graph:%s failed",
  383. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  384. return nullptr;
  385. }
  386. uint32_t out_idx = (flag ? SWITCH_TRUE_OUTPUT : SWITCH_FALSE_OUTPUT);
  387. if (GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), merge_node->GetInDataAnchor(out_idx)) != GRAPH_SUCCESS) {
  388. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(index:0) and op:%s(%s)(index:%u) failed",
  389. const_node->GetName().c_str(), const_node->GetType().c_str(),
  390. merge_node->GetName().c_str(), merge_node->GetType().c_str(), out_idx);
  391. GELOGE(FAILED, "[Add][Edge] between op:%s(%s)(index:0) and op:%s(%s)(index:%u) failed",
  392. const_node->GetName().c_str(), const_node->GetType().c_str(),
  393. merge_node->GetName().c_str(), merge_node->GetType().c_str(), out_idx);
  394. return nullptr;
  395. }
  396. return const_node;
  397. }
  398. ///
  399. /// @brief Insert Identity Node
  400. /// @param [in] graph
  401. /// @param [in] name
  402. /// @param [in] data_desc
  403. /// @return NodePtr
  404. ///
  405. NodePtr ControlTriggerPass::InsertIdentityNode(ComputeGraphPtr &graph, const std::string &name,
  406. const GeTensorDesc &data_desc) {
  407. OpDescPtr op_desc = MakeShared<OpDesc>(name, IDENTITY);
  408. if (op_desc == nullptr) {
  409. REPORT_CALL_ERROR("E19999", "New OpDesc failed");
  410. GELOGE(FAILED, "[New][OpDesc] failed");
  411. return nullptr;
  412. }
  413. if ((op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) || (op_desc->AddOutputDesc(data_desc) != GRAPH_SUCCESS)) {
  414. REPORT_CALL_ERROR("E19999", "Add input or output desc to op:%s(%s) failed",
  415. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  416. GELOGE(INTERNAL_ERROR, "[Add][GeTensorDesc] to op:%s(%s) failed",
  417. op_desc->GetName().c_str(), op_desc->GetType().c_str());
  418. return nullptr;
  419. }
  420. GELOGI("Create Identity op:%s.", name.c_str());
  421. NodePtr identity_node = graph->AddNode(op_desc);
  422. if (identity_node == nullptr) {
  423. REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed",
  424. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  425. GELOGE(INTERNAL_ERROR, "[Add][Node] %s(%s) to graph:%s failed",
  426. op_desc->GetName().c_str(), op_desc->GetType().c_str(), graph->GetName().c_str());
  427. return nullptr;
  428. }
  429. return identity_node;
  430. }
  431. ///
  432. /// @brief Find pred_input of switch_node
  433. /// @param [in] switch_node
  434. /// @param [in] name
  435. /// @param [in] data_desc
  436. /// @return Status
  437. ///
  438. Status ControlTriggerPass::FindPredInput(const NodePtr &switch_node) {
  439. if (switch_node == nullptr) {
  440. REPORT_INNER_ERROR("E19999", "Param switch_node is nullptr, check invalid");
  441. GELOGE(INTERNAL_ERROR, "[Check][Param] switch_node is nullptr");
  442. return INTERNAL_ERROR;
  443. }
  444. InDataAnchorPtr in_cond_anchor = switch_node->GetInDataAnchor(SWITCH_PRED_INPUT);
  445. if (in_cond_anchor == nullptr) {
  446. REPORT_INNER_ERROR("E19999", "Index:%d in anchor of switch_node:%s(%s) is nullptr, check invalid",
  447. SWITCH_PRED_INPUT,
  448. switch_node->GetName().c_str(), switch_node->GetType().c_str());
  449. GELOGE(INTERNAL_ERROR, "[Get][InDataAnchor] Index:%d in anchor of switch_node:%s(%s) is nullptr",
  450. SWITCH_PRED_INPUT, switch_node->GetName().c_str(), switch_node->GetType().c_str());
  451. return INTERNAL_ERROR;
  452. }
  453. OutDataAnchorPtr pred_cond_anchor = in_cond_anchor->GetPeerOutAnchor();
  454. if (pred_cond_anchor == nullptr) {
  455. REPORT_INNER_ERROR("E19999", "Index:%d in anchor of switch_node:%s(%s), it's peer anchor is nullptr, "
  456. "check invalid", SWITCH_PRED_INPUT,
  457. switch_node->GetName().c_str(), switch_node->GetType().c_str());
  458. GELOGE(INTERNAL_ERROR, "Index:%d in anchor of switch_node:%s(%s), it's peer anchor is nullptr",
  459. SWITCH_PRED_INPUT, switch_node->GetName().c_str(), switch_node->GetType().c_str());
  460. return INTERNAL_ERROR;
  461. }
  462. switch_cond_map_[switch_node] = pred_cond_anchor->GetOwnerNode();
  463. return SUCCESS;
  464. }
  465. ///
  466. /// @brief Clear Status, used for subgraph pass
  467. /// @return SUCCESS
  468. ///
  469. Status ControlTriggerPass::ClearStatus() {
  470. switch_cond_map_.clear();
  471. control_trigger_map_.clear();
  472. return SUCCESS;
  473. }
  474. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示