You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

get_original_format_pass.cc 7.3 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/passes/get_original_format_pass.h"
  17. #include <vector>
  18. #include "common/debug/log.h"
  19. #include "common/types.h"
  20. #include "common/util.h"
  21. #include "framework/common/debug/ge_log.h"
  22. #include "framework/omg/omg_inner_types.h"
  23. #include "graph/utils/attr_utils.h"
  24. #include "graph/utils/op_desc_utils.h"
  25. #include "graph/common/local_context.h"
  26. using domi::DOMI_TENSOR_NCHW;
  27. using domi::DOMI_TENSOR_NHWC;
  28. using domi::DOMI_TENSOR_RESERVED;
  29. using domi::FAILED;
  30. using domi::PARAM_INVALID;
  31. using domi::SUCCESS;
  32. namespace ge {
  33. Status GetOriginalFormatPass::Run(ge::ComputeGraphPtr graph) {
  34. GE_CHECK_NOTNULL(graph);
  35. GE_RETURN_WITH_LOG_IF_ERROR(SetOriginalFormat(graph), "SetOriginalFormat failed");
  36. return SUCCESS;
  37. }
  38. Status GetOriginalFormatPass::SetOriginalFormat(const ge::ComputeGraphPtr &graph) {
  39. GE_CHECK_NOTNULL(graph);
  40. int64_t ori_format = 0;
  41. int64_t tmp_format = 0;
  42. for (auto &node_ptr : graph->GetDirectNode()) {
  43. GE_CHECK_NOTNULL(node_ptr);
  44. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(node_ptr->GetOpDesc(), ATTR_NAME_INFERRED_FORMAT, DOMI_TENSOR_RESERVED),
  45. GELOGE(FAILED, "set ATTR_NAME_INFERRED_FORMAT failed");
  46. return FAILED);
  47. }
  48. for (auto &node_ptr : graph->GetDirectNode()) {
  49. GE_CHECK_NOTNULL(node_ptr);
  50. OpDescPtr desc_ptr = node_ptr->GetOpDesc();
  51. GE_CHECK_NOTNULL(desc_ptr);
  52. auto is_data = (desc_ptr->GetType() == DATA_TYPE || desc_ptr->GetType() == AIPP_DATA_TYPE);
  53. if (is_data) {
  54. GELOGI("Data node: %s,format :%d", node_ptr->GetName().c_str(), GetLocalOmgContext().format);
  55. ori_format = static_cast<int64_t>(GetLocalOmgContext().format);
  56. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(desc_ptr, ATTR_NAME_FORMAT, ori_format),
  57. GELOGE(FAILED, "set ATTR_NAME_FORMAT failed");
  58. return FAILED);
  59. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(desc_ptr, ATTR_NAME_INFERRED_FORMAT, ori_format),
  60. GELOGE(FAILED, "set ATTR_NAME_INFERRED_FORMAT failed");
  61. return FAILED);
  62. continue;
  63. }
  64. int32_t i = 0;
  65. bool continue_flag = false;
  66. bool ignore_pred_format = false;
  67. for (auto &bias_node_ptr : node_ptr->GetInDataNodes()) {
  68. GE_CHECK_NOTNULL(bias_node_ptr);
  69. OpDescPtr bias_op_ptr = bias_node_ptr->GetOpDesc();
  70. GE_CHECK_NOTNULL(bias_op_ptr);
  71. if (bias_op_ptr->GetType() == BIASADD) {
  72. ignore_pred_format = true;
  73. std::size_t tmp_size = ge::OpDescUtils::GetNonConstInputsSize(bias_node_ptr);
  74. GE_IF_BOOL_EXEC(tmp_size > 2 || tmp_size == 0,
  75. GELOGW("bias_node is node followed by %zu nodes, should be 1 or 2", tmp_size);
  76. continue_flag = true; break);
  77. OpDescPtr tmp_first_op_ptr = bias_node_ptr->GetInDataNodes().at(0)->GetOpDesc();
  78. GE_CHECK_NOTNULL(tmp_first_op_ptr);
  79. bias_op_ptr = tmp_first_op_ptr;
  80. // if biasadd have 2 input edges, format should be same
  81. if (tmp_size == 2) {
  82. int64_t first_input_format = 0;
  83. int64_t second_input_format = 0;
  84. OpDescPtr tmpSecondOpPtr = bias_node_ptr->GetInDataNodes().at(1)->GetOpDesc();
  85. GE_CHECK_NOTNULL(tmpSecondOpPtr);
  86. GE_IF_BOOL_EXEC(
  87. !AttrUtils::GetInt(tmp_first_op_ptr, ATTR_NAME_FORMAT, first_input_format), continue_flag = true; break);
  88. GE_IF_BOOL_EXEC(
  89. !AttrUtils::GetInt(tmpSecondOpPtr, ATTR_NAME_FORMAT, second_input_format), continue_flag = true; break);
  90. if (first_input_format != second_input_format) {
  91. GELOGW("biasadd node is followed two nodes with different format, get original format failed");
  92. continue_flag = true;
  93. break;
  94. }
  95. }
  96. }
  97. GE_IF_BOOL_EXEC(!AttrUtils::GetInt(bias_op_ptr, ATTR_NAME_FORMAT, tmp_format), continue_flag = true; break;);
  98. if (i == 0) {
  99. ori_format = tmp_format;
  100. }
  101. GE_IF_BOOL_EXEC(tmp_format != ori_format,
  102. GELOGW("node: %s , original format of src nodes must be same!", bias_node_ptr->GetName().c_str());
  103. continue_flag = true; break;);
  104. i++;
  105. }
  106. GE_IF_BOOL_EXEC(continue_flag, continue);
  107. OpDescPtr tmp_op_ptr = node_ptr->GetOpDesc();
  108. GE_CHECK_NOTNULL(tmp_op_ptr);
  109. if (IsFormatTranspose(tmp_op_ptr, static_cast<int32_t>(ori_format))) {
  110. ori_format = (ori_format == DOMI_TENSOR_NCHW) ? DOMI_TENSOR_NHWC : DOMI_TENSOR_NCHW;
  111. }
  112. if (ignore_pred_format) {
  113. GE_IF_BOOL_EXEC(!AttrUtils::SetBool(tmp_op_ptr, ATTR_NAME_IGNORE_PRED_FORMAT, true),
  114. GELOGE(FAILED, "remove edge failed");
  115. return FAILED);
  116. }
  117. // Do not reset ATTR_NAME_FORMAT if it is set in the OpParser.
  118. if (!tmp_op_ptr->HasAttr(ATTR_NAME_FORMAT)) {
  119. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(tmp_op_ptr, ATTR_NAME_FORMAT, ori_format),
  120. GELOGE(FAILED, "set ATTR_NAME_FORMAT failed");
  121. return FAILED);
  122. GE_IF_BOOL_EXEC(!AttrUtils::SetInt(tmp_op_ptr, ATTR_NAME_INFERRED_FORMAT, ori_format),
  123. GELOGE(FAILED, "set ATTR_NAME_INFERRED_FORMAT failed");
  124. return FAILED);
  125. } else {
  126. int64_t existingFormat = 0;
  127. GE_RETURN_WITH_LOG_IF_FALSE(AttrUtils::GetInt(tmp_op_ptr, ATTR_NAME_FORMAT, existingFormat),
  128. "Get existing_format attr failed");
  129. if (!AttrUtils::SetInt(tmp_op_ptr, ATTR_NAME_INFERRED_FORMAT, existingFormat)) {
  130. GELOGE(FAILED, "set ATTR_NAME_INFERRED_FORMAT failed");
  131. return FAILED;
  132. }
  133. }
  134. }
  135. return SUCCESS;
  136. }
  137. bool GetOriginalFormatPass::IsFormatTranspose(const ge::OpDescPtr op_ptr, int32_t ori_format) {
  138. GE_CHK_BOOL_EXEC(op_ptr != nullptr, return false, "opdef is nullptr");
  139. if (op_ptr->GetType() == PERMUTE) {
  140. vector<int32_t> index_list;
  141. GE_IF_BOOL_EXEC(!AttrUtils::GetListInt(op_ptr, PERMUTE_ATTR_ORDER, index_list), return false);
  142. auto index_size = index_list.size();
  143. GE_IF_BOOL_EXEC(static_cast<int32_t>(index_size) != PERMUTE_ORDER_NUM, return false);
  144. int32_t perm_nchw[4] = {0, 2, 3, 1}; // 4 format nums, {0,2,3,1} NCHW -> NHWC
  145. int32_t perm_nhwc[4] = {0, 3, 1, 2}; // 4 format nums, {0,3,1,2} NHWC -> NCHW
  146. bool is_nchw = true;
  147. bool is_nhwc = true;
  148. for (size_t i = 0; i < index_size; ++i) {
  149. is_nchw = (perm_nchw[i] != index_list[i]) ? false : is_nchw;
  150. is_nhwc = (perm_nhwc[i] != index_list[i]) ? false : is_nhwc;
  151. }
  152. bool ret = (is_nchw && ori_format == DOMI_TENSOR_NCHW && !is_nhwc) ||
  153. (is_nhwc && ori_format == DOMI_TENSOR_NHWC && !is_nchw);
  154. return ret;
  155. }
  156. return false;
  157. }
  158. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示