You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

omg_util.cc 5.5 kB

5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/common/omg_util.h"
  17. #include <algorithm>
  18. #include "framework/common/debug/ge_log.h"
  19. #include "graph/debug/ge_attr_define.h"
  20. #include "graph/utils/graph_utils.h"
  21. namespace ge {
  22. ///
  23. /// @brief get the Original Type of FrameworkOp
  24. /// @param [in] node
  25. /// @param [out] type
  26. /// @return Status
  27. ///
  28. Status GetOriginalType(const ge::NodePtr &node, string &type) {
  29. GE_CHECK_NOTNULL(node);
  30. type = node->GetType();
  31. GE_IF_BOOL_EXEC(type != FRAMEWORKOP, return SUCCESS);
  32. GE_CHECK_NOTNULL(node->GetOpDesc());
  33. bool ret = ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
  34. if (!ret) {
  35. GELOGE(INTERNAL_ERROR, "Get FrameWorkOp original type [%s]", type.c_str());
  36. return INTERNAL_ERROR;
  37. }
  38. GELOGD("Get FrameWorkOp original type [%s]", type.c_str());
  39. return SUCCESS;
  40. }
  41. ///
  42. /// @brief set op stream_label
  43. /// @param [in] node
  44. /// @param [in] label
  45. /// @return Status
  46. ///
  47. Status SetStreamLabel(const ge::NodePtr &node, const std::string &label) {
  48. GE_CHECK_NOTNULL(node);
  49. OpDescPtr tmp_desc = node->GetOpDesc();
  50. GE_CHECK_NOTNULL(tmp_desc);
  51. if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_STREAM_LABEL, label)) {
  52. GELOGE(FAILED, "Op: %s set ATTR_NAME_STREAM_LABEL failed", node->GetName().c_str());
  53. return FAILED;
  54. }
  55. return SUCCESS;
  56. }
  57. ///
  58. /// @brief set op cycle_event flag
  59. /// @param [in] node
  60. /// @return Status
  61. ///
  62. Status SetCycleEvent(const ge::NodePtr &node) {
  63. GE_CHECK_NOTNULL(node);
  64. OpDescPtr tmp_desc = node->GetOpDesc();
  65. GE_CHECK_NOTNULL(tmp_desc);
  66. if (!AttrUtils::SetBool(tmp_desc, ge::ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, true)) {
  67. GELOGE(FAILED, "Op: %s set ATTR_NAME_STREAM_CYCLE_EVENT_FLAG failed", node->GetName().c_str());
  68. return FAILED;
  69. }
  70. return SUCCESS;
  71. }
  72. ///
  73. /// @brief set op active_label_list
  74. /// @param [in] node
  75. /// @param [in] active_label_list
  76. /// @return Status
  77. ///
  78. Status SetActiveLabelList(const ge::NodePtr &node, const std::vector<std::string> &active_label_list) {
  79. GE_CHECK_NOTNULL(node);
  80. OpDescPtr tmp_desc = node->GetOpDesc();
  81. GE_CHECK_NOTNULL(tmp_desc);
  82. if (!AttrUtils::SetListStr(tmp_desc, ge::ATTR_NAME_ACTIVE_LABEL_LIST, active_label_list)) {
  83. GELOGE(FAILED, "Op: %s set ATTR_NAME_ACTIVE_LABEL_LIST failed", node->GetName().c_str());
  84. return FAILED;
  85. }
  86. return SUCCESS;
  87. }
  88. ///
  89. /// @brief set op branch_label
  90. /// @param [in] node
  91. /// @param [in] branch_label
  92. /// @return Status
  93. ///
  94. Status SetSwitchBranchNodeLabel(const ge::NodePtr &node, const std::string &branch_label) {
  95. GE_CHECK_NOTNULL(node);
  96. OpDescPtr tmp_desc = node->GetOpDesc();
  97. GE_CHECK_NOTNULL(tmp_desc);
  98. if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_SWITCH_BRANCH_NODE_LABEL, branch_label)) {
  99. GELOGE(FAILED, "Op: %s set ATTR_NAME_SWITCH_BRANCH_NODE_LABEL failed", node->GetName().c_str());
  100. return FAILED;
  101. }
  102. return SUCCESS;
  103. }
  104. ///
  105. /// @brief set op true_branch flag
  106. /// @param [in] node
  107. /// @param [in] value
  108. /// @return Status
  109. ///
  110. Status SetSwitchTrueBranchFlag(const ge::NodePtr &node, bool value) {
  111. GE_CHECK_NOTNULL(node);
  112. OpDescPtr tmp_desc = node->GetOpDesc();
  113. GE_CHECK_NOTNULL(tmp_desc);
  114. if (!AttrUtils::SetBool(tmp_desc, ge::ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value)) {
  115. GELOGE(FAILED, "Op: %s set ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG failed", node->GetName().c_str());
  116. return FAILED;
  117. }
  118. return SUCCESS;
  119. }
  120. ///
  121. /// @brief set op original name
  122. /// @param [in] node
  123. /// @param [in] orig_name
  124. /// @return Status
  125. ///
  126. Status SetOriginalNodeName(const ge::NodePtr &node, const std::string &orig_name) {
  127. GE_CHECK_NOTNULL(node);
  128. OpDescPtr tmp_desc = node->GetOpDesc();
  129. GE_CHECK_NOTNULL(tmp_desc);
  130. if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_ORIG_NODE_NAME, orig_name)) {
  131. GELOGE(FAILED, "Op: %s set ATTR_NAME_ORIG_NODE_NAME failed", node->GetName().c_str());
  132. return FAILED;
  133. }
  134. return SUCCESS;
  135. }
  136. ///
  137. /// @brief set op cyclic_dependence flag
  138. /// @param [in] node
  139. /// @return Status
  140. ///
  141. Status SetCyclicDependenceFlag(const ge::NodePtr &node) {
  142. GE_CHECK_NOTNULL(node);
  143. OpDescPtr tmp_desc = node->GetOpDesc();
  144. GE_CHECK_NOTNULL(tmp_desc);
  145. if (!AttrUtils::SetBool(tmp_desc, ge::ATTR_NAME_CYCLIC_DEPENDENCE_FLAG, true)) {
  146. GELOGE(FAILED, "Op: %s set ATTR_NAME_CYCLIC_DEPENDENCE_FLAG failed", node->GetName().c_str());
  147. return FAILED;
  148. }
  149. return SUCCESS;
  150. }
  151. ///
  152. /// @brief set op next_iteration name
  153. /// @param [in] node
  154. /// @param [in] next
  155. /// @return Status
  156. ///
  157. Status SetNextIteration(const ge::NodePtr &node, const std::string &next) {
  158. GE_CHECK_NOTNULL(node);
  159. OpDescPtr tmp_desc = node->GetOpDesc();
  160. GE_CHECK_NOTNULL(tmp_desc);
  161. if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_NEXT_ITERATION, next)) {
  162. GELOGE(FAILED, "Op: %s set ATTR_NAME_NEXT_ITERATION failed", node->GetName().c_str());
  163. return FAILED;
  164. }
  165. return SUCCESS;
  166. }
  167. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示