You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

get_original_format_pass.cc 10 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
5 years ago
4 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/get_original_format_pass.h"
  17. #include <vector>
  18. #include "framework/common/debug/log.h"
  19. #include "framework/common/types.h"
  20. #include "framework/common/util.h"
  21. #include "framework/common/debug/ge_log.h"
  22. #include "framework/omg/omg_inner_types.h"
  23. #include "graph/utils/attr_utils.h"
  24. #include "graph/utils/op_desc_utils.h"
  25. #include "common/local_context.h"
  26. using domi::DOMI_TENSOR_NCHW;
  27. using domi::DOMI_TENSOR_NHWC;
  28. using domi::DOMI_TENSOR_RESERVED;
  29. using domi::FAILED;
  30. using domi::PARAM_INVALID;
  31. using domi::SUCCESS;
  32. namespace ge {
  33. Status GetOriginalFormatPass::Run(ge::ComputeGraphPtr graph) {
  34. GE_CHECK_NOTNULL(graph);
  35. GE_RETURN_WITH_LOG_IF_ERROR(SetOriginalFormat(graph),
  36. "[Set][OriginalFormat] for graph:%s failed", graph->GetName().c_str());
  37. return SUCCESS;
  38. }
  39. Status GetOriginalFormatPass::SetOriginalFormat(const ge::ComputeGraphPtr &graph) {
  40. GE_CHECK_NOTNULL(graph);
  41. int64_t ori_format = 0;
  42. int64_t tmp_format = 0;
  43. for (auto &node_ptr : graph->GetDirectNode()) {
  44. GE_CHECK_NOTNULL(node_ptr);
  45. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(node_ptr->GetOpDesc(), ATTR_NAME_INFERRED_FORMAT, DOMI_TENSOR_RESERVED),
  46. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed",
  47. ATTR_NAME_INFERRED_FORMAT.c_str(),
  48. node_ptr->GetName().c_str(), node_ptr->GetType().c_str());
  49. GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_INFERRED_FORMAT.c_str(),
  50. node_ptr->GetName().c_str(), node_ptr->GetType().c_str());
  51. return FAILED);
  52. }
  53. for (auto &node_ptr : graph->GetDirectNode()) {
  54. GE_CHECK_NOTNULL(node_ptr);
  55. OpDescPtr desc_ptr = node_ptr->GetOpDesc();
  56. GE_CHECK_NOTNULL(desc_ptr);
  57. auto is_data = (desc_ptr->GetType() == DATA_TYPE || desc_ptr->GetType() == AIPP_DATA_TYPE);
  58. if (is_data) {
  59. GELOGI("Data node: %s,format :%d", node_ptr->GetName().c_str(), GetLocalOmgContext().format);
  60. ori_format = static_cast<int64_t>(GetLocalOmgContext().format);
  61. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(desc_ptr, ATTR_NAME_FORMAT, ori_format),
  62. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed",
  63. ATTR_NAME_FORMAT.c_str(),
  64. desc_ptr->GetName().c_str(), desc_ptr->GetType().c_str());
  65. GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_FORMAT.c_str(),
  66. desc_ptr->GetName().c_str(), desc_ptr->GetType().c_str());
  67. return FAILED);
  68. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(desc_ptr, ATTR_NAME_INFERRED_FORMAT, ori_format),
  69. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed",
  70. ATTR_NAME_INFERRED_FORMAT.c_str(),
  71. desc_ptr->GetName().c_str(), desc_ptr->GetType().c_str());
  72. GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_INFERRED_FORMAT.c_str(),
  73. desc_ptr->GetName().c_str(), desc_ptr->GetType().c_str());
  74. return FAILED);
  75. continue;
  76. }
  77. int32_t i = 0;
  78. bool continue_flag = false;
  79. bool ignore_pred_format = false;
  80. for (auto &bias_node_ptr : node_ptr->GetInDataNodes()) {
  81. GE_CHECK_NOTNULL(bias_node_ptr);
  82. OpDescPtr bias_op_ptr = bias_node_ptr->GetOpDesc();
  83. GE_CHECK_NOTNULL(bias_op_ptr);
  84. if (bias_op_ptr->GetType() == BIASADD) {
  85. ignore_pred_format = true;
  86. std::size_t tmp_size = ge::OpDescUtils::GetNonConstInputsSize(bias_node_ptr);
  87. GE_IF_BOOL_EXEC(tmp_size > 2 || tmp_size == 0,
  88. GELOGW("bias_node is node followed by %zu nodes, should be 1 or 2", tmp_size);
  89. continue_flag = true; break);
  90. OpDescPtr tmp_first_op_ptr = bias_node_ptr->GetInDataNodes().at(0)->GetOpDesc();
  91. GE_CHECK_NOTNULL(tmp_first_op_ptr);
  92. bias_op_ptr = tmp_first_op_ptr;
  93. // if biasadd have 2 input edges, format should be same
  94. if (tmp_size == 2) {
  95. int64_t first_input_format = 0;
  96. int64_t second_input_format = 0;
  97. OpDescPtr tmpSecondOpPtr = bias_node_ptr->GetInDataNodes().at(1)->GetOpDesc();
  98. GE_CHECK_NOTNULL(tmpSecondOpPtr);
  99. GE_IF_BOOL_EXEC(
  100. !AttrUtils::GetInt(tmp_first_op_ptr, ATTR_NAME_FORMAT, first_input_format), continue_flag = true; break);
  101. GE_IF_BOOL_EXEC(
  102. !AttrUtils::GetInt(tmpSecondOpPtr, ATTR_NAME_FORMAT, second_input_format), continue_flag = true; break);
  103. if (first_input_format != second_input_format) {
  104. GELOGW("biasadd node is followed two nodes with different format, get original format failed");
  105. continue_flag = true;
  106. break;
  107. }
  108. }
  109. }
  110. GE_IF_BOOL_EXEC(!AttrUtils::GetInt(bias_op_ptr, ATTR_NAME_FORMAT, tmp_format), continue_flag = true; break;);
  111. if (i == 0) {
  112. ori_format = tmp_format;
  113. }
  114. GE_IF_BOOL_EXEC(tmp_format != ori_format,
  115. GELOGW("node: %s , original format of src nodes must be same!", bias_node_ptr->GetName().c_str());
  116. continue_flag = true; break;);
  117. i++;
  118. }
  119. GE_IF_BOOL_EXEC(continue_flag, continue);
  120. OpDescPtr tmp_op_ptr = node_ptr->GetOpDesc();
  121. GE_CHECK_NOTNULL(tmp_op_ptr);
  122. if (IsFormatTranspose(tmp_op_ptr, static_cast<int32_t>(ori_format))) {
  123. ori_format = (ori_format == DOMI_TENSOR_NCHW) ? DOMI_TENSOR_NHWC : DOMI_TENSOR_NCHW;
  124. }
  125. if (ignore_pred_format) {
  126. GE_IF_BOOL_EXEC(!AttrUtils::SetBool(tmp_op_ptr, ATTR_NAME_IGNORE_PRED_FORMAT, true),
  127. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed",
  128. ATTR_NAME_IGNORE_PRED_FORMAT.c_str(),
  129. tmp_op_ptr->GetName().c_str(), tmp_op_ptr->GetType().c_str());
  130. GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_IGNORE_PRED_FORMAT.c_str(),
  131. tmp_op_ptr->GetName().c_str(), tmp_op_ptr->GetType().c_str());
  132. return FAILED);
  133. }
  134. // Do not reset ATTR_NAME_FORMAT if it is set in the OpParser.
  135. if (!tmp_op_ptr->HasAttr(ATTR_NAME_FORMAT)) {
  136. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(tmp_op_ptr, ATTR_NAME_FORMAT, ori_format),
  137. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed",
  138. ATTR_NAME_FORMAT.c_str(),
  139. tmp_op_ptr->GetName().c_str(), tmp_op_ptr->GetType().c_str());
  140. GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_FORMAT.c_str(),
  141. tmp_op_ptr->GetName().c_str(), tmp_op_ptr->GetType().c_str());
  142. return FAILED);
  143. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(tmp_op_ptr, ATTR_NAME_INFERRED_FORMAT, ori_format),
  144. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed",
  145. ATTR_NAME_INFERRED_FORMAT.c_str(),
  146. tmp_op_ptr->GetName().c_str(), tmp_op_ptr->GetType().c_str());
  147. GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_INFERRED_FORMAT.c_str(),
  148. tmp_op_ptr->GetName().c_str(), tmp_op_ptr->GetType().c_str());
  149. return FAILED);
  150. } else {
  151. int64_t existingFormat = 0;
  152. GE_RETURN_WITH_LOG_IF_FALSE(AttrUtils::GetInt(tmp_op_ptr, ATTR_NAME_FORMAT, existingFormat),
  153. "[Get][Attr] %s from op:%s(%s) failed", ATTR_NAME_FORMAT.c_str(),
  154. tmp_op_ptr->GetName().c_str(), tmp_op_ptr->GetType().c_str());
  155. if (!AttrUtils::SetInt(tmp_op_ptr, ATTR_NAME_INFERRED_FORMAT, existingFormat)) {
  156. REPORT_CALL_ERROR("E19999", "Set Attr:%s to op:%s(%s) failed",
  157. ATTR_NAME_INFERRED_FORMAT.c_str(),
  158. tmp_op_ptr->GetName().c_str(), tmp_op_ptr->GetType().c_str());
  159. GELOGE(FAILED, "[Set][Attr] %s to op:%s(%s) failed", ATTR_NAME_INFERRED_FORMAT.c_str(),
  160. tmp_op_ptr->GetName().c_str(), tmp_op_ptr->GetType().c_str());
  161. return FAILED;
  162. }
  163. }
  164. }
  165. return SUCCESS;
  166. }
  167. bool GetOriginalFormatPass::IsFormatTranspose(const ge::OpDescPtr op_ptr, int32_t ori_format) {
  168. GE_CHK_BOOL_EXEC(op_ptr != nullptr, return false, "[Check][Param] op_ptr is nullptr");
  169. if (op_ptr->GetType() == PERMUTE) {
  170. vector<int32_t> index_list;
  171. GE_IF_BOOL_EXEC(!AttrUtils::GetListInt(op_ptr, PERMUTE_ATTR_ORDER, index_list), return false);
  172. auto index_size = index_list.size();
  173. GE_IF_BOOL_EXEC(static_cast<int32_t>(index_size) != PERMUTE_ORDER_NUM, return false);
  174. int32_t perm_nchw[4] = {0, 2, 3, 1}; // 4 format nums, {0,2,3,1} NCHW -> NHWC
  175. int32_t perm_nhwc[4] = {0, 3, 1, 2}; // 4 format nums, {0,3,1,2} NHWC -> NCHW
  176. bool is_nchw = true;
  177. bool is_nhwc = true;
  178. for (size_t i = 0; i < index_size; ++i) {
  179. is_nchw = (perm_nchw[i] != index_list[i]) ? false : is_nchw;
  180. is_nhwc = (perm_nhwc[i] != index_list[i]) ? false : is_nhwc;
  181. }
  182. bool ret = (is_nchw && ori_format == DOMI_TENSOR_NCHW && !is_nhwc) ||
  183. (is_nhwc && ori_format == DOMI_TENSOR_NHWC && !is_nchw);
  184. return ret;
  185. }
  186. return false;
  187. }
  188. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示