You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bcast.cc 4.8 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "common/bcast.h"
  17. #include <vector>
  18. #include "common/math_util.h"
  19. #include "common/util.h"
  20. using domi::Status;
  21. namespace ge {
  22. Status BCast::GenerateBcastInfo(const kVecInt &sx, const kVecInt &sy) {
  23. if (sx.size() == 0 && sy.size() == 0) {
  24. result_.push_back(1);
  25. x_reshape_.push_back(1);
  26. x_bcast_.push_back(1);
  27. y_reshape_.push_back(1);
  28. y_bcast_.push_back(1);
  29. } else {
  30. kVecInt x = sx;
  31. kVecInt y = sy;
  32. Reverse(x);
  33. Reverse(y);
  34. ExtendTensorDim(x, y);
  35. GE_RETURN_WITH_LOG_IF_ERROR(SetShapeDifferentInfo(x, y), "[Set][ShapeDifferentInfo] GenerateBcastInfo failed.");
  36. }
  37. ReverseAllIntermediateShapes();
  38. return domi::SUCCESS;
  39. }
  40. Status BCast::SetShapeDifferentInfo(const kVecInt &x, const kVecInt &y) {
  41. const int64_t n = x.size();
  42. for (int64_t i = 0; i < n; ++i) {
  43. const int64_t x_i = x[i];
  44. GE_CHECK_GE(x_i, 0);
  45. const int64_t y_i = y[i];
  46. GE_CHECK_GE(y_i, 0);
  47. int64_t output_i = 0;
  48. int64_t x_bcast_i = 0;
  49. int64_t y_bcast_i = 0;
  50. if (x_i == y_i) {
  51. output_i = x_i;
  52. x_bcast_i = 1;
  53. y_bcast_i = 1;
  54. if (x_i == 1) {
  55. grad_x_reduce_idx_.push_back(n - 1 - i);
  56. grad_y_reduce_idx_.push_back(n - 1 - i);
  57. }
  58. } else if (x_i == 1) {
  59. output_i = y_i;
  60. x_bcast_i = y_i;
  61. y_bcast_i = 1;
  62. grad_x_reduce_idx_.push_back(n - 1 - i);
  63. } else if (y_i == 1) {
  64. output_i = x_i;
  65. x_bcast_i = 1;
  66. y_bcast_i = x_i;
  67. grad_y_reduce_idx_.push_back(n - 1 - i);
  68. } else {
  69. REPORT_INNER_ERROR("E19999", "SetShapeDifferentInfo failed. Two tensor shapes are not compatible "
  70. "according to the broadcasting rule.");
  71. GELOGE(domi::PARAM_INVALID,
  72. "[Check][Param] SetShapeDifferentInfo failed. Two tensor shapes are not compatible "
  73. "according to the broadcasting rule.");
  74. return domi::PARAM_INVALID;
  75. }
  76. output_.push_back(output_i);
  77. result_.push_back(output_i);
  78. x_reshape_.push_back(x_i);
  79. x_bcast_.push_back(x_bcast_i);
  80. y_reshape_.push_back(y_i);
  81. y_bcast_.push_back(y_bcast_i);
  82. }
  83. return domi::SUCCESS;
  84. }
  85. void BCast::ExtendTensorDim(kVecInt &v_x, kVecInt &v_y) {
  86. if (v_x.size() > v_y.size()) {
  87. v_y.resize(v_x.size(), 1);
  88. } else {
  89. v_x.resize(v_y.size(), 1);
  90. }
  91. }
  92. BCast::kVecInt BCast::TransShapeToDimVec(const GeTensorDesc &shape) {
  93. const size_t dim_num = shape.GetShape().GetDimNum();
  94. BCast::kVecInt ret(dim_num);
  95. for (size_t i = 0; i < dim_num; ++i) {
  96. ret[i] = shape.GetShape().GetDim(i);
  97. }
  98. return ret;
  99. }
  100. void BCast::Reverse(kVecInt &shape) { std::reverse(shape.begin(), shape.end()); }
  101. void BCast::ReverseAllIntermediateShapes() {
  102. // Reverse all intermediate shape params
  103. Reverse(x_reshape_);
  104. Reverse(x_bcast_);
  105. Reverse(y_reshape_);
  106. Reverse(y_bcast_);
  107. Reverse(result_);
  108. Reverse(output_);
  109. Reverse(grad_x_reduce_idx_);
  110. Reverse(grad_y_reduce_idx_);
  111. }
  112. void BCast::BCastIndexes(kVecInt &x_indexes, kVecInt &y_indexes) {
  113. Reverse(x_reshape_);
  114. Reverse(y_reshape_);
  115. Reverse(output_);
  116. // Process 0-th dimension
  117. int64_t x_dim = 1;
  118. int64_t y_dim = 1;
  119. int64_t out_dim = 1;
  120. // If x and y are both scalar, then output_ is empty
  121. if (!output_.empty()) {
  122. x_dim = x_reshape_.at(0);
  123. y_dim = y_reshape_.at(0);
  124. out_dim = output_.at(0);
  125. }
  126. int64_t x_bias = x_dim;
  127. int64_t y_bias = y_dim;
  128. for (int64_t i = 0; i < out_dim; i++) {
  129. x_indexes.push_back(x_dim == 1 ? 0 : i);
  130. y_indexes.push_back(y_dim == 1 ? 0 : i);
  131. }
  132. // Process the remaining dimensions
  133. for (size_t i = 1; i < output_.size(); i++) {
  134. x_dim = x_reshape_.at(i); // i-th dimension of x.
  135. y_dim = y_reshape_.at(i); // i-th dimension of y.
  136. out_dim = output_.at(i); // i-th dimension of output_.
  137. int64_t stride = x_indexes.size();
  138. for (int64_t j = 1; j < out_dim; j++) {
  139. for (int64_t k = 0; k < stride; k++) {
  140. x_indexes.push_back(x_indexes.at(k) + (x_dim == 1 ? 0 : (j * x_bias)));
  141. y_indexes.push_back(y_indexes.at(k) + (y_dim == 1 ? 0 : (j * y_bias)));
  142. }
  143. }
  144. x_bias *= x_dim;
  145. y_bias *= y_dim;
  146. }
  147. Reverse(x_reshape_);
  148. Reverse(y_reshape_);
  149. Reverse(output_);
  150. }
  151. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示