You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

ref_identity_delete_op_pass.cc 12 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. * Licensed under the Apache License, Version 2.0 (the "License");
  4. * you may not use this file except in compliance with the License.
  5. * You may obtain a copy of the License at
  6. * http://www.apache.org/licenses/LICENSE-2.0
  7. * Unless required by applicable law or agreed to in writing, software
  8. * distributed under the License is distributed on an "AS IS" BASIS,
  9. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. * See the License for the specific language governing permissions and
  11. * limitations under the License.
  12. */
  13. #include "ref_identity_delete_op_pass.h"
  14. #include <map>
  15. #include <stack>
  16. #include "graph/common/transop_util.h"
  17. namespace ge {
  18. Status RefIdentityDeleteOpPass::Run(ComputeGraphPtr graph) {
  19. GE_CHECK_NOTNULL(graph);
  20. for (auto &node : graph->GetAllNodes()) {
  21. if (node->GetType() != REFIDENTITY) {
  22. continue;
  23. }
  24. int input_index = 0;
  25. NodePtr ref_node = GetRefNode(node, input_index);
  26. CHECK_FALSE_EXEC(GetRefNode(node, input_index) != nullptr,
  27. REPORT_CALL_ERROR("E19999", "Get Ref node of node:%s(%s) failed",
  28. node->GetName().c_str(), node->GetType().c_str());
  29. GELOGE(FAILED, "[Get][RefNode] of node:%s(%s) failed",
  30. node->GetName().c_str(), node->GetType().c_str());
  31. return FAILED);
  32. CHECK_FALSE_EXEC(DealNoOutputRef(ref_node, node, input_index, graph) == SUCCESS,
  33. GELOGE(FAILED, "[Deal][NoOutputRef] for node:%s failed, index:%d",
  34. node->GetName().c_str(), input_index);
  35. return FAILED);
  36. }
  37. return SUCCESS;
  38. }
  39. NodePtr RefIdentityDeleteOpPass::GetRefNode(const NodePtr &node, int &input_index) {
  40. OutDataAnchorPtr out_anchor = node->GetOutDataAnchor(0);
  41. CHECK_FALSE_EXEC(out_anchor != nullptr, return nullptr);
  42. for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) {
  43. CHECK_FALSE_EXEC(peer_in_anchor != nullptr, continue);
  44. auto peer_node = peer_in_anchor->GetOwnerNode();
  45. CHECK_FALSE_EXEC(peer_node != nullptr, continue);
  46. const auto &peer_op_desc = peer_node->GetOpDesc();
  47. CHECK_FALSE_EXEC(peer_op_desc != nullptr, return nullptr);
  48. const auto &peer_input_desc = peer_op_desc->GetInputDescPtr(static_cast<uint32_t>(peer_in_anchor->GetIdx()));
  49. if (!peer_input_desc->GetRefPortIndex().empty()) {
  50. input_index = peer_in_anchor->GetIdx();
  51. return peer_node;
  52. }
  53. }
  54. return nullptr;
  55. }
  56. Status RefIdentityDeleteOpPass::DealNoOutputRef(const NodePtr &node, const NodePtr &ref_identity, int input_index,
  57. const ComputeGraphPtr &graph) {
  58. NodePtr first_node = nullptr;
  59. NodePtr variable_ref = GetVariableRef(node, ref_identity, first_node);
  60. if (variable_ref == nullptr) {
  61. REPORT_CALL_ERROR("E19999", "Get variable ref of node:%s(%s) failed",
  62. node->GetName().c_str(), node->GetType().c_str());
  63. GELOGE(FAILED, "[Get][VariableRef] of node:%s(%s) failed", node->GetName().c_str(), node->GetType().c_str());
  64. return FAILED;
  65. }
  66. if (first_node->GetName() != variable_ref->GetName()) {
  67. // Remove the control edge between ref node and variable ref
  68. // Add a control edge between ref node and trans node
  69. // +-----------+ +-----------+
  70. // +---------+RefIdentity| +-----------+RefIdentity|
  71. // | +-----+-----+ | +-----+-----+
  72. // | | | |
  73. // | v | v
  74. // +-----v-----+ +----+----+ +-----v-----+ +----+----+
  75. // | TransNode | | RefNode | ==> | TransNode +<--C--+ RefNode |
  76. // +-----+-----+ +----+----+ +-----+-----+ +---------+
  77. // | | |
  78. // v C v
  79. // +-----+-----+ | +-----+-----+
  80. // |VariableRef+<--------+ |VariableRef|
  81. // +-----------+ +-----------+
  82. auto ret = ge::GraphUtils::AddEdge(node->GetOutControlAnchor(), first_node->GetInControlAnchor());
  83. if (ret != SUCCESS) {
  84. REPORT_CALL_ERROR("E19999", "Add control edge between op:%s(%s) and op:%s(%s) failed",
  85. node->GetName().c_str(), node->GetType().c_str(),
  86. first_node->GetName().c_str(), first_node->GetType().c_str());
  87. GELOGE(FAILED, "[Add][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
  88. node->GetName().c_str(), node->GetType().c_str(),
  89. first_node->GetName().c_str(), first_node->GetType().c_str());
  90. return FAILED;
  91. }
  92. ret = ge::GraphUtils::RemoveEdge(node->GetOutControlAnchor(), variable_ref->GetInControlAnchor());
  93. if (ret != SUCCESS) {
  94. REPORT_CALL_ERROR("E19999", "Remove control edge between op:%s(%s) and op:%s(%s) failed",
  95. node->GetName().c_str(), node->GetType().c_str(),
  96. first_node->GetName().c_str(), first_node->GetType().c_str());
  97. GELOGE(FAILED, "[Remove][ControlEdge] between op:%s(%s) and op:%s(%s) failed",
  98. node->GetName().c_str(), node->GetType().c_str(),
  99. first_node->GetName().c_str(), first_node->GetType().c_str());
  100. return FAILED;
  101. }
  102. } else {
  103. // +-----------+ +-----------+
  104. // +-----------+RefIdentity| +-----------+RefIdentity|
  105. // | +-----+-----+ | +-----+-----+
  106. // | | | |
  107. // | v | v
  108. // +-----v-----+ +----+----+ +-----v-----+ +----+----+
  109. // |VariableRef+<--C--+ RefNode | ==> |VariableRef+<--C--+ RefNode |
  110. // +-----+-----+ +----+----+ +-----------+ +----+----+
  111. // | | |
  112. // | v v
  113. // | +---+----+ +---+----+
  114. // +-----C------>+ | | |
  115. // +--------+ +--------+
  116. auto ret = RemoveUselessControlEdge(node, variable_ref);
  117. if (ret != SUCCESS) {
  118. GELOGE(FAILED, "[Remove][UselessControlEdge] between node:%s(%s) and node:%s(%s) failed.",
  119. node->GetName().c_str(), node->GetType().c_str(),
  120. variable_ref->GetName().c_str(), variable_ref->GetType().c_str());
  121. return FAILED;
  122. }
  123. }
  124. // remove ref identity
  125. if (GraphUtils::IsolateNode(ref_identity, {0}) != GRAPH_SUCCESS) {
  126. REPORT_CALL_ERROR("E19999", "Isolate op:%s(%s) failed",
  127. ref_identity->GetName().c_str(), ref_identity->GetType().c_str());
  128. GELOGE(INTERNAL_ERROR, "[Isolate][Node] %s, type:%s failed", ref_identity->GetName().c_str(),
  129. variable_ref->GetType().c_str());
  130. return FAILED;
  131. }
  132. if (GraphUtils::RemoveNodeWithoutRelink(graph, ref_identity) != GRAPH_SUCCESS) {
  133. REPORT_CALL_ERROR("E19999", "Remove node:%s(%s) without relink in graph:%s failed",
  134. ref_identity->GetName().c_str(), ref_identity->GetType().c_str(), graph->GetName().c_str());
  135. GELOGE(INTERNAL_ERROR, "[Remove][Node] %s, type:%s without relink in graph:%s failed",
  136. ref_identity->GetName().c_str(), ref_identity->GetType().c_str(), graph->GetName().c_str());
  137. return FAILED;
  138. }
  139. return SUCCESS;
  140. }
  141. ge::NodePtr RefIdentityDeleteOpPass::GetVariableRef(const NodePtr &ref, const NodePtr &ref_identity,
  142. NodePtr &first_node) {
  143. const auto &ref_identity_out_anchor = ref_identity->GetOutDataAnchor(0);
  144. if (ref_identity_out_anchor == nullptr) {
  145. return nullptr;
  146. }
  147. for (auto &peer_in_anchor : ref_identity_out_anchor->GetPeerInDataAnchors()) {
  148. const auto &peer_node = peer_in_anchor->GetOwnerNode();
  149. if (peer_node == nullptr || peer_node->GetName() == ref->GetName()) {
  150. continue;
  151. }
  152. // DFS to find variable ref node.
  153. std::stack<NodePtr> nodes_to_check;
  154. nodes_to_check.push(peer_node);
  155. GELOGI("[RefIdentityDeleteOpPass]Start to search variable ref node from %s.", peer_node->GetName().c_str());
  156. NodePtr cur_node = nullptr;
  157. while (!nodes_to_check.empty()) {
  158. cur_node = nodes_to_check.top();
  159. nodes_to_check.pop();
  160. const auto &type = cur_node->GetType();
  161. if (type == VARIABLE && CheckControlEdge(ref, cur_node)) {
  162. // Target variable ref node found.
  163. GELOGI("[RefIdentityDeleteOpPass]variable ref node[%s] found.", cur_node->GetName().c_str());
  164. first_node = peer_node;
  165. return cur_node;
  166. }
  167. int data_index = TransOpUtil::GetTransOpDataIndex(type);
  168. if (data_index < 0) {
  169. GELOGI("[RefIdentityDeleteOpPass]Find node[%s] that is not trans op[%s], stop to search its output.",
  170. cur_node->GetName().c_str(), type.c_str());
  171. continue;
  172. }
  173. const auto &cur_out_anchor = cur_node->GetOutDataAnchor(0);
  174. if (cur_out_anchor == nullptr) {
  175. GELOGI("[RefIdentityDeleteOpPass]Get out anchor of [%s] failed, stop to search its output.",
  176. cur_node->GetName().c_str());
  177. continue;
  178. }
  179. for (const auto &cur_peer_in_anchor : cur_out_anchor->GetPeerInDataAnchors()) {
  180. const auto &cur_peer_node = cur_peer_in_anchor->GetOwnerNode();
  181. if (cur_peer_node == nullptr) {
  182. continue;
  183. }
  184. nodes_to_check.push(cur_peer_node);
  185. }
  186. }
  187. GELOGI("[RefIdentityDeleteOpPass]Can not find variable ref node from %s.", peer_node->GetName().c_str());
  188. }
  189. GELOGI("[RefIdentityDeleteOpPass]Can not find variable ref node, return nullptr.");
  190. return nullptr;
  191. }
  192. bool RefIdentityDeleteOpPass::CheckControlEdge(const NodePtr &ref, const NodePtr &variable_ref) {
  193. const auto &control_out_anchor = ref->GetOutControlAnchor();
  194. if (control_out_anchor == nullptr) {
  195. return false;
  196. }
  197. const string &variable_ref_name = variable_ref->GetName();
  198. for (const auto &peer_in_control_anchor : control_out_anchor->GetPeerInControlAnchors()) {
  199. const auto &node = peer_in_control_anchor->GetOwnerNode();
  200. if (node != nullptr && node->GetName() == variable_ref_name) {
  201. return true;
  202. }
  203. }
  204. return false;
  205. }
  206. Status RefIdentityDeleteOpPass::RemoveUselessControlEdge(const NodePtr &ref, const NodePtr &variable_ref) {
  207. map<string, NodePtr> out_nodes_map;
  208. for (const auto &out_anchor : ref->GetAllOutDataAnchors()) {
  209. for (const auto &peer_in_anchor : out_anchor->GetPeerAnchors()) {
  210. const auto &peer_node = peer_in_anchor->GetOwnerNode();
  211. if (peer_node == nullptr) {
  212. continue;
  213. }
  214. out_nodes_map[peer_node->GetName()] = peer_node;
  215. }
  216. }
  217. const auto &out_control_anchor = variable_ref->GetOutControlAnchor();
  218. GE_CHECK_NOTNULL(out_control_anchor);
  219. for (const auto &peer_in_control_anchor : out_control_anchor->GetPeerInControlAnchors()) {
  220. const auto &peer_node = peer_in_control_anchor->GetOwnerNode();
  221. if (peer_node == nullptr) {
  222. continue;
  223. }
  224. if (out_nodes_map.find(peer_node->GetName()) != out_nodes_map.end()) {
  225. auto ret = ge::GraphUtils::RemoveEdge(out_control_anchor, peer_in_control_anchor);
  226. if (ret != SUCCESS) {
  227. REPORT_CALL_ERROR("E19999", "Remove control edge between op:%s(%s) and op:%s(%s) failed",
  228. variable_ref->GetName().c_str(), variable_ref->GetType().c_str(),
  229. peer_node->GetName().c_str(), peer_node->GetType().c_str());
  230. GELOGE(FAILED, "[Remove][ControlEdge] between variable ref node[%s] and ref node's peer node[%s] failed",
  231. variable_ref->GetName().c_str(), peer_node->GetName().c_str());
  232. return FAILED;
  233. }
  234. }
  235. }
  236. return SUCCESS;
  237. }
  238. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示