You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

omg_util.cc 9.6 kB

5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/common/omg_util.h"
  17. #include "graph/debug/ge_attr_define.h"
  18. #include "graph/utils/graph_utils.h"
  19. #include "graph/utils/tensor_utils.h"
  20. #include "common/math/math_util.h"
  21. namespace ge {
  22. ///
  23. /// @brief get the Original Type of FrameworkOp
  24. /// @param [in] node
  25. /// @param [out] type
  26. /// @return Status
  27. ///
  28. Status GetOriginalType(const ge::NodePtr &node, string &type) {
  29. GE_CHECK_NOTNULL(node);
  30. type = node->GetType();
  31. GE_IF_BOOL_EXEC(type != FRAMEWORKOP, return SUCCESS);
  32. GE_CHECK_NOTNULL(node->GetOpDesc());
  33. bool ret = ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type);
  34. if (!ret) {
  35. REPORT_INNER_ERROR("E19999", "Get Attr:%s fail for op:%s(%s)", ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE.c_str(),
  36. node->GetName().c_str(), node->GetType().c_str());
  37. GELOGE(INTERNAL_ERROR, "Get FrameWorkOp original type [%s]", type.c_str());
  38. return INTERNAL_ERROR;
  39. }
  40. GELOGD("Get FrameWorkOp original type [%s]", type.c_str());
  41. return SUCCESS;
  42. }
  43. ///
  44. /// @brief set op stream_label
  45. /// @param [in] node
  46. /// @param [in] label
  47. /// @return Status
  48. ///
  49. Status SetStreamLabel(const ge::NodePtr &node, const std::string &label) {
  50. GE_CHECK_NOTNULL(node);
  51. OpDescPtr tmp_desc = node->GetOpDesc();
  52. GE_CHECK_NOTNULL(tmp_desc);
  53. if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_STREAM_LABEL, label)) {
  54. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_STREAM_LABEL.c_str(),
  55. node->GetName().c_str(), node->GetType().c_str());
  56. GELOGE(FAILED, "Op: %s set ATTR_NAME_STREAM_LABEL failed", node->GetName().c_str());
  57. return FAILED;
  58. }
  59. return SUCCESS;
  60. }
  61. ///
  62. /// @brief set op cycle_event flag
  63. /// @param [in] node
  64. /// @return Status
  65. ///
  66. Status SetCycleEvent(const ge::NodePtr &node) {
  67. GE_CHECK_NOTNULL(node);
  68. OpDescPtr tmp_desc = node->GetOpDesc();
  69. GE_CHECK_NOTNULL(tmp_desc);
  70. if (!AttrUtils::SetBool(tmp_desc, ge::ATTR_NAME_STREAM_CYCLE_EVENT_FLAG, true)) {
  71. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_STREAM_CYCLE_EVENT_FLAG.c_str(),
  72. node->GetName().c_str(), node->GetType().c_str());
  73. GELOGE(FAILED, "Op: %s set ATTR_NAME_STREAM_CYCLE_EVENT_FLAG failed", node->GetName().c_str());
  74. return FAILED;
  75. }
  76. return SUCCESS;
  77. }
  78. ///
  79. /// @brief set op active_label_list
  80. /// @param [in] node
  81. /// @param [in] active_label_list
  82. /// @return Status
  83. ///
  84. Status SetActiveLabelList(const ge::NodePtr &node, const std::vector<std::string> &active_label_list) {
  85. GE_CHECK_NOTNULL(node);
  86. OpDescPtr tmp_desc = node->GetOpDesc();
  87. GE_CHECK_NOTNULL(tmp_desc);
  88. if (!AttrUtils::SetListStr(tmp_desc, ge::ATTR_NAME_ACTIVE_LABEL_LIST, active_label_list)) {
  89. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_ACTIVE_LABEL_LIST.c_str(),
  90. node->GetName().c_str(), node->GetType().c_str());
  91. GELOGE(FAILED, "Op: %s set ATTR_NAME_ACTIVE_LABEL_LIST failed", node->GetName().c_str());
  92. return FAILED;
  93. }
  94. return SUCCESS;
  95. }
  96. ///
  97. /// @brief set op branch_label
  98. /// @param [in] node
  99. /// @param [in] branch_label
  100. /// @return Status
  101. ///
  102. Status SetSwitchBranchNodeLabel(const ge::NodePtr &node, const std::string &branch_label) {
  103. GE_CHECK_NOTNULL(node);
  104. OpDescPtr tmp_desc = node->GetOpDesc();
  105. GE_CHECK_NOTNULL(tmp_desc);
  106. if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_SWITCH_BRANCH_NODE_LABEL, branch_label)) {
  107. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_SWITCH_BRANCH_NODE_LABEL.c_str(),
  108. node->GetName().c_str(), node->GetType().c_str());
  109. GELOGE(FAILED, "Op: %s set ATTR_NAME_SWITCH_BRANCH_NODE_LABEL failed", node->GetName().c_str());
  110. return FAILED;
  111. }
  112. return SUCCESS;
  113. }
  114. ///
  115. /// @brief set op true_branch flag
  116. /// @param [in] node
  117. /// @param [in] value
  118. /// @return Status
  119. ///
  120. Status SetSwitchTrueBranchFlag(const ge::NodePtr &node, bool value) {
  121. GE_CHECK_NOTNULL(node);
  122. OpDescPtr tmp_desc = node->GetOpDesc();
  123. GE_CHECK_NOTNULL(tmp_desc);
  124. if (!AttrUtils::SetBool(tmp_desc, ge::ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value)) {
  125. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG.c_str(),
  126. node->GetName().c_str(), node->GetType().c_str());
  127. GELOGE(FAILED, "Op: %s set ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG failed", node->GetName().c_str());
  128. return FAILED;
  129. }
  130. return SUCCESS;
  131. }
  132. ///
  133. /// @brief set op original name
  134. /// @param [in] node
  135. /// @param [in] orig_name
  136. /// @return Status
  137. ///
  138. Status SetOriginalNodeName(const ge::NodePtr &node, const std::string &orig_name) {
  139. GE_CHECK_NOTNULL(node);
  140. OpDescPtr tmp_desc = node->GetOpDesc();
  141. GE_CHECK_NOTNULL(tmp_desc);
  142. if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_ORIG_NODE_NAME, orig_name)) {
  143. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_ORIG_NODE_NAME.c_str(),
  144. node->GetName().c_str(), node->GetType().c_str());
  145. GELOGE(FAILED, "Op: %s set ATTR_NAME_ORIG_NODE_NAME failed", node->GetName().c_str());
  146. return FAILED;
  147. }
  148. return SUCCESS;
  149. }
  150. ///
  151. /// @brief set op cyclic_dependence flag
  152. /// @param [in] node
  153. /// @return Status
  154. ///
  155. Status SetCyclicDependenceFlag(const ge::NodePtr &node) {
  156. GE_CHECK_NOTNULL(node);
  157. OpDescPtr tmp_desc = node->GetOpDesc();
  158. GE_CHECK_NOTNULL(tmp_desc);
  159. if (!AttrUtils::SetBool(tmp_desc, ge::ATTR_NAME_CYCLIC_DEPENDENCE_FLAG, true)) {
  160. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_CYCLIC_DEPENDENCE_FLAG.c_str(),
  161. node->GetName().c_str(), node->GetType().c_str());
  162. GELOGE(FAILED, "Op: %s set ATTR_NAME_CYCLIC_DEPENDENCE_FLAG failed", node->GetName().c_str());
  163. return FAILED;
  164. }
  165. return SUCCESS;
  166. }
  167. ///
  168. /// @brief set op next_iteration name
  169. /// @param [in] node
  170. /// @param [in] next
  171. /// @return Status
  172. ///
  173. Status SetNextIteration(const ge::NodePtr &node, const std::string &next) {
  174. GE_CHECK_NOTNULL(node);
  175. OpDescPtr tmp_desc = node->GetOpDesc();
  176. GE_CHECK_NOTNULL(tmp_desc);
  177. if (!AttrUtils::SetStr(tmp_desc, ge::ATTR_NAME_NEXT_ITERATION, next)) {
  178. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_NEXT_ITERATION.c_str(),
  179. node->GetName().c_str(), node->GetType().c_str());
  180. GELOGE(FAILED, "Op: %s set ATTR_NAME_NEXT_ITERATION failed", node->GetName().c_str());
  181. return FAILED;
  182. }
  183. return SUCCESS;
  184. }
  185. ///
  186. /// @brief Align the memory
  187. /// @param [in/out] memory size
  188. /// @param [in] alinment
  189. /// @return void
  190. ///
  191. void AlignMemSize(int64_t &mem_size, int64_t align_size) {
  192. if (mem_size <= 0) {
  193. return;
  194. }
  195. mem_size = (mem_size + align_size - 1) / align_size * align_size;
  196. }
  197. ///
  198. /// @brief Get memory size from tensor desc
  199. /// @param [in] node
  200. /// @param [out] memory size
  201. /// @return Status
  202. ///
  203. Status GetMemorySize(const NodePtr &node, int64_t &output_size) {
  204. GE_CHECK_NOTNULL(node->GetOpDesc());
  205. auto output_op_desc = node->GetOpDesc()->GetOutputDescPtr(kBufferPoolNodeOutIndex);
  206. GE_CHECK_NOTNULL(output_op_desc);
  207. int64_t size = 0;
  208. auto ret = ge::TensorUtils::GetSize(*output_op_desc, size);
  209. if (ret != ge::GRAPH_SUCCESS) {
  210. GELOGE(INTERNAL_ERROR, "[Get][Size]Node:%s.", node->GetName().c_str());
  211. REPORT_INNER_ERROR("E19999", "Failed to get output size, node:%s.", node->GetName().c_str());
  212. return INTERNAL_ERROR;
  213. }
  214. FMK_INT64_ADDCHECK(size, kBufferPoolMemAlignSize);
  215. AlignMemSize(size, kBufferPoolMemAlignSize);
  216. // The HCOM operator requires an additional 512 bytes before and after
  217. FMK_INT64_ADDCHECK(size, (kBufferPoolMemAlignSize + kBufferPoolMemAlignSize));
  218. output_size = kBufferPoolMemAlignSize + size + kBufferPoolMemAlignSize;
  219. return SUCCESS;
  220. }
  221. ///
  222. /// @brief Check Is Unknown shape Tensor
  223. /// @param [in] tensor_desc
  224. /// @return true: Unknown / false: Known
  225. ///
  226. bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) {
  227. const static int kUnknowShape = -1;
  228. const static int kUnknowRank = -2;
  229. for (auto dim_size : tensor_desc.GetShape().GetDims()) {
  230. if (dim_size == kUnknowShape || dim_size == kUnknowRank) {
  231. return true;
  232. }
  233. }
  234. return false;
  235. }
  236. ///
  237. /// @brief Set Op _force_unknown_shape flag
  238. /// @param [in] node
  239. /// @param [in] force_unknown, set attribute if true
  240. /// @return
  241. ///
  242. void MarkForceUnknownShape(const NodePtr &node, bool force_unknown) {
  243. GE_RT_VOID_CHECK_NOTNULL(node);
  244. if (!force_unknown) {
  245. return;
  246. }
  247. GELOGD("[%s] mark as force unknown shape node", node->GetName().c_str());
  248. if (!AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, force_unknown)) {
  249. REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(),
  250. node->GetName().c_str(), node->GetType().c_str());
  251. GELOGE(FAILED, "Op: %s set %s failed", node->GetName().c_str(), ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str());
  252. }
  253. }
  254. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示