You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gather_v2_kernel.cc 20 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "host_kernels/gather_v2_kernel.h"
  17. #include <memory>
  18. #include <set>
  19. #include "common/fp16_t.h"
  20. #include "common/ge_inner_error_codes.h"
  21. #include "common/op/ge_op_utils.h"
  22. #include "common/types.h"
  23. #include "common/util.h"
  24. #include "framework/common/debug/ge_log.h"
  25. #include "host_kernels/kernel_utils.h"
  26. #include "graph/utils/type_utils.h"
  27. #include "inc/kernel_factory.h"
  28. namespace ge {
  29. namespace {
  30. const size_t kGatherV2InputIndexZero = 0;
  31. const size_t kGatherV2InputIndexOne = 1;
  32. const size_t kGatherV2InputIndexTwo = 2;
  33. const size_t kGatherV2InputIndexThree = 3;
  34. const size_t kGatherV2DimOne = 1;
  35. const size_t kGatherV2InpotNum = 3;
  36. const size_t kMaxIndicatesDims = 1; // only support scalar and 1 dims indicates_
  37. const std::set<DataType> supported_type = {DT_FLOAT16, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT16, DT_INT32,
  38. DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64};
  39. const int64_t DIM_AXIS_0 = 0;
  40. const int64_t DIM_AXIS_1 = 1;
  41. const int64_t DIM_AXIS_2 = 2;
  42. const int64_t DIM_AXIS_3 = 3;
  43. } // namespace
  44. template <typename T>
  45. Status GatherV2Kernel::ProcessAxis0(ConstGeTensorPtr tensor_x, GeTensorPtr output) {
  46. Status ret = SUCCESS;
  47. T *data_ptr_x = reinterpret_cast<T *>(const_cast<unsigned char *>(tensor_x->GetData().data()));
  48. T *data_ptr_y = reinterpret_cast<T *>(const_cast<unsigned char *>(output->GetData().data()));
  49. // index is valid, and no bigger than kGatherV2InputIndexZero
  50. size_t output_size = output->GetData().size();
  51. for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) {
  52. T *data_ptr_x_tmp = data_ptr_x + indicates_[i] * xstride_[kGatherV2InputIndexZero];
  53. T *data_ptr_y_tmp = data_ptr_y + i * ystride_[kGatherV2InputIndexZero];
  54. size_t size = sizeof(T) * xstride_[kGatherV2InputIndexZero];
  55. if (data_ptr_y_tmp - data_ptr_y < 0) {
  56. GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero");
  57. return PARAM_INVALID;
  58. }
  59. size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T);
  60. auto ret_mem = memcpy_s(reinterpret_cast<void *>(data_ptr_y_tmp), output_size - offset_size,
  61. reinterpret_cast<void *>(data_ptr_x_tmp), size);
  62. if (ret_mem != 0) {
  63. GELOGE(MEMALLOC_FAILED, "memcpy failed!");
  64. return MEMALLOC_FAILED;
  65. }
  66. }
  67. return ret;
  68. }
  69. template <typename T>
  70. Status GatherV2Kernel::ProcessAxis1(ConstGeTensorPtr tensor_x, GeTensorPtr output) {
  71. Status ret = SUCCESS;
  72. T *data_ptr_x = reinterpret_cast<T *>(const_cast<unsigned char *>(tensor_x->GetData().data()));
  73. T *data_ptr_y = reinterpret_cast<T *>(const_cast<unsigned char *>(output->GetData().data()));
  74. // index is valid, and no bigger than kGatherV2InputIndexOne
  75. size_t output_size = output->GetData().size();
  76. for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) {
  77. T *data_ptr_x_i = data_ptr_x + i * xstride_[kGatherV2InputIndexZero];
  78. T *data_ptr_y_i = data_ptr_y + i * ystride_[kGatherV2InputIndexZero];
  79. for (int64_t j = 0; j < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexOne); j++) {
  80. T *data_ptr_x_tmp = data_ptr_x_i + indicates_[j] * xstride_[kGatherV2InputIndexOne];
  81. T *data_ptr_y_tmp = data_ptr_y_i + j * ystride_[kGatherV2InputIndexOne];
  82. size_t size = sizeof(T) * xstride_[kGatherV2InputIndexOne];
  83. if (data_ptr_y_tmp - data_ptr_y < 0) {
  84. GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero");
  85. return PARAM_INVALID;
  86. }
  87. size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T);
  88. auto ret_mem = memcpy_s(reinterpret_cast<void *>(data_ptr_y_tmp), output_size - offset_size,
  89. reinterpret_cast<void *>(data_ptr_x_tmp), size);
  90. if (ret_mem != 0) {
  91. GELOGE(MEMALLOC_FAILED, "memcpy failed!");
  92. return MEMALLOC_FAILED;
  93. }
  94. }
  95. }
  96. return ret;
  97. }
  98. template <typename T>
  99. Status GatherV2Kernel::ProcessAxis2(ConstGeTensorPtr tensor_x, GeTensorPtr output) {
  100. Status ret = SUCCESS;
  101. T *data_ptr_x = reinterpret_cast<T *>(const_cast<unsigned char *>(tensor_x->GetData().data()));
  102. T *data_ptr_y = reinterpret_cast<T *>(const_cast<unsigned char *>(output->GetData().data()));
  103. // index is valid, and no bigger than kGatherV2InputIndexTwo
  104. size_t output_size = output->GetData().size();
  105. for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) {
  106. T *data_ptr_x_i = data_ptr_x + i * xstride_[kGatherV2InputIndexZero];
  107. T *data_ptr_y_i = data_ptr_y + i * ystride_[kGatherV2InputIndexZero];
  108. for (int64_t j = 0; j < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexOne); j++) {
  109. T *data_ptr_x_j = data_ptr_x_i + j * xstride_[kGatherV2InputIndexOne];
  110. T *data_ptr_y_j = data_ptr_y_i + j * ystride_[kGatherV2InputIndexOne];
  111. for (int64_t m = 0; m < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexTwo); m++) {
  112. T *data_ptr_x_tmp = data_ptr_x_j + indicates_[m] * xstride_[kGatherV2InputIndexTwo];
  113. T *data_ptr_y_tmp = data_ptr_y_j + m * ystride_[kGatherV2InputIndexTwo];
  114. size_t size = sizeof(T) * xstride_[kGatherV2InputIndexTwo];
  115. if (data_ptr_y_tmp - data_ptr_y < 0) {
  116. GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero");
  117. return PARAM_INVALID;
  118. }
  119. size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T);
  120. auto ret_mem = memcpy_s(reinterpret_cast<void *>(data_ptr_y_tmp), output_size - offset_size,
  121. reinterpret_cast<void *>(data_ptr_x_tmp), size);
  122. if (ret_mem != 0) {
  123. GELOGE(MEMALLOC_FAILED, "memcpy failed!");
  124. return MEMALLOC_FAILED;
  125. }
  126. }
  127. }
  128. }
  129. return ret;
  130. }
  131. template <typename T>
  132. Status GatherV2Kernel::ProcessAxis3(ConstGeTensorPtr tensor_x, GeTensorPtr output) {
  133. Status ret = SUCCESS;
  134. T *data_ptr_x = reinterpret_cast<T *>(const_cast<unsigned char *>(tensor_x->GetData().data()));
  135. T *data_ptr_y = reinterpret_cast<T *>(const_cast<unsigned char *>(output->GetData().data()));
  136. // index is valid, and no bigger than kGatherV2InputIndexThree
  137. size_t output_size = output->GetData().size();
  138. for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) {
  139. T *data_ptr_x_i = data_ptr_x + i * xstride_[kGatherV2InputIndexZero];
  140. T *data_ptr_y_i = data_ptr_y + i * ystride_[kGatherV2InputIndexZero];
  141. for (int64_t j = 0; j < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexOne); j++) {
  142. T *data_ptr_x_j = data_ptr_x_i + j * xstride_[kGatherV2InputIndexOne];
  143. T *data_ptr_y_j = data_ptr_y_i + j * ystride_[kGatherV2InputIndexOne];
  144. for (int64_t m = 0; m < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexTwo); m++) {
  145. T *data_ptr_x_m = data_ptr_x_j + m * xstride_[kGatherV2InputIndexTwo];
  146. T *data_ptr_y_m = data_ptr_y_j + m * ystride_[kGatherV2InputIndexTwo];
  147. for (int64_t n = 0; n < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexThree); n++) {
  148. T *data_ptr_x_tmp = data_ptr_x_m + indicates_[n] * xstride_[kGatherV2InputIndexThree];
  149. T *data_ptr_y_tmp = data_ptr_y_m + n * ystride_[kGatherV2InputIndexThree];
  150. size_t size = sizeof(T) * xstride_[kGatherV2InputIndexThree];
  151. if (data_ptr_y_tmp - data_ptr_y < 0) {
  152. GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero");
  153. return PARAM_INVALID;
  154. }
  155. size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T);
  156. auto ret_mem = memcpy_s(reinterpret_cast<void *>(data_ptr_y_tmp), output_size - offset_size,
  157. reinterpret_cast<void *>(data_ptr_x_tmp), size);
  158. if (ret_mem != 0) {
  159. GELOGE(MEMALLOC_FAILED, "memcpy failed!");
  160. return MEMALLOC_FAILED;
  161. }
  162. }
  163. }
  164. }
  165. }
  166. return ret;
  167. }
  168. template <typename T>
  169. Status GatherV2Kernel::GenData(const int64_t data_num, ConstGeTensorPtr tensor_x, int64_t axis, GeTensorPtr output) {
  170. if (data_num <= 0) {
  171. return PARAM_INVALID;
  172. }
  173. if (!CheckInt64MulOverflow(data_num, sizeof(T))) {
  174. GELOGE(PARAM_INVALID, "Int64MulOverflow, data_num:%ld, type_len:%zu.", data_num, sizeof(T));
  175. return PARAM_INVALID;
  176. }
  177. std::unique_ptr<T[]> buf(new (std::nothrow) T[data_num]());
  178. if (buf == nullptr) {
  179. GELOGE(MEMALLOC_FAILED, "New sizeof(T) * data_num(%zu) memory failed", static_cast<size_t>(sizeof(T) * data_num));
  180. return MEMALLOC_FAILED;
  181. }
  182. GE_IF_BOOL_EXEC(
  183. output->SetData(reinterpret_cast<uint8_t *>(buf.get()), static_cast<size_t>(data_num * sizeof(T))) != GRAPH_SUCCESS,
  184. GELOGE(INTERNAL_ERROR, "set data failed");
  185. return INTERNAL_ERROR);
  186. Status ret = SUCCESS;
  187. switch (axis) {
  188. case DIM_AXIS_0:
  189. ret = ProcessAxis0<T>(tensor_x, output);
  190. break;
  191. case DIM_AXIS_1:
  192. ret = ProcessAxis1<T>(tensor_x, output);
  193. break;
  194. case DIM_AXIS_2:
  195. ret = ProcessAxis2<T>(tensor_x, output);
  196. break;
  197. case DIM_AXIS_3:
  198. ret = ProcessAxis3<T>(tensor_x, output);
  199. break;
  200. default:
  201. GELOGI("Only support 4 dims and below but input axis is %ld", axis);
  202. return NOT_CHANGED;
  203. }
  204. return ret;
  205. }
  206. Status GatherV2Kernel::CalcStride(std::vector<int64_t> &stride, std::vector<int64_t> dims) {
  207. if (stride.size() != dims.size() || dims.size() == 0) {
  208. return PARAM_INVALID;
  209. }
  210. int i = static_cast<int>(dims.size() - kGatherV2DimOne);
  211. stride[static_cast<size_t>(i)] = static_cast<int64_t>(kGatherV2DimOne);
  212. i--;
  213. while (i >= 0) {
  214. size_t index = static_cast<size_t>(i) + kGatherV2DimOne;
  215. if (!CheckInt64MulOverflow(stride[index], dims[index])) {
  216. GELOGE(PARAM_INVALID, "Int64MulOverflow, data_num(%ld) type_len(%ld)", stride[index], dims[index]);
  217. return PARAM_INVALID;
  218. }
  219. stride[static_cast<size_t>(i)] = stride[index] * dims[index];
  220. i--;
  221. }
  222. return SUCCESS;
  223. }
  224. Status GatherV2Kernel::Process(int64_t axis, DataType data_type, ConstGeTensorPtr input_tensor_ptr,
  225. GeTensorPtr output_ptr) {
  226. Status ret = SUCCESS;
  227. int64_t data_num = output_ptr->GetTensorDesc().GetShape().GetShapeSize();
  228. switch (data_type) {
  229. case DT_FLOAT16:
  230. ret = GenData<fp16_t>(data_num, input_tensor_ptr, axis, output_ptr);
  231. break;
  232. case DT_DOUBLE:
  233. ret = GenData<double>(data_num, input_tensor_ptr, axis, output_ptr);
  234. break;
  235. case DT_INT8:
  236. ret = GenData<int8_t>(data_num, input_tensor_ptr, axis, output_ptr);
  237. break;
  238. case DT_INT16:
  239. ret = GenData<int16_t>(data_num, input_tensor_ptr, axis, output_ptr);
  240. break;
  241. case DT_INT32:
  242. ret = GenData<int32_t>(data_num, input_tensor_ptr, axis, output_ptr);
  243. break;
  244. case DT_INT64:
  245. ret = GenData<int64_t>(data_num, input_tensor_ptr, axis, output_ptr);
  246. break;
  247. case DT_UINT8:
  248. ret = GenData<uint8_t>(data_num, input_tensor_ptr, axis, output_ptr);
  249. break;
  250. case DT_UINT16:
  251. ret = GenData<uint16_t>(data_num, input_tensor_ptr, axis, output_ptr);
  252. break;
  253. case DT_UINT32:
  254. ret = GenData<uint32_t>(data_num, input_tensor_ptr, axis, output_ptr);
  255. break;
  256. case DT_UINT64:
  257. ret = GenData<uint64_t>(data_num, input_tensor_ptr, axis, output_ptr);
  258. break;
  259. default:
  260. GELOGI("GatherV2Kernel does not support this Data type:%s", TypeUtils::DataTypeToSerialString(data_type).c_str());
  261. return NOT_CHANGED;
  262. }
  263. return ret;
  264. }
  265. Status GatherV2Kernel::SaveIndicesByDataType(ConstGeTensorPtr indices_tensor_ptr, GeShape &x_shape,
  266. GeShape &indices_shape, DataType indices_data_type, size_t axis) {
  267. if (indices_data_type == DT_INT32) {
  268. auto indices_ptr = const_cast<int32_t *>(reinterpret_cast<const int32_t *>(indices_tensor_ptr->GetData().data()));
  269. for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) {
  270. if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) {
  271. GELOGW("indices %ld value is not in range [0, %ld).", i, x_shape.GetDim(axis));
  272. return NOT_CHANGED;
  273. }
  274. indicates_.push_back(*(indices_ptr + i));
  275. }
  276. } else {
  277. // int64
  278. auto indices_ptr = const_cast<int64_t *>(reinterpret_cast<const int64_t *>(indices_tensor_ptr->GetData().data()));
  279. for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) {
  280. if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) {
  281. GELOGW("indices %ld value is not in range [0, %ld).", i, x_shape.GetDim(axis));
  282. return NOT_CHANGED;
  283. }
  284. indicates_.push_back(*(indices_ptr + i));
  285. }
  286. }
  287. return SUCCESS;
  288. }
  289. Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vector<ConstGeTensorPtr> &input,
  290. vector<GeTensorPtr> &v_output) const {
  291. if (op_desc_ptr == nullptr) {
  292. GELOGW("input opdesc is nullptr.");
  293. return NOT_CHANGED;
  294. }
  295. if (input.size() != kGatherV2InpotNum) {
  296. GELOGW("The number of input for GatherV2 must be %zu.", kGatherV2InpotNum);
  297. return NOT_CHANGED;
  298. }
  299. bool is_null = (input[kGatherV2InputIndexZero] == nullptr || input[kGatherV2InputIndexOne] == nullptr ||
  300. input[kGatherV2InputIndexTwo] == nullptr);
  301. if (is_null) {
  302. GELOGW("some input is nullptr.");
  303. return NOT_CHANGED;
  304. }
  305. ConstGeTensorPtr tensor0 = input.at(kGatherV2InputIndexZero);
  306. ConstGeTensorPtr tensor1 = input.at(kGatherV2InputIndexOne);
  307. ConstGeTensorPtr tensor2 = input.at(kGatherV2InputIndexTwo);
  308. bool size_is_zero =
  309. ((tensor0->GetData().size() == 0) || (tensor1->GetData().size() == 0) || (tensor2->GetData().size() == 0));
  310. if (size_is_zero) {
  311. GELOGW("some input size is zero.");
  312. return NOT_CHANGED;
  313. }
  314. auto indices_shape = tensor1->GetTensorDesc().GetShape();
  315. auto axis_shape = tensor2->GetTensorDesc().GetShape();
  316. // axis must be scalar
  317. if (axis_shape.GetDimNum() != 0) {
  318. GELOGW("axis must be scalar but its shape is %zu", axis_shape.GetDimNum());
  319. return NOT_CHANGED;
  320. }
  321. auto axis_data_type = tensor2->GetTensorDesc().GetDataType();
  322. bool is_valid_axis_data_type = axis_data_type == DT_INT32 || axis_data_type == DT_INT64;
  323. if (!is_valid_axis_data_type) {
  324. GELOGW("axis datatype must be DT_INT32 or DT_INT64");
  325. return NOT_CHANGED;
  326. }
  327. // check indices data_type && dims && every element
  328. auto indices_data_type = tensor1->GetTensorDesc().GetDataType();
  329. bool is_valid_indices_data_type = indices_data_type == DT_INT32 || indices_data_type == DT_INT64;
  330. if (!is_valid_indices_data_type) {
  331. GELOGW("indices datatype must be DT_INT32 or DT_INT64.");
  332. return NOT_CHANGED;
  333. }
  334. if (indices_shape.GetDimNum() > kMaxIndicatesDims) {
  335. GELOGW("indices input only support 0 or 1 dims.");
  336. return NOT_CHANGED;
  337. }
  338. return SUCCESS;
  339. }
  340. void GatherV2Kernel::DebugPrint(int64_t axis, const GeShape &x_shape, const GeShape &indices_shape,
  341. const std::vector<int64_t> &y_shape) {
  342. GELOGD("GatherV2Kernel axis:%ld x_shape:%zu indices_shape:%zu y_shape:%zu.", axis, x_shape.GetDimNum(),
  343. indices_shape.GetDimNum(), y_shape.size());
  344. for (size_t i = 0; i < x_shape.GetDimNum(); i++) {
  345. GELOGD("GatherV2Kernel x_shape[%zu]: %ld.", i, x_shape.GetDim(i));
  346. }
  347. for (size_t i = 0; i < indices_shape.GetDimNum(); i++) {
  348. GELOGD("GatherV2Kernel indices_shape[%zu]: %ld.", i, indices_shape.GetDim(i));
  349. }
  350. for (size_t i = 0; i < y_shape.size(); i++) {
  351. GELOGD("GatherV2Kernel y_shape[%zu]: %ld.", i, y_shape[i]);
  352. }
  353. for (auto ele : indicates_) {
  354. GELOGD("GatherV2Kernel indices:%ld.", ele);
  355. }
  356. }
  357. Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector<ConstGeTensorPtr> &input,
  358. vector<GeTensorPtr> &v_output) {
  359. GELOGI("Enter GatherV2Kernel Process");
  360. Status ret = Check(op_desc_ptr, input, v_output);
  361. if (ret != SUCCESS) {
  362. GELOGW("param check failed");
  363. return NOT_CHANGED;
  364. }
  365. GELOGI("GatherV2Kernel[%s] start Process", op_desc_ptr->GetName().c_str());
  366. ConstGeTensorPtr tensor0 = input.at(kGatherV2InputIndexZero);
  367. ConstGeTensorPtr tensor1 = input.at(kGatherV2InputIndexOne);
  368. ConstGeTensorPtr tensor2 = input.at(kGatherV2InputIndexTwo);
  369. auto x_shape = tensor0->GetTensorDesc().GetShape();
  370. auto indices_shape = tensor1->GetTensorDesc().GetShape();
  371. auto axis_data_type = tensor2->GetTensorDesc().GetDataType();
  372. int64_t axis = axis_data_type == DT_INT32
  373. ? *(const_cast<int32_t *>(reinterpret_cast<const int32_t *>(tensor2->GetData().data())))
  374. : *(const_cast<int64_t *>(reinterpret_cast<const int64_t *>(tensor2->GetData().data())));
  375. axis = axis >= 0 ? axis : axis + x_shape.GetDimNum();
  376. // check axis value
  377. if (axis < 0 || (axis + 1) > static_cast<int64_t>(x_shape.GetDimNum())) {
  378. GELOGW("axis is invalid!");
  379. return NOT_CHANGED;
  380. }
  381. auto indices_data_type = tensor1->GetTensorDesc().GetDataType();
  382. ret = SaveIndicesByDataType(tensor1, x_shape, indices_shape, indices_data_type, static_cast<size_t>(axis));
  383. if (ret != SUCCESS) {
  384. GELOGW("Save indeices by data type failed!");
  385. return ret;
  386. }
  387. // check input data type
  388. auto x_data_type = tensor0->GetTensorDesc().GetDataType();
  389. if (supported_type.find(x_data_type) == supported_type.end()) {
  390. GELOGI("GatherV2Kernel does not support this Data type:%s.",
  391. TypeUtils::DataTypeToSerialString(x_data_type).c_str());
  392. return NOT_CHANGED;
  393. }
  394. // calc output shape
  395. std::vector<int64_t> y_shape;
  396. for (size_t i = 0; i < static_cast<size_t>(axis); i++) {
  397. y_shape.push_back(x_shape.GetDim(i));
  398. }
  399. for (size_t i = 0; i < indices_shape.GetDimNum(); i++) {
  400. y_shape.push_back(indices_shape.GetDim(i));
  401. }
  402. for (size_t i = static_cast<size_t>(axis) + 1; i < x_shape.GetDimNum(); i++) {
  403. y_shape.push_back(x_shape.GetDim(i));
  404. }
  405. GeTensorPtr output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(0));
  406. if (output_ptr == nullptr) {
  407. GELOGW("make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str());
  408. return NOT_CHANGED;
  409. }
  410. output_ptr->MutableTensorDesc().SetShape(GeShape(y_shape));
  411. output_ptr->MutableTensorDesc().SetDataType(x_data_type);
  412. // added for debug
  413. DebugPrint(axis, x_shape, indices_shape, y_shape);
  414. // calc stride
  415. std::vector<int64_t> xstride(x_shape.GetDimNum());
  416. std::vector<int64_t> ystride(y_shape.size());
  417. xstride_ = xstride;
  418. ystride_ = ystride;
  419. auto ret_x = CalcStride(xstride_, x_shape.GetDims());
  420. auto ret_y = CalcStride(ystride_, y_shape);
  421. ret = (ret_x == SUCCESS && ret_y == SUCCESS) ? SUCCESS : NOT_CHANGED;
  422. if (ret != SUCCESS) {
  423. GELOGE(ret, "CalcStride Failed");
  424. return ret;
  425. }
  426. ret = Process(axis, x_data_type, tensor0, output_ptr);
  427. if (ret != SUCCESS) {
  428. GELOGE(ret, "GenData failed, data_type: %s", TypeUtils::DataTypeToSerialString(x_data_type).c_str());
  429. return ret;
  430. }
  431. GELOGI("GatherV2Kernel Process Success.");
  432. v_output.push_back(output_ptr);
  433. return SUCCESS;
  434. }
  435. REGISTER_KERNEL(GATHERV2, GatherV2Kernel);
  436. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示