You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

gather_v2_kernel.cc 20 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. /**
  2. * Copyright 2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "host_kernels/gather_v2_kernel.h"
  17. #include <memory>
  18. #include <set>
  19. #include "common/fp16_t.h"
  20. #include "common/ge_inner_error_codes.h"
  21. #include "common/op/ge_op_utils.h"
  22. #include "common/types.h"
  23. #include "common/util.h"
  24. #include "framework/common/debug/ge_log.h"
  25. #include "host_kernels/kernel_utils.h"
  26. #include "graph/utils/type_utils.h"
  27. #include "inc/kernel_factory.h"
  28. namespace ge {
  29. namespace {
  30. const size_t kGatherV2InputIndexZero = 0;
  31. const size_t kGatherV2InputIndexOne = 1;
  32. const size_t kGatherV2InputIndexTwo = 2;
  33. const size_t kGatherV2InputIndexThree = 3;
  34. const size_t kGatherV2DimOne = 1;
  35. const size_t kGatherV2InpotNum = 3;
  36. const size_t kMaxIndicatesDims = 1; // only support scalar and 1 dims indicates_
  37. const std::set<DataType> supported_type = {DT_FLOAT16, DT_DOUBLE, DT_INT8, DT_INT16, DT_INT16, DT_INT32,
  38. DT_INT64, DT_UINT8, DT_UINT16, DT_UINT32, DT_UINT64};
  39. } // namespace
  40. template <typename T>
  41. Status GatherV2Kernel::ProcessAxis0(ConstGeTensorPtr tensor_x, GeTensorPtr output) {
  42. Status ret = SUCCESS;
  43. T *data_ptr_x = reinterpret_cast<T *>(const_cast<unsigned char *>(tensor_x->GetData().data()));
  44. T *data_ptr_y = reinterpret_cast<T *>(const_cast<unsigned char *>(output->GetData().data()));
  45. // index is valid, and no bigger than kGatherV2InputIndexZero
  46. size_t output_size = output->GetData().size();
  47. for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) {
  48. T *data_ptr_x_tmp = data_ptr_x + indicates_[i] * xstride_[kGatherV2InputIndexZero];
  49. T *data_ptr_y_tmp = data_ptr_y + i * ystride_[kGatherV2InputIndexZero];
  50. size_t size = sizeof(T) * xstride_[kGatherV2InputIndexZero];
  51. if (data_ptr_y_tmp - data_ptr_y < 0) {
  52. GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero");
  53. return PARAM_INVALID;
  54. }
  55. size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T);
  56. auto ret_mem = memcpy_s(reinterpret_cast<void *>(data_ptr_y_tmp), output_size - offset_size,
  57. reinterpret_cast<void *>(data_ptr_x_tmp), size);
  58. if (ret_mem != 0) {
  59. GELOGE(MEMALLOC_FAILED, "memcpy failed!");
  60. return MEMALLOC_FAILED;
  61. }
  62. }
  63. return ret;
  64. }
  65. template <typename T>
  66. Status GatherV2Kernel::ProcessAxis1(ConstGeTensorPtr tensor_x, GeTensorPtr output) {
  67. Status ret = SUCCESS;
  68. T *data_ptr_x = reinterpret_cast<T *>(const_cast<unsigned char *>(tensor_x->GetData().data()));
  69. T *data_ptr_y = reinterpret_cast<T *>(const_cast<unsigned char *>(output->GetData().data()));
  70. // index is valid, and no bigger than kGatherV2InputIndexOne
  71. size_t output_size = output->GetData().size();
  72. for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) {
  73. T *data_ptr_x_i = data_ptr_x + i * xstride_[kGatherV2InputIndexZero];
  74. T *data_ptr_y_i = data_ptr_y + i * ystride_[kGatherV2InputIndexZero];
  75. for (int64_t j = 0; j < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexOne); j++) {
  76. T *data_ptr_x_tmp = data_ptr_x_i + indicates_[j] * xstride_[kGatherV2InputIndexOne];
  77. T *data_ptr_y_tmp = data_ptr_y_i + j * ystride_[kGatherV2InputIndexOne];
  78. size_t size = sizeof(T) * xstride_[kGatherV2InputIndexOne];
  79. if (data_ptr_y_tmp - data_ptr_y < 0) {
  80. GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero");
  81. return PARAM_INVALID;
  82. }
  83. size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T);
  84. auto ret_mem = memcpy_s(reinterpret_cast<void *>(data_ptr_y_tmp), output_size - offset_size,
  85. reinterpret_cast<void *>(data_ptr_x_tmp), size);
  86. if (ret_mem != 0) {
  87. GELOGE(MEMALLOC_FAILED, "memcpy failed!");
  88. return MEMALLOC_FAILED;
  89. }
  90. }
  91. }
  92. return ret;
  93. }
  94. template <typename T>
  95. Status GatherV2Kernel::ProcessAxis2(ConstGeTensorPtr tensor_x, GeTensorPtr output) {
  96. Status ret = SUCCESS;
  97. T *data_ptr_x = reinterpret_cast<T *>(const_cast<unsigned char *>(tensor_x->GetData().data()));
  98. T *data_ptr_y = reinterpret_cast<T *>(const_cast<unsigned char *>(output->GetData().data()));
  99. // index is valid, and no bigger than kGatherV2InputIndexTwo
  100. size_t output_size = output->GetData().size();
  101. for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) {
  102. T *data_ptr_x_i = data_ptr_x + i * xstride_[kGatherV2InputIndexZero];
  103. T *data_ptr_y_i = data_ptr_y + i * ystride_[kGatherV2InputIndexZero];
  104. for (int64_t j = 0; j < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexOne); j++) {
  105. T *data_ptr_x_j = data_ptr_x_i + j * xstride_[kGatherV2InputIndexOne];
  106. T *data_ptr_y_j = data_ptr_y_i + j * ystride_[kGatherV2InputIndexOne];
  107. for (int64_t m = 0; m < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexTwo); m++) {
  108. T *data_ptr_x_tmp = data_ptr_x_j + indicates_[m] * xstride_[kGatherV2InputIndexTwo];
  109. T *data_ptr_y_tmp = data_ptr_y_j + m * ystride_[kGatherV2InputIndexTwo];
  110. size_t size = sizeof(T) * xstride_[kGatherV2InputIndexTwo];
  111. if (data_ptr_y_tmp - data_ptr_y < 0) {
  112. GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero");
  113. return PARAM_INVALID;
  114. }
  115. size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T);
  116. auto ret_mem = memcpy_s(reinterpret_cast<void *>(data_ptr_y_tmp), output_size - offset_size,
  117. reinterpret_cast<void *>(data_ptr_x_tmp), size);
  118. if (ret_mem != 0) {
  119. GELOGE(MEMALLOC_FAILED, "memcpy failed!");
  120. return MEMALLOC_FAILED;
  121. }
  122. }
  123. }
  124. }
  125. return ret;
  126. }
  127. template <typename T>
  128. Status GatherV2Kernel::ProcessAxis3(ConstGeTensorPtr tensor_x, GeTensorPtr output) {
  129. Status ret = SUCCESS;
  130. T *data_ptr_x = reinterpret_cast<T *>(const_cast<unsigned char *>(tensor_x->GetData().data()));
  131. T *data_ptr_y = reinterpret_cast<T *>(const_cast<unsigned char *>(output->GetData().data()));
  132. // index is valid, and no bigger than kGatherV2InputIndexThree
  133. size_t output_size = output->GetData().size();
  134. for (int64_t i = 0; i < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexZero); i++) {
  135. T *data_ptr_x_i = data_ptr_x + i * xstride_[kGatherV2InputIndexZero];
  136. T *data_ptr_y_i = data_ptr_y + i * ystride_[kGatherV2InputIndexZero];
  137. for (int64_t j = 0; j < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexOne); j++) {
  138. T *data_ptr_x_j = data_ptr_x_i + j * xstride_[kGatherV2InputIndexOne];
  139. T *data_ptr_y_j = data_ptr_y_i + j * ystride_[kGatherV2InputIndexOne];
  140. for (int64_t m = 0; m < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexTwo); m++) {
  141. T *data_ptr_x_m = data_ptr_x_j + m * xstride_[kGatherV2InputIndexTwo];
  142. T *data_ptr_y_m = data_ptr_y_j + m * ystride_[kGatherV2InputIndexTwo];
  143. for (int64_t n = 0; n < output->GetTensorDesc().GetShape().GetDim(kGatherV2InputIndexThree); n++) {
  144. T *data_ptr_x_tmp = data_ptr_x_m + indicates_[n] * xstride_[kGatherV2InputIndexThree];
  145. T *data_ptr_y_tmp = data_ptr_y_m + n * ystride_[kGatherV2InputIndexThree];
  146. size_t size = sizeof(T) * xstride_[kGatherV2InputIndexThree];
  147. if (data_ptr_y_tmp - data_ptr_y < 0) {
  148. GELOGE(PARAM_INVALID, "ptr_y - ptr_y_tmp less than zero");
  149. return PARAM_INVALID;
  150. }
  151. size_t offset_size = (data_ptr_y_tmp - data_ptr_y) * sizeof(T);
  152. auto ret_mem = memcpy_s(reinterpret_cast<void *>(data_ptr_y_tmp), output_size - offset_size,
  153. reinterpret_cast<void *>(data_ptr_x_tmp), size);
  154. if (ret_mem != 0) {
  155. GELOGE(MEMALLOC_FAILED, "memcpy failed!");
  156. return MEMALLOC_FAILED;
  157. }
  158. }
  159. }
  160. }
  161. }
  162. return ret;
  163. }
  164. template <typename T>
  165. Status GatherV2Kernel::GenData(const int64_t data_num, ConstGeTensorPtr tensor_x, int64_t axis, GeTensorPtr output) {
  166. if (data_num <= 0) {
  167. return PARAM_INVALID;
  168. }
  169. if (!CheckInt64MulOverflow(data_num, sizeof(T))) {
  170. GELOGE(PARAM_INVALID, "Int64MulOverflow, data_num:%ld, type_len:%zu.", data_num, sizeof(T));
  171. return PARAM_INVALID;
  172. }
  173. std::unique_ptr<T[]> buf(new (std::nothrow) T[data_num]());
  174. if (buf == nullptr) {
  175. GELOGE(MEMALLOC_FAILED, "New sizeof(T) * data_num(%zu) memory failed", static_cast<size_t>(sizeof(T) * data_num));
  176. return MEMALLOC_FAILED;
  177. }
  178. GE_IF_BOOL_EXEC(
  179. output->SetData(reinterpret_cast<uint8_t *>(buf.get()), static_cast<size_t>(data_num * sizeof(T))) != GRAPH_SUCCESS,
  180. GELOGE(INTERNAL_ERROR, "set data failed");
  181. return INTERNAL_ERROR);
  182. Status ret = SUCCESS;
  183. switch (axis) {
  184. case 0:
  185. ret = ProcessAxis0<T>(tensor_x, output);
  186. break;
  187. case 1:
  188. ret = ProcessAxis1<T>(tensor_x, output);
  189. break;
  190. case 2:
  191. ret = ProcessAxis2<T>(tensor_x, output);
  192. break;
  193. case 3:
  194. ret = ProcessAxis3<T>(tensor_x, output);
  195. break;
  196. default:
  197. GELOGI("Only support 4 dims and below but input axis is %ld", axis);
  198. return NOT_CHANGED;
  199. }
  200. return ret;
  201. }
  202. Status GatherV2Kernel::CalcStride(std::vector<int64_t> &stride, std::vector<int64_t> dims) {
  203. if (stride.size() != dims.size() || dims.size() == 0) {
  204. return PARAM_INVALID;
  205. }
  206. int i = static_cast<int>(dims.size() - kGatherV2DimOne);
  207. stride[static_cast<size_t>(i)] = static_cast<int64_t>(kGatherV2DimOne);
  208. i--;
  209. while (i >= 0) {
  210. size_t index = static_cast<size_t>(i) + kGatherV2DimOne;
  211. if (!CheckInt64MulOverflow(stride[index], dims[index])) {
  212. GELOGE(PARAM_INVALID, "Int64MulOverflow, data_num(%ld) type_len(%ld)", stride[index], dims[index]);
  213. return PARAM_INVALID;
  214. }
  215. stride[static_cast<size_t>(i)] = stride[index] * dims[index];
  216. i--;
  217. }
  218. return SUCCESS;
  219. }
  220. Status GatherV2Kernel::Process(int64_t axis, DataType data_type, ConstGeTensorPtr input_tensor_ptr,
  221. GeTensorPtr output_ptr) {
  222. Status ret = SUCCESS;
  223. int64_t data_num = output_ptr->GetTensorDesc().GetShape().GetShapeSize();
  224. switch (data_type) {
  225. case DT_FLOAT16:
  226. ret = GenData<fp16_t>(data_num, input_tensor_ptr, axis, output_ptr);
  227. break;
  228. case DT_DOUBLE:
  229. ret = GenData<double>(data_num, input_tensor_ptr, axis, output_ptr);
  230. break;
  231. case DT_INT8:
  232. ret = GenData<int8_t>(data_num, input_tensor_ptr, axis, output_ptr);
  233. break;
  234. case DT_INT16:
  235. ret = GenData<int16_t>(data_num, input_tensor_ptr, axis, output_ptr);
  236. break;
  237. case DT_INT32:
  238. ret = GenData<int32_t>(data_num, input_tensor_ptr, axis, output_ptr);
  239. break;
  240. case DT_INT64:
  241. ret = GenData<int64_t>(data_num, input_tensor_ptr, axis, output_ptr);
  242. break;
  243. case DT_UINT8:
  244. ret = GenData<uint8_t>(data_num, input_tensor_ptr, axis, output_ptr);
  245. break;
  246. case DT_UINT16:
  247. ret = GenData<uint16_t>(data_num, input_tensor_ptr, axis, output_ptr);
  248. break;
  249. case DT_UINT32:
  250. ret = GenData<uint32_t>(data_num, input_tensor_ptr, axis, output_ptr);
  251. break;
  252. case DT_UINT64:
  253. ret = GenData<uint64_t>(data_num, input_tensor_ptr, axis, output_ptr);
  254. break;
  255. default:
  256. GELOGI("GatherV2Kernel does not support this Data type:%s", TypeUtils::DataTypeToSerialString(data_type).c_str());
  257. return NOT_CHANGED;
  258. }
  259. return ret;
  260. }
  261. Status GatherV2Kernel::SaveIndicesByDataType(ConstGeTensorPtr indices_tensor_ptr, GeShape &x_shape,
  262. GeShape &indices_shape, DataType indices_data_type, size_t axis) {
  263. if (indices_data_type == DT_INT32) {
  264. auto indices_ptr = const_cast<int32_t *>(reinterpret_cast<const int32_t *>(indices_tensor_ptr->GetData().data()));
  265. for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) {
  266. if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) {
  267. GELOGW("indices %ld value is not in range [0, %ld)", i, x_shape.GetDim(axis));
  268. return NOT_CHANGED;
  269. }
  270. indicates_.push_back(*(indices_ptr + i));
  271. }
  272. } else {
  273. // int64
  274. auto indices_ptr = const_cast<int64_t *>(reinterpret_cast<const int64_t *>(indices_tensor_ptr->GetData().data()));
  275. for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) {
  276. if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) {
  277. GELOGW("indices %ld value is not in range [0, %ld)", i, x_shape.GetDim(axis));
  278. return NOT_CHANGED;
  279. }
  280. indicates_.push_back(*(indices_ptr + i));
  281. }
  282. }
  283. return SUCCESS;
  284. }
  285. Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vector<ConstGeTensorPtr> &input,
  286. vector<GeTensorPtr> &v_output) const {
  287. if (op_desc_ptr == nullptr) {
  288. GELOGW("input opdesc is nullptr.");
  289. return NOT_CHANGED;
  290. }
  291. if (input.size() != kGatherV2InpotNum) {
  292. GELOGW("The number of input for GatherV2 must be %zu.", kGatherV2InpotNum);
  293. return NOT_CHANGED;
  294. }
  295. bool is_null = (input[kGatherV2InputIndexZero] == nullptr || input[kGatherV2InputIndexOne] == nullptr ||
  296. input[kGatherV2InputIndexTwo] == nullptr);
  297. if (is_null) {
  298. GELOGW("some input is nullptr.");
  299. return NOT_CHANGED;
  300. }
  301. ConstGeTensorPtr tensor0 = input.at(kGatherV2InputIndexZero);
  302. ConstGeTensorPtr tensor1 = input.at(kGatherV2InputIndexOne);
  303. ConstGeTensorPtr tensor2 = input.at(kGatherV2InputIndexTwo);
  304. bool size_is_zero =
  305. ((tensor0->GetData().size() == 0) || (tensor1->GetData().size() == 0) || (tensor2->GetData().size() == 0));
  306. if (size_is_zero) {
  307. GELOGW("some input size is zero.");
  308. return NOT_CHANGED;
  309. }
  310. auto indices_shape = tensor1->GetTensorDesc().GetShape();
  311. auto axis_shape = tensor2->GetTensorDesc().GetShape();
  312. // axis must be scalar
  313. if (axis_shape.GetDimNum() != 0) {
  314. GELOGW("axis must be scalar but its shape is %zu", axis_shape.GetDimNum());
  315. return NOT_CHANGED;
  316. }
  317. auto axis_data_type = tensor2->GetTensorDesc().GetDataType();
  318. bool is_valid_axis_data_type = axis_data_type == DT_INT32 || axis_data_type == DT_INT64;
  319. if (!is_valid_axis_data_type) {
  320. GELOGW("axis datatype must be DT_INT32 or DT_INT64");
  321. return NOT_CHANGED;
  322. }
  323. // check indices data_type && dims && every element
  324. auto indices_data_type = tensor1->GetTensorDesc().GetDataType();
  325. bool is_valid_indices_data_type = indices_data_type == DT_INT32 || indices_data_type == DT_INT64;
  326. if (!is_valid_indices_data_type) {
  327. GELOGW("indices datatype must be DT_INT32 or DT_INT64");
  328. return NOT_CHANGED;
  329. }
  330. if (indices_shape.GetDimNum() > kMaxIndicatesDims) {
  331. GELOGW("indices input only support 0 or 1 dims");
  332. return NOT_CHANGED;
  333. }
  334. return SUCCESS;
  335. }
  336. void GatherV2Kernel::DebugPrint(int64_t axis, const GeShape &x_shape, const GeShape &indices_shape,
  337. const std::vector<int64_t> &y_shape) {
  338. GELOGD("GatherV2Kernel axis:%ld x_shape:%zu indices_shape:%zu y_shape:%zu", axis, x_shape.GetDimNum(),
  339. indices_shape.GetDimNum(), y_shape.size());
  340. for (size_t i = 0; i < x_shape.GetDimNum(); i++) {
  341. GELOGD("GatherV2Kernel x_shape[%zu]: %ld", i, x_shape.GetDim(i));
  342. }
  343. for (size_t i = 0; i < indices_shape.GetDimNum(); i++) {
  344. GELOGD("GatherV2Kernel indices_shape[%zu]: %ld", i, indices_shape.GetDim(i));
  345. }
  346. for (size_t i = 0; i < y_shape.size(); i++) {
  347. GELOGD("GatherV2Kernel y_shape[%zu]: %ld", i, y_shape[i]);
  348. }
  349. for (auto ele : indicates_) {
  350. GELOGD("GatherV2Kernel indices:%ld", ele);
  351. }
  352. }
  353. Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector<ConstGeTensorPtr> &input,
  354. vector<GeTensorPtr> &v_output) {
  355. GELOGI("Enter GatherV2Kernel Process.");
  356. Status ret = Check(op_desc_ptr, input, v_output);
  357. if (ret != SUCCESS) {
  358. GELOGW("param check failed.");
  359. return NOT_CHANGED;
  360. }
  361. GELOGI("GatherV2Kernel[%s] start Process.", op_desc_ptr->GetName().c_str());
  362. ConstGeTensorPtr tensor0 = input.at(kGatherV2InputIndexZero);
  363. ConstGeTensorPtr tensor1 = input.at(kGatherV2InputIndexOne);
  364. ConstGeTensorPtr tensor2 = input.at(kGatherV2InputIndexTwo);
  365. auto x_shape = tensor0->GetTensorDesc().GetShape();
  366. auto indices_shape = tensor1->GetTensorDesc().GetShape();
  367. auto axis_data_type = tensor2->GetTensorDesc().GetDataType();
  368. int64_t axis = axis_data_type == DT_INT32
  369. ? *(const_cast<int32_t *>(reinterpret_cast<const int32_t *>(tensor2->GetData().data())))
  370. : *(const_cast<int64_t *>(reinterpret_cast<const int64_t *>(tensor2->GetData().data())));
  371. axis = axis >= 0 ? axis : axis + x_shape.GetDimNum();
  372. // check axis value
  373. if (axis < 0 || (axis + 1) > static_cast<int64_t>(x_shape.GetDimNum())) {
  374. GELOGW("axis is invalid");
  375. return NOT_CHANGED;
  376. }
  377. auto indices_data_type = tensor1->GetTensorDesc().GetDataType();
  378. ret = SaveIndicesByDataType(tensor1, x_shape, indices_shape, indices_data_type, static_cast<size_t>(axis));
  379. if (ret != SUCCESS) {
  380. GELOGW("Save indeices by data type failed!");
  381. return ret;
  382. }
  383. // check input data type
  384. auto x_data_type = tensor0->GetTensorDesc().GetDataType();
  385. if (supported_type.find(x_data_type) == supported_type.end()) {
  386. GELOGI("GatherV2Kernel does not support this Data type:%s", TypeUtils::DataTypeToSerialString(x_data_type).c_str());
  387. return NOT_CHANGED;
  388. }
  389. // calc output shape
  390. std::vector<int64_t> y_shape;
  391. for (size_t i = 0; i < static_cast<size_t>(axis); i++) {
  392. y_shape.push_back(x_shape.GetDim(i));
  393. }
  394. for (size_t i = 0; i < indices_shape.GetDimNum(); i++) {
  395. y_shape.push_back(indices_shape.GetDim(i));
  396. }
  397. for (size_t i = static_cast<size_t>(axis) + 1; i < x_shape.GetDimNum(); i++) {
  398. y_shape.push_back(x_shape.GetDim(i));
  399. }
  400. GeTensorPtr output_ptr = MakeShared<GeTensor>(op_desc_ptr->GetOutputDesc(0));
  401. if (output_ptr == nullptr) {
  402. GELOGW("make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str());
  403. return NOT_CHANGED;
  404. }
  405. output_ptr->MutableTensorDesc().SetShape(GeShape(y_shape));
  406. output_ptr->MutableTensorDesc().SetDataType(x_data_type);
  407. // added for debug
  408. DebugPrint(axis, x_shape, indices_shape, y_shape);
  409. // calc stride
  410. std::vector<int64_t> xstride(x_shape.GetDimNum());
  411. std::vector<int64_t> ystride(y_shape.size());
  412. xstride_ = xstride;
  413. ystride_ = ystride;
  414. auto ret_x = CalcStride(xstride_, x_shape.GetDims());
  415. auto ret_y = CalcStride(ystride_, y_shape);
  416. ret = (ret_x == SUCCESS && ret_y == SUCCESS) ? SUCCESS : NOT_CHANGED;
  417. if (ret != SUCCESS) {
  418. GELOGE(ret, "CalcStride Failed");
  419. return ret;
  420. }
  421. ret = Process(axis, x_data_type, tensor0, output_ptr);
  422. if (ret != SUCCESS) {
  423. GELOGE(ret, "GenData failed, data_type: %s", TypeUtils::DataTypeToSerialString(x_data_type).c_str());
  424. return ret;
  425. }
  426. GELOGI("GatherV2Kernel Process Success.");
  427. v_output.push_back(output_ptr);
  428. return SUCCESS;
  429. }
  430. REGISTER_KERNEL(GATHERV2, GatherV2Kernel);
  431. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示