You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

bcast.h 7.1 kB

5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #ifndef GE_GRAPH_COMMON_BCAST_H_
  17. #define GE_GRAPH_COMMON_BCAST_H_
  18. #include <stdint.h>
  19. #include <functional>
  20. #include <vector>
  21. #include "common/debug/log.h"
  22. #include "common/types.h"
  23. #include "framework/common/debug/ge_log.h"
  24. #include "framework/common/ge_inner_error_codes.h"
  25. #include "graph/attr_value.h"
  26. #include "graph/ge_tensor.h"
  27. #include "graph/utils/tensor_adapter.h"
  28. namespace ge {
  29. static const size_t kMinDimNum = 2;
  30. class BCast {
  31. public:
  32. ///
  33. /// @ingroup domi_calibration
  34. /// @brief define kVecInt
  35. ///
  36. typedef std::vector<int64_t> kVecInt;
  37. ///
  38. /// @ingroup domi_calibration
  39. /// @brief constructor
  40. ///
  41. BCast() {}
  42. ///
  43. /// @ingroup domi_calibration
  44. /// @brief destructor
  45. ///
  46. ~BCast() {}
  47. ///
  48. /// @ingroup domi_calibration
  49. /// @brief Not optimize intermediate shapes
  50. /// @decrease dims, more efficient, set by user
  51. /// @param [in] x first Tensor dim
  52. /// @param [in] y second Tensor dim
  53. /// @return SUCCESS broadcast message successfully generated
  54. /// @return other broadcast message failed to generate
  55. ///
  56. ge::Status GenerateBcastInfo(const kVecInt &x, const kVecInt &y);
  57. ///
  58. /// @ingroup domi_calibration
  59. /// @brief get x_reshape
  60. ///
  61. const kVecInt &GetXReshape() const { return x_reshape_; }
  62. ///
  63. /// @ingroup domi_calibration
  64. /// @brief get x_bcast
  65. ///
  66. const kVecInt &GetXBcast() const { return x_bcast_; }
  67. ///
  68. /// @ingroup domi_calibration
  69. /// @brief get y_reshape
  70. ///
  71. const kVecInt &GetYReshape() const { return y_reshape_; }
  72. ///
  73. /// @ingroup domi_calibration
  74. /// @brief get y_bcast
  75. ///
  76. const kVecInt &GetYBcast() const { return y_bcast_; }
  77. ///
  78. /// @ingroup domi_calibration
  79. /// @brief get result_shape
  80. ///
  81. const kVecInt &GetResultShape() const { return result_; }
  82. ///
  83. /// @ingroup domi_calibration
  84. /// @brief get result_shape
  85. ///
  86. const kVecInt &GetOutputShape() const { return output_; }
  87. const kVecInt &GetGradXReduceIdx() const { return grad_x_reduce_idx_; }
  88. const kVecInt &GetGradYReduceIdx() const { return grad_y_reduce_idx_; }
  89. ///
  90. /// @ingroup domi_calibration
  91. /// @brief convert TensorDescriptor to kVecInt
  92. /// @param [in] shape Tensor descriptor
  93. /// @return kVecInt dim info
  94. ///
  95. static kVecInt TransShapeToDimVec(const GeTensorDesc &shape);
  96. void BCastIndexes(kVecInt &x_indexes, kVecInt &y_indexes);
  97. template <typename InT, typename OutT>
  98. Status BCastCompute(const std::vector<ConstGeTensorPtr> &input, std::vector<OutT> &v_output,
  99. const std::function<OutT(InT const &, InT const &)> &func) {
  100. Status ret;
  101. if (func == nullptr) {
  102. GELOGE(domi::PARAM_INVALID, "Param func is null");
  103. return domi::PARAM_INVALID;
  104. }
  105. // Min input num is 2
  106. if (input.size() < kMinDimNum) {
  107. GELOGE(domi::PARAM_INVALID, "Input size is smaller than two.");
  108. return domi::PARAM_INVALID;
  109. }
  110. // Only broadcast shape
  111. ret =
  112. GenerateBcastInfo(TransShapeToDimVec(input[0]->GetTensorDesc()), TransShapeToDimVec(input[1]->GetTensorDesc()));
  113. if (ret != domi::SUCCESS) {
  114. GELOGE(ret, "Greater broadcasting failed.");
  115. return ret;
  116. }
  117. kVecInt x_indexes;
  118. kVecInt y_indexes;
  119. BCastIndexes(x_indexes, y_indexes);
  120. const void *x1_data = input[0]->GetData().data();
  121. const void *x2_data = input[1]->GetData().data();
  122. for (size_t i = 0; i < x_indexes.size(); i++) {
  123. int64_t x_index = x_indexes[i];
  124. int64_t y_index = y_indexes[i];
  125. auto value = func((*(reinterpret_cast<const InT *>(x1_data) + x_index)),
  126. (*(reinterpret_cast<const InT *>(x2_data) + y_index)));
  127. v_output.push_back(value);
  128. }
  129. return domi::SUCCESS;
  130. }
  131. template <typename InT, typename OutT>
  132. Status BCastComputeCheck(const std::vector<ConstGeTensorPtr> &input, std::vector<OutT> &v_output,
  133. const std::function<OutT(InT const &, InT const &, DataType &type, Status &)> &func) {
  134. if (func == nullptr) {
  135. GELOGE(PARAM_INVALID, "Param func is null");
  136. return PARAM_INVALID;
  137. }
  138. // Min input num is 2
  139. if (input.size() < kMinDimNum) {
  140. GELOGE(PARAM_INVALID, "Input size is smaller than two.");
  141. return PARAM_INVALID;
  142. }
  143. // Only broadcast shape
  144. Status ret =
  145. GenerateBcastInfo(TransShapeToDimVec(input[0]->GetTensorDesc()), TransShapeToDimVec(input[1]->GetTensorDesc()));
  146. if (ret != SUCCESS) {
  147. GELOGE(ret, "Greater broadcasting failed.");
  148. return ret;
  149. }
  150. DataType data_type = input[0]->GetTensorDesc().GetDataType();
  151. kVecInt x_indexes;
  152. kVecInt y_indexes;
  153. BCastIndexes(x_indexes, y_indexes);
  154. const void *x1_data = input[0]->GetData().data();
  155. const void *x2_data = input[1]->GetData().data();
  156. for (size_t i = 0; i < x_indexes.size(); i++) {
  157. int64_t x_index = x_indexes[i];
  158. int64_t y_index = y_indexes[i];
  159. auto value = func((*(reinterpret_cast<const InT *>(x1_data) + x_index)),
  160. (*(reinterpret_cast<const InT *>(x2_data) + y_index)), data_type, ret);
  161. if (ret != SUCCESS) {
  162. GELOGE(ret, "BCastComputeCheck func execute failed, datatype is %d.", data_type);
  163. return ret;
  164. }
  165. v_output.push_back(value);
  166. }
  167. return SUCCESS;
  168. }
  169. private:
  170. ///
  171. /// @ingroup domi_calibration
  172. /// @brief reverse elements in kVecInt
  173. /// @param [in] shape dim info
  174. /// @return null
  175. ///
  176. static void Reverse(kVecInt &shape);
  177. ///
  178. /// @ingroup domi_calibration
  179. /// @brief two Tensor with different shape, set broadcast info
  180. /// @param [in] x first input Tensor dim info
  181. /// @param [in] y second input Tensor dim info
  182. /// @return null
  183. ///
  184. ge::Status SetShapeDifferentInfo(const kVecInt &x, const kVecInt &y);
  185. ///
  186. /// @ingroup domi_calibration
  187. /// @brief extend Tensor dim
  188. /// @param [in] x first input Tensor dim info
  189. /// @param [in] y second input Tensor dim info
  190. /// @return null
  191. ///
  192. void ExtendTensorDim(kVecInt &x, kVecInt &y);
  193. ///
  194. /// @ingroup domi_calibration
  195. /// @brief reverse all intermediate shape params
  196. /// @param [in] void
  197. /// @return null
  198. ///
  199. void ReverseAllIntermediateShapes();
  200. kVecInt x_reshape_;
  201. kVecInt x_bcast_;
  202. kVecInt y_reshape_;
  203. kVecInt y_bcast_;
  204. kVecInt result_;
  205. kVecInt output_;
  206. kVecInt grad_x_reduce_idx_;
  207. kVecInt grad_y_reduce_idx_;
  208. };
  209. } // namespace ge
  210. #endif // GE_GRAPH_COMMON_BCAST_H_

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示