You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

dynamic_stitch_kernel.cc 9.3 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
4 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "host_kernels/dynamic_stitch_kernel.h"
  17. #include <securec.h>
  18. #include <memory>
  19. #include "common/fp16_t.h"
  20. #include "common/ge_inner_error_codes.h"
  21. #include "common/math/math_util.h"
  22. #include "common/op/ge_op_utils.h"
  23. #include "common/types.h"
  24. #include "framework/common/debug/ge_log.h"
  25. #include "graph/utils/type_utils.h"
  26. #include "inc/kernel_factory.h"
  27. namespace ge {
  28. namespace {
  29. const int kDoubleAttrN = 2;
  30. const int kFirstOutputDescIdx = 0;
  31. const int kMergedShapeSecondDim = 1;
  32. const size_t kNullTensorDimNum = 1;
  33. const int64_t kNullTensorDimValue = 0;
  34. const std::set<DataType> kSupportedTypeSet = {DT_INT8, DT_UINT8, DT_INT16, DT_UINT16, DT_INT32,
  35. DT_INT64, DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_DOUBLE};
  36. } // namespace
  37. Status DynamicStitchKernel::Compute(const OpDescPtr op_desc_ptr, const vector<ConstGeTensorPtr> &input,
  38. vector<GeTensorPtr> &v_output) {
  39. GELOGD("DynamicStitch Kernel in.");
  40. Status validate_ret = ValidateParams(op_desc_ptr, input);
  41. if (validate_ret != SUCCESS) {
  42. GELOGW("Dynamic stitch kernel params validate failed.");
  43. return NOT_CHANGED;
  44. }
  45. // OutputDesc size is not null, validated before
  46. GeTensorPtr output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(kFirstOutputDescIdx));
  47. if (output_ptr == nullptr) {
  48. GELOGW("Fail to malloc output.");
  49. return NOT_CHANGED;
  50. }
  51. Status ret = GenData(input, output_ptr);
  52. if (ret != SUCCESS) {
  53. GELOGW("Dynamic stitch folding failed.");
  54. return NOT_CHANGED;
  55. }
  56. v_output.push_back(output_ptr);
  57. GELOGD("Dynamic stitch end.");
  58. return SUCCESS;
  59. }
  60. Status DynamicStitchKernel::ValidateParams(const OpDescPtr &op_desc_ptr, const std::vector<ConstGeTensorPtr> &input) {
  61. if (op_desc_ptr == nullptr) {
  62. GELOGW("Input op_desc is nullptr.");
  63. return PARAM_INVALID;
  64. }
  65. if (op_desc_ptr->GetOutputsSize() == 0) {
  66. GELOGW("Current output_desc is empty.");
  67. return PARAM_INVALID;
  68. }
  69. // validate input
  70. // input[0]~input[N-1] is indices, input[N]~input[2N-1] is data
  71. if (input.empty()) {
  72. GELOGI("Input is empty. Ignore dynamic stitch kernel.");
  73. return NOT_CHANGED;
  74. }
  75. for (const auto &in : input) {
  76. if (in == nullptr) {
  77. GELOGW("input is nullptr.");
  78. return PARAM_INVALID;
  79. }
  80. }
  81. // validate attrs
  82. if (!(AttrUtils::GetInt(op_desc_ptr, ATTR_NAME_N, n_))) {
  83. GELOGW("Attr %s is not exist.", ATTR_NAME_N.c_str());
  84. return NOT_CHANGED;
  85. }
  86. // validate attr N and input.size
  87. if ((kDoubleAttrN * n_) > static_cast<int>(input.size())) {
  88. GELOGW("Input size %zu is not not match with attr %d. Ignore dynamic stitch kernel.", input.size(), n_);
  89. return NOT_CHANGED;
  90. }
  91. // validate supported datatype
  92. DataType data_type = input[n_]->GetTensorDesc().GetDataType();
  93. if (kSupportedTypeSet.find(data_type) == kSupportedTypeSet.end()) {
  94. GELOGW("Input data_type %s is not supported. Please check IR definition. Ignore dynamic stitch kernel.",
  95. TypeUtils::DataTypeToSerialString(data_type).c_str());
  96. return NOT_CHANGED;
  97. }
  98. return SUCCESS;
  99. }
  100. void DynamicStitchKernel::ComputeMergedShape(const vector<ConstGeTensorPtr> &input, GeShape &merged_shape) {
  101. // Safety note: index [1~2*n_] for input is valid, and all input is not null, validated in ValidateParams
  102. // merged.shape = [max(indices)] + step
  103. // 1. Compute merged first dim, which is the max index.
  104. int32_t merged_first_dim = 0;
  105. int64_t indices_shape_size = 0;
  106. for (int i = 0; i < n_; i++) {
  107. indices_shape_size = input[i]->GetTensorDesc().GetShape().GetShapeSize();
  108. indices_shape_size = indices_shape_size == 0 ? 1 : indices_shape_size;
  109. const int32_t *input_indices = reinterpret_cast<const int32_t *>(input[i]->GetData().data());
  110. for (int64_t j = 0; j < indices_shape_size; j++) {
  111. merged_first_dim = std::max(merged_first_dim, input_indices[j]);
  112. }
  113. }
  114. // 2. Compute step, which is follow : step = data[t].shape - indices[t].shape
  115. size_t indices_dim_num = input[0]->GetTensorDesc().GetShape().GetDimNum();
  116. GeShape data_shape = input[n_]->GetTensorDesc().GetShape();
  117. int64_t step = (data_shape.GetDimNum() == indices_dim_num) ? 0 : data_shape.GetDim(indices_dim_num);
  118. vector<int64_t> merged_dim_vec = {merged_first_dim + 1};
  119. if (step > 0) {
  120. merged_dim_vec.emplace_back(step);
  121. GELOGD("merged_shape is [ %ld, %ld].", merged_first_dim, step);
  122. }
  123. merged_shape = GeShape(merged_dim_vec);
  124. GELOGD("merged_shape is [ %ld ].", merged_first_dim);
  125. }
  126. Status DynamicStitchKernel::GenData(const vector<ConstGeTensorPtr> &input, GeTensorPtr &output_ptr) {
  127. // Safety note: index [1~2*n_] for input is valid, and all input is not null, validated in ValidateParams
  128. GeShape merged_shape;
  129. ComputeMergedShape(input, merged_shape);
  130. auto data_type = input[n_]->GetTensorDesc().GetDataType();
  131. // 1.calc output data size
  132. auto output_size = merged_shape.GetShapeSize();
  133. int64_t data_size = GetSizeByDataType(data_type);
  134. auto step = merged_shape.GetDim(kMergedShapeSecondDim);
  135. if (!CheckInt64MulOverflow(output_size, data_size) || !CheckInt64MulOverflow(step, data_size)) {
  136. GELOGW("Check int64 mul overflow failed. Output_size is %ld, data_size is %ld, step is %ld.", output_size,
  137. data_size, step);
  138. return NOT_CHANGED;
  139. }
  140. auto allowance = output_size * data_size;
  141. auto data_unit = step > 0 ? step * data_size : data_size;
  142. // 2.allocate memery for output
  143. std::unique_ptr<uint8_t[]> buf(new (std::nothrow) uint8_t[allowance]);
  144. if (buf == nullptr) {
  145. GELOGW("new buffer failed");
  146. return INTERNAL_ERROR;
  147. }
  148. // 3.copy data from input_data along with the sequence of input_indices
  149. Status stitch_ret = StitchDataFollowIndices(data_unit, input, allowance, buf);
  150. if (stitch_ret != SUCCESS) {
  151. GELOGW("Stitch data follow index failed.");
  152. return NOT_CHANGED;
  153. }
  154. output_ptr->MutableTensorDesc().SetDataType(data_type);
  155. output_ptr->MutableTensorDesc().SetShape(merged_shape);
  156. Status ret = output_ptr->SetData(buf.get(), allowance);
  157. if (ret != GRAPH_SUCCESS) {
  158. GELOGW("set data failed");
  159. return NOT_CHANGED;
  160. }
  161. return SUCCESS;
  162. }
  163. Status DynamicStitchKernel::StitchDataFollowIndices(int64_t data_unit, const vector<ConstGeTensorPtr> &input,
  164. int64_t allowance, std::unique_ptr<uint8_t[]> &buf) {
  165. // Safety note: index [1~2*n_] for input is valid, and all input is not null, validated in ValidateParams
  166. int64_t dst_offset = 0;
  167. int64_t src_offset = 0;
  168. std::set<int32_t> indices_set;
  169. for (int i = 0; i < n_; i++) {
  170. GeShape indices_shape = input[i]->GetTensorDesc().GetShape();
  171. size_t indices_dim_num = indices_shape.GetDimNum();
  172. // skip null indices tensor
  173. if (indices_dim_num == kNullTensorDimNum && indices_shape.GetDim(0) == kNullTensorDimValue) {
  174. GELOGD("Input indices[%d] has null tensor, skip it.", i);
  175. continue;
  176. }
  177. auto indices_shape_size = indices_shape.GetShapeSize();
  178. // to normalize logic, assume scalar as vector with shape of [1].
  179. indices_shape_size = (indices_shape_size == 0) ? 1 : indices_shape_size;
  180. // all index for input is less than size of input
  181. const int32_t *input_indices = reinterpret_cast<const int32_t *>(input[i]->GetData().data());
  182. const uint8_t *input_data = input[i + n_]->GetData().data();
  183. for (int64_t j = 0; j < indices_shape_size; j++) {
  184. // if index repeated, need new data replace old data , so give more allowance
  185. if (indices_set.find(input_indices[j]) != indices_set.end()) {
  186. if (ge::CheckInt64AddOverflow(input_indices[j], data_unit) != SUCCESS) {
  187. GELOGW("Check int64 mul overflow failed. Indices is %ld, data_unit is %ld.", input_indices[j], data_unit);
  188. return NOT_CHANGED;
  189. }
  190. allowance += data_unit;
  191. }
  192. indices_set.insert(input_indices[j]);
  193. if (!CheckInt64MulOverflow(input_indices[j], data_unit)) {
  194. GELOGW("Check int64 mul overflow failed. Indices is %ld, data_unit is %ld.", input_indices[j], data_unit);
  195. return NOT_CHANGED;
  196. }
  197. dst_offset = input_indices[j] * data_unit;
  198. src_offset = j * data_unit;
  199. auto protected_size =
  200. allowance < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) ? allowance : static_cast<int64_t>(SECUREC_MEM_MAX_LEN);
  201. auto ret = memcpy_s(buf.get() + dst_offset, protected_size, input_data + src_offset, data_unit);
  202. if (ret != EOK) {
  203. GELOGW("Memory copy failed.");
  204. return NOT_CHANGED;
  205. }
  206. allowance -= data_unit;
  207. }
  208. }
  209. return SUCCESS;
  210. }
  211. REGISTER_KERNEL(DYNAMICSTITCH, DynamicStitchKernel);
  212. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示