You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

assign_remove_pass.cc 11 kB

4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
4 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/assign_remove_pass.h"
  17. #include "framework/common/debug/log.h"
  18. #include "graph/utils/graph_utils.h"
  19. #include "graph/debug/ge_attr_define.h"
  20. namespace ge {
  21. namespace {
  22. constexpr uint32_t kValidInputNodeOutputNum = 1;
  23. constexpr int32_t kAssignRefInputIndex = 0;
  24. constexpr int32_t kAssignValueInputIndex = 1;
  25. const std::set<std::string> kNoTaskNodeTypes = { ge::DATA, ge::ANN_DATA, ge::AIPPDATA,
  26. ge::CONSTANT, ge::CONSTANTOP,
  27. ge::VARIABLE, ge::VARIABLEV2 };
  28. }
  29. Status AssignRemovePass::Run(NodePtr &node) {
  30. GELOGD("AssignRemovePass running");
  31. if (TransformAttr(node) != SUCCESS) {
  32. GELOGE(FAILED, "[Call][TransformAttr] Transform assign_var_name attr failed, node=%s", node->GetName().c_str());
  33. return FAILED;
  34. }
  35. if (node->GetType() == ASSIGN) {
  36. if (OptimizedAssignNode(node) != SUCCESS) {
  37. GELOGE(FAILED, "[Call][Optimize] for assign_node %s failed", node->GetName().c_str());
  38. return FAILED;
  39. }
  40. }
  41. GELOGD("AssignRemovePass success");
  42. return SUCCESS;
  43. }
  44. ///
  45. /// @brief Optimize for assign_node
  46. /// @param [in] assign_node
  47. /// @return Status
  48. ///
  49. Status AssignRemovePass::OptimizedAssignNode(NodePtr &assign_node) {
  50. const auto &ref_in_anchor = assign_node->GetInDataAnchor(kAssignRefInputIndex);
  51. const auto &value_in_anchor = assign_node->GetInDataAnchor(kAssignValueInputIndex);
  52. if ((ref_in_anchor == nullptr) || (value_in_anchor == nullptr)) {
  53. REPORT_INNER_ERROR("E19999", "Index %d or %d input anchor of node:%s(%s) is nullptr, check invalid",
  54. kAssignRefInputIndex, kAssignValueInputIndex,
  55. assign_node->GetName().c_str(), assign_node->GetType().c_str());
  56. GELOGE(FAILED, "[Check][Param] Index %d or %d input anchor of node:%s(%s) is nullptr",
  57. kAssignRefInputIndex, kAssignValueInputIndex,
  58. assign_node->GetName().c_str(), assign_node->GetType().c_str());
  59. return FAILED;
  60. }
  61. const auto &ref_peer_anchor = ref_in_anchor->GetPeerOutAnchor();
  62. const auto &value_peer_anchor = value_in_anchor->GetPeerOutAnchor();
  63. if ((ref_peer_anchor == nullptr) || (value_peer_anchor == nullptr)) {
  64. REPORT_INNER_ERROR("E19999", "Index %d or %d input anchor of node:%s(%s), peer anchor is nullptr, check invalid",
  65. kAssignRefInputIndex, kAssignValueInputIndex,
  66. assign_node->GetName().c_str(), assign_node->GetType().c_str());
  67. GELOGE(FAILED, "[Check][Param] Index %d or %d input anchor of node:%s(%s), peer anchor is nullptr",
  68. kAssignRefInputIndex, kAssignValueInputIndex,
  69. assign_node->GetName().c_str(), assign_node->GetType().c_str());
  70. return FAILED;
  71. }
  72. if (IsCondMatch(assign_node, ref_peer_anchor, value_peer_anchor)) {
  73. ///
  74. /// variable not-const not-const
  75. /// \ / |
  76. /// \ / |
  77. /// Assign ----> variable
  78. /// | |
  79. /// | |
  80. /// node node
  81. ///
  82. GELOGD("Optimization for assign_node %s start", assign_node->GetName().c_str());
  83. if (IsolateAndDeleteNode(assign_node, {kAssignRefInputIndex}) != SUCCESS) {
  84. REPORT_CALL_ERROR("E19999", "Isolate and delete node:%s(%s) failed",
  85. assign_node->GetName().c_str(), assign_node->GetType().c_str());
  86. GELOGE(FAILED, "[IsolateAndDelete][Node] %s failed.", assign_node->GetName().c_str());
  87. return FAILED;
  88. }
  89. const auto &ref_input = ref_peer_anchor->GetOwnerNode()->GetOpDesc();
  90. const auto &value_input = value_peer_anchor->GetOwnerNode()->GetOpDesc();
  91. if ((ref_input == nullptr) || (value_input == nullptr)) {
  92. REPORT_INNER_ERROR("E19999", "Input index %d or %d of node:%s(%s), peer op is nullptr, check invalid",
  93. kAssignRefInputIndex, kAssignValueInputIndex,
  94. assign_node->GetName().c_str(), assign_node->GetType().c_str());
  95. GELOGE(FAILED, "[Check][Param] Input index %d or %d of node:%s(%s), peer op is nullptr",
  96. kAssignRefInputIndex, kAssignValueInputIndex,
  97. assign_node->GetName().c_str(), assign_node->GetType().c_str());
  98. return FAILED;
  99. }
  100. // variable has and only has one input
  101. if (ref_input->UpdateInputDesc(0, value_input->GetOutputDesc(value_peer_anchor->GetIdx())) != GRAPH_SUCCESS) {
  102. REPORT_CALL_ERROR("E19999", "Input index %d of node:%s(%s), update it's peer op input:0 desc failed",
  103. kAssignRefInputIndex, assign_node->GetName().c_str(), assign_node->GetType().c_str());
  104. GELOGE(FAILED, "[Update][InputDesc] Input index %d of node:%s(%s), update it's peer op input:0 desc failed",
  105. kAssignRefInputIndex, assign_node->GetName().c_str(), assign_node->GetType().c_str());
  106. return FAILED;
  107. }
  108. if (GraphUtils::AddEdge(value_peer_anchor, ref_peer_anchor->GetOwnerNode()->GetInDataAnchor(0)) != GRAPH_SUCCESS) {
  109. REPORT_CALL_ERROR("E19999", "Add edge between op:%s(%s)(out_index:%d) and op:%s(%s)(in_index:0) failed",
  110. value_peer_anchor->GetOwnerNode()->GetName().c_str(),
  111. value_peer_anchor->GetOwnerNode()->GetType().c_str(), value_peer_anchor->GetIdx(),
  112. ref_peer_anchor->GetOwnerNode()->GetName().c_str(),
  113. ref_peer_anchor->GetOwnerNode()->GetType().c_str());
  114. GELOGE(FAILED, "[Add][Edge] between op:%s(%s)(out_index:%d) and op:%s(%s)(in_index:0) failed",
  115. value_peer_anchor->GetOwnerNode()->GetName().c_str(),
  116. value_peer_anchor->GetOwnerNode()->GetType().c_str(), value_peer_anchor->GetIdx(),
  117. ref_peer_anchor->GetOwnerNode()->GetName().c_str(), ref_peer_anchor->GetOwnerNode()->GetType().c_str());
  118. return FAILED;
  119. }
  120. GELOGD("add attr ASSIGN_VAR_NAME on node %s, var_name=%s",
  121. value_input->GetName().c_str(), ref_input->GetName().c_str());
  122. if (!AttrUtils::SetStr(value_input->MutableOutputDesc(value_peer_anchor->GetIdx()), ASSIGN_VAR_NAME,
  123. ref_input->GetName())) {
  124. REPORT_CALL_ERROR("E19999", "Set Attr:%s to output:%d desc of node:%s(%s) failed",
  125. ASSIGN_VAR_NAME.c_str(), value_peer_anchor->GetIdx(),
  126. value_input->GetName().c_str(), value_input->GetType().c_str());
  127. GELOGE(FAILED, "[Set][Attr] %s to output:%d desc of node:%s(%s) failed",
  128. ASSIGN_VAR_NAME.c_str(), value_peer_anchor->GetIdx(),
  129. value_input->GetName().c_str(), value_input->GetType().c_str());
  130. return FAILED;
  131. }
  132. auto value_node = value_peer_anchor->GetOwnerNode();
  133. AddRePassNode(value_node);
  134. }
  135. return SUCCESS;
  136. }
  137. ///
  138. /// @brief Transform assign_var_name attr
  139. /// @param [in] node
  140. /// @return Status
  141. ///
  142. Status AssignRemovePass::TransformAttr(NodePtr &node) {
  143. GE_CHECK_NOTNULL(node->GetOpDesc());
  144. for (const auto &output_desc : node->GetOpDesc()->GetAllOutputsDesc()) {
  145. int32_t inplace_input_idx = -1;
  146. std::string assign_var_name;
  147. if (AttrUtils::GetInt(output_desc, INPLACE_SUPPORT_INPUT_INDEX, inplace_input_idx) &&
  148. AttrUtils::GetStr(output_desc, ASSIGN_VAR_NAME, assign_var_name)) {
  149. GELOGD("Transform attr ASSIGN_VAR_NAME on node %s, assign_var_name=%s, inplace_input_idx=%d, ",
  150. node->GetName().c_str(), assign_var_name.c_str(), inplace_input_idx);
  151. const auto &in_data_anchor = node->GetInDataAnchor(inplace_input_idx);
  152. GE_CHECK_NOTNULL(in_data_anchor);
  153. const auto &peer_data_anchor = in_data_anchor->GetPeerOutAnchor();
  154. GE_CHECK_NOTNULL(peer_data_anchor);
  155. auto in_node = peer_data_anchor->GetOwnerNode();
  156. GE_CHECK_NOTNULL(in_node->GetOpDesc());
  157. GELOGD("add attr ASSIGN_VAR_NAME on node %s, var_name=%s", in_node->GetName().c_str(), assign_var_name.c_str());
  158. if (!AttrUtils::SetStr(in_node->GetOpDesc()->MutableOutputDesc(peer_data_anchor->GetIdx()),
  159. ASSIGN_VAR_NAME, assign_var_name)) {
  160. REPORT_CALL_ERROR("E19999", "Set Attr:%s to output:%d desc of node:%s(%s) failed",
  161. ASSIGN_VAR_NAME.c_str(), peer_data_anchor->GetIdx(),
  162. in_node->GetName().c_str(), in_node->GetType().c_str());
  163. GELOGE(FAILED, "[Set][Attr] %s to output:%d desc of node:%s(%s) failed",
  164. ASSIGN_VAR_NAME.c_str(), peer_data_anchor->GetIdx(),
  165. in_node->GetName().c_str(), in_node->GetType().c_str());
  166. return FAILED;
  167. }
  168. AddRePassNode(in_node);
  169. }
  170. }
  171. return SUCCESS;
  172. }
  173. ///
  174. /// @brief Check if need optimize for assign_node
  175. /// @param [in] assign_node
  176. /// @param [in] peer_data_anchor for ref_input of assign_node
  177. /// @param [in] peer_data_anchor for value_input of assign_node
  178. /// @return Status
  179. ///
  180. bool AssignRemovePass::IsCondMatch(const NodePtr &node, const OutDataAnchorPtr &ref_peer_anchor,
  181. const OutDataAnchorPtr &value_peer_anchor) {
  182. GELOGD("Check if assign_node %s match optimization condition, ref_input: %s, value_input: %s",
  183. node->GetName().c_str(), ref_peer_anchor->GetOwnerNode()->GetName().c_str(),
  184. value_peer_anchor->GetOwnerNode()->GetName().c_str());
  185. if (kNoTaskNodeTypes.count(value_peer_anchor->GetOwnerNode()->GetType()) > 0) {
  186. GELOGD("value input is not calculate node");
  187. return false;
  188. }
  189. const std::string &ref_type = ref_peer_anchor->GetOwnerNode()->GetType();
  190. if ((ref_type != VARIABLE) && (ref_type != VARIABLEV2)) {
  191. GELOGD("ref input is not var");
  192. return false;
  193. }
  194. if (!ref_peer_anchor->GetOwnerNode()->GetInDataNodes().empty()) {
  195. GELOGD("ref input has data input");
  196. return false;
  197. }
  198. if ((ref_peer_anchor->GetPeerInDataNodesSize() != kValidInputNodeOutputNum) ||
  199. (value_peer_anchor->GetPeerInDataNodesSize() != kValidInputNodeOutputNum)) {
  200. GELOGD("ref / value input has other output(s)");
  201. return false;
  202. }
  203. GELOGD("Optimization condition matches, assign_node: %s", node->GetName().c_str());
  204. return true;
  205. }
  206. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示