You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bcast.cc 4.6 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/common/bcast.h"
  17. #include <vector>
  18. #include "common/math_util.h"
  19. #include "common/util.h"
  20. using domi::Status;
  21. namespace ge {
  22. Status BCast::GenerateBcastInfo(const kVecInt &sx, const kVecInt &sy) {
  23. if (sx.size() == 0 && sy.size() == 0) {
  24. result_.push_back(1);
  25. x_reshape_.push_back(1);
  26. x_bcast_.push_back(1);
  27. y_reshape_.push_back(1);
  28. y_bcast_.push_back(1);
  29. } else {
  30. kVecInt x = sx;
  31. kVecInt y = sy;
  32. Reverse(x);
  33. Reverse(y);
  34. ExtendTensorDim(x, y);
  35. GE_RETURN_WITH_LOG_IF_ERROR(SetShapeDifferentInfo(x, y), "GenerateBcastInfo failed.");
  36. }
  37. ReverseAllIntermediateShapes();
  38. return domi::SUCCESS;
  39. }
  40. Status BCast::SetShapeDifferentInfo(const kVecInt &x, const kVecInt &y) {
  41. const int64_t n = x.size();
  42. for (int64_t i = 0; i < n; ++i) {
  43. const int64_t x_i = x[i];
  44. GE_CHECK_GE(x_i, 0);
  45. const int64_t y_i = y[i];
  46. GE_CHECK_GE(y_i, 0);
  47. int64_t output_i = 0;
  48. int64_t x_bcast_i = 0;
  49. int64_t y_bcast_i = 0;
  50. if (x_i == y_i) {
  51. output_i = x_i;
  52. x_bcast_i = 1;
  53. y_bcast_i = 1;
  54. if (x_i == 1) {
  55. grad_x_reduce_idx_.push_back(n - 1 - i);
  56. grad_y_reduce_idx_.push_back(n - 1 - i);
  57. }
  58. } else if (x_i == 1) {
  59. output_i = y_i;
  60. x_bcast_i = y_i;
  61. y_bcast_i = 1;
  62. grad_x_reduce_idx_.push_back(n - 1 - i);
  63. } else if (y_i == 1) {
  64. output_i = x_i;
  65. x_bcast_i = 1;
  66. y_bcast_i = x_i;
  67. grad_y_reduce_idx_.push_back(n - 1 - i);
  68. } else {
  69. GELOGE(domi::PARAM_INVALID,
  70. "SetShapeDifferentInfo failed. Two tensor shapes are not compatible "
  71. "according to the broadcasting rule.");
  72. return domi::PARAM_INVALID;
  73. }
  74. output_.push_back(output_i);
  75. result_.push_back(output_i);
  76. x_reshape_.push_back(x_i);
  77. x_bcast_.push_back(x_bcast_i);
  78. y_reshape_.push_back(y_i);
  79. y_bcast_.push_back(y_bcast_i);
  80. }
  81. return domi::SUCCESS;
  82. }
  83. void BCast::ExtendTensorDim(kVecInt &v_x, kVecInt &v_y) {
  84. if (v_x.size() > v_y.size()) {
  85. v_y.resize(v_x.size(), 1);
  86. } else {
  87. v_x.resize(v_y.size(), 1);
  88. }
  89. }
  90. BCast::kVecInt BCast::TransShapeToDimVec(const GeTensorDesc &shape) {
  91. const size_t dim_num = shape.GetShape().GetDimNum();
  92. BCast::kVecInt ret(dim_num);
  93. for (size_t i = 0; i < dim_num; ++i) {
  94. ret[i] = shape.GetShape().GetDim(i);
  95. }
  96. return ret;
  97. }
  98. void BCast::Reverse(kVecInt &shape) { std::reverse(shape.begin(), shape.end()); }
  99. void BCast::ReverseAllIntermediateShapes() {
  100. // Reverse all intermediate shape params
  101. Reverse(x_reshape_);
  102. Reverse(x_bcast_);
  103. Reverse(y_reshape_);
  104. Reverse(y_bcast_);
  105. Reverse(result_);
  106. Reverse(output_);
  107. Reverse(grad_x_reduce_idx_);
  108. Reverse(grad_y_reduce_idx_);
  109. }
  110. void BCast::BCastIndexes(kVecInt &x_indexes, kVecInt &y_indexes) {
  111. Reverse(x_reshape_);
  112. Reverse(y_reshape_);
  113. Reverse(output_);
  114. // Process 0-th dimension
  115. int64_t x_dim = 1;
  116. int64_t y_dim = 1;
  117. int64_t out_dim = 1;
  118. // If x and y are both scalar, then output_ is empty
  119. if (!output_.empty()) {
  120. x_dim = x_reshape_.at(0);
  121. y_dim = y_reshape_.at(0);
  122. out_dim = output_.at(0);
  123. }
  124. int64_t x_bias = x_dim;
  125. int64_t y_bias = y_dim;
  126. for (int64_t i = 0; i < out_dim; i++) {
  127. x_indexes.push_back(x_dim == 1 ? 0 : i);
  128. y_indexes.push_back(y_dim == 1 ? 0 : i);
  129. }
  130. // Process the remaining dimensions
  131. for (size_t i = 1; i < output_.size(); i++) {
  132. x_dim = x_reshape_.at(i); // i-th dimension of x.
  133. y_dim = y_reshape_.at(i); // i-th dimension of y.
  134. out_dim = output_.at(i); // i-th dimension of output_.
  135. int64_t stride = x_indexes.size();
  136. for (int64_t j = 1; j < out_dim; j++) {
  137. for (int64_t k = 0; k < stride; k++) {
  138. x_indexes.push_back(x_indexes.at(k) + (x_dim == 1 ? 0 : (j * x_bias)));
  139. y_indexes.push_back(y_indexes.at(k) + (y_dim == 1 ? 0 : (j * y_bias)));
  140. }
  141. }
  142. x_bias *= x_dim;
  143. y_bias *= y_dim;
  144. }
  145. Reverse(x_reshape_);
  146. Reverse(y_reshape_);
  147. Reverse(output_);
  148. }
  149. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示