You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

tensor_utils.cc 14 kB

5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
5 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397
  1. /**
  2. * Copyright 2019-2020 Huawei Technologies Co., Ltd
  3. *
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. */
  16. #include "graph/utils/tensor_utils.h"
  17. #include <cmath>
  18. #include "debug/ge_log.h"
  19. #include "framework/common/debug/ge_log.h"
  20. #include "graph/ge_tensor.h"
  21. #include "graph/types.h"
  22. #include "graph/utils/type_utils.h"
  23. namespace ge {
  24. namespace {
  25. // When nc1hwc0 dim size = 5, calc element count directly.
  26. const uint32_t kNc1hwc0CalcByDimsSize = 5;
  27. // Unknown shape element num
  28. const int64_t kElementCntUnknownShape = -1;
  29. // Unknown shape mem size
  30. const int64_t kMemSizeUnknownShape = -1;
  31. // Nchw and nhwc dim size must be 4
  32. const uint32_t kDimSize4d = 4;
  33. // C1HWNCoC0 dim size must be 6
  34. const uint32_t kDimSizeC1hwncoc0 = 6;
  35. // Cube size is 16
  36. const uint32_t kTheCubeSize = 16;
  37. // Default c0 size equals cube size.
  38. const uint32_t kC0SizeDefault = kTheCubeSize;
  39. // Size equals int8 cube size is 32
  40. const uint32_t kC0SizeInt8 = 32;
  41. // NCHW dim N index
  42. const int32_t kNchwDimIdxN = 0;
  43. // NCHW dim C index
  44. const int32_t kNchwDimIdxC = 1;
  45. // NCHW dim H index
  46. const int32_t kNchwDimIdxH = 2;
  47. // NCHW dim W index
  48. const int32_t kNchwDimIdxW = 3;
  49. const int kDataMemAlignSize = 32;
  50. const int kNum2 = 2;
  51. } // namespace
  52. ///
  53. /// Check if a * b overflow.
  54. /// @param a multiplier
  55. /// @param b Multiplicand
  56. /// @return true: overflow
  57. /// false: not overflow
  58. ///
  59. static bool CheckMultiplyOverflowInt64(const int64_t &a, const int64_t &b) {
  60. if (a > 0) {
  61. if (b > 0) {
  62. if (a > (INT64_MAX / b)) {
  63. return true;
  64. }
  65. } else {
  66. if (b < (INT64_MIN / a)) {
  67. return true;
  68. }
  69. }
  70. } else {
  71. if (b > 0) {
  72. if (a < (INT64_MIN / b)) {
  73. return true;
  74. }
  75. } else {
  76. if ((a != 0) && (b < (INT64_MAX / a))) {
  77. return true;
  78. }
  79. }
  80. }
  81. return false;
  82. }
  83. ///
  84. /// Calculate element num by dims directly.
  85. /// @param dims dim info
  86. /// @param element_cnt element count
  87. /// @return GRAPH_SUCCESS:success
  88. /// other:failed
  89. ///
  90. static graphStatus CalcElementCntByDims(const std::vector<int64_t> &dims, int64_t &element_cnt) {
  91. element_cnt = 1;
  92. for (int64_t dim : dims) {
  93. if (CheckMultiplyOverflowInt64(element_cnt, dim)) {
  94. GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, as when multiplying %ld and %ld.", element_cnt, dim);
  95. return GRAPH_FAILED;
  96. }
  97. element_cnt *= dim;
  98. }
  99. return GRAPH_SUCCESS;
  100. }
  101. ///
  102. /// Calculate fixed dims element num.
  103. /// @param dims dim info
  104. /// @param fixed_dim_size fixed dim size
  105. /// @param element_cnt element count
  106. /// @return GRAPH_SUCCESS:success
  107. /// other:failed
  108. ///
  109. static graphStatus CalcElementCntOfFixedDims(const std::vector<int64_t> &dims, Format format, uint32_t fixed_dim_size,
  110. int64_t &element_cnt) {
  111. if (dims.size() != fixed_dim_size) {
  112. GELOGW("Format %d(%s) need dim size=%u but %zu, calc as ND.", format,
  113. TypeUtils::FormatToSerialString(format).c_str(), fixed_dim_size, dims.size());
  114. }
  115. return CalcElementCntByDims(dims, element_cnt);
  116. }
  117. ///
  118. /// Get dim c0 size by type
  119. /// @param data_type data type
  120. /// @return c0 size
  121. ///
  122. static uint32_t GetDimC0(DataType &data_type) {
  123. bool is_int8_size = (data_type == DT_INT8) || (data_type == DT_UINT8) || (data_type == DT_DUAL_SUB_UINT8) ||
  124. (data_type == DT_DUAL_SUB_INT8) || (data_type == DT_BOOL) || (data_type == DT_QINT8);
  125. return is_int8_size ? kC0SizeInt8 : kC0SizeDefault;
  126. }
  127. ///
  128. /// Calculate nc1hwc0 element num.
  129. /// @param dims dim info
  130. /// @param data_type data type
  131. /// @param element_cnt element count
  132. /// @return GRAPH_SUCCESS:success
  133. /// other:failed
  134. ///
  135. static graphStatus CalcElementCntOfNc1hwc0(const std::vector<int64_t> &dims, DataType data_type, int64_t &element_cnt) {
  136. // When nc1hwc0 dims size = 5, no need split dim c
  137. if (dims.size() == kNc1hwc0CalcByDimsSize) {
  138. return CalcElementCntByDims(dims, element_cnt);
  139. } else if (dims.size() != kDimSize4d) {
  140. GELOGE(GRAPH_FAILED, "CalcElementCntOfNc1hwc0 failed as dims.size=%zu is not %u or %u.", dims.size(), kDimSize4d,
  141. kNc1hwc0CalcByDimsSize);
  142. return GRAPH_FAILED;
  143. }
  144. auto c0 = static_cast<int64_t>(GetDimC0(data_type));
  145. // Nc1hwc0 dims is according to nchw, dim c index is 1.
  146. auto c1 = static_cast<int64_t>(std::ceil(dims[kNchwDimIdxC] * 1.0 / c0));
  147. // Store dims is split c to c1 and c0.
  148. std::vector<int64_t> store_dims = {dims[kNchwDimIdxN], c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0};
  149. return CalcElementCntByDims(store_dims, element_cnt);
  150. }
  151. ///
  152. /// Calculate FractalZ element num.
  153. /// @param dims dim info
  154. /// @param data_type data type
  155. /// @param element_cnt element count
  156. /// @return GRAPH_SUCCESS:success
  157. /// other:failed
  158. ///
  159. static graphStatus CalcElementCntOfFractalZ(const std::vector<int64_t> &dims, DataType data_type,
  160. int64_t &element_cnt) {
  161. static char *parser_priority = std::getenv("PARSER_PRIORITY");
  162. if (parser_priority != nullptr && string(parser_priority) == "cce") {
  163. if (dims.size() != kDimSize4d) {
  164. GELOGE(GRAPH_FAILED, "CalcElementCntOfFractalZ failed as dims.size=%zu is not %u.", dims.size(), kDimSize4d);
  165. return GRAPH_FAILED;
  166. }
  167. auto c0 = static_cast<int64_t>(GetDimC0(data_type));
  168. // FractalZ dims is according to nchw, dim c index is 1.
  169. auto c1 = static_cast<int64_t>(std::ceil(dims[kNchwDimIdxC] * 1.0 / c0));
  170. // Spread NC1HWC0 as a two dimension array, n as column dimension,
  171. // C1HWC0 as row dimension
  172. std::vector<int64_t> r_count_vec = {c1, dims[kNchwDimIdxH], dims[kNchwDimIdxW], c0};
  173. int64_t r_count = 1;
  174. graphStatus graph_status = CalcElementCntByDims(r_count_vec, r_count);
  175. if (graph_status != GRAPH_SUCCESS) {
  176. GELOGE(graph_status, "Calc [%ld, %ld, %ld, %ld] element count failed.", c1, dims[kNchwDimIdxH],
  177. dims[kNchwDimIdxW], c0);
  178. return graph_status;
  179. }
  180. // Cube count in n
  181. auto nc_cnt = static_cast<int64_t>(std::ceil(dims[kNchwDimIdxN] * 1.0 / kTheCubeSize));
  182. // Cube count in vertical direction(C1HWC0)
  183. int64_t vc_cnt = r_count / c0;
  184. // Element count in each cube
  185. int64_t cube_elem_cnt = c0 * kTheCubeSize;
  186. if (CheckMultiplyOverflowInt64(nc_cnt, vc_cnt)) {
  187. GELOGE(GRAPH_FAILED, "The multiplication of %ld and %ld is overflow.", nc_cnt, vc_cnt);
  188. return GRAPH_FAILED;
  189. }
  190. // Read data times needed by cube
  191. int64_t c_cnt = nc_cnt * vc_cnt;
  192. if (CheckMultiplyOverflowInt64(c_cnt, cube_elem_cnt)) {
  193. GELOGE(GRAPH_FAILED, "The multiplication of %ld and %ld is overflow.", c_cnt, cube_elem_cnt);
  194. return GRAPH_FAILED;
  195. }
  196. // Element count after fractal arrangement
  197. element_cnt = c_cnt * cube_elem_cnt;
  198. return GRAPH_SUCCESS;
  199. } else {
  200. return CalcElementCntByDims(dims, element_cnt);
  201. }
  202. }
  203. ///
  204. /// Calculate tensor element num.
  205. /// @param dims dim info
  206. /// @param format tensor format
  207. /// @param data_type data type
  208. /// @param element_cnt element count
  209. /// @return GRAPH_SUCCESS:success
  210. /// other:failed
  211. ///
  212. static graphStatus CalcTensorElementCnt(const std::vector<int64_t> &dims, Format format, DataType data_type,
  213. int64_t &element_cnt) {
  214. const string format_str = TypeUtils::FormatToSerialString(format);
  215. // Check dims
  216. for (size_t i = 0; i < dims.size(); ++i) {
  217. int64_t dim = dims[i];
  218. if (dim < 0) {
  219. GELOGI("It's unknown shape, as dims[%zu]=%ld negative, format=%d(%s).", i, dim, format, format_str.c_str());
  220. element_cnt = kElementCntUnknownShape;
  221. return GRAPH_SUCCESS;
  222. } else if (dim == 0) {
  223. GELOGI("No need calc element count, as dims[%zu]=%ld, format=%d(%s).", i, dim, format, format_str.c_str());
  224. element_cnt = 0;
  225. return GRAPH_SUCCESS;
  226. }
  227. }
  228. graphStatus graph_status;
  229. switch (format) {
  230. case FORMAT_ND:
  231. case FORMAT_MD:
  232. graph_status = CalcElementCntByDims(dims, element_cnt);
  233. break;
  234. case FORMAT_NCHW:
  235. case FORMAT_HWCN:
  236. case FORMAT_NHWC:
  237. case FORMAT_CHWN:
  238. graph_status = CalcElementCntOfFixedDims(dims, format, kDimSize4d, element_cnt);
  239. break;
  240. case FORMAT_C1HWNCoC0:
  241. graph_status = CalcElementCntOfFixedDims(dims, format, kDimSizeC1hwncoc0, element_cnt);
  242. break;
  243. case FORMAT_NC1HWC0:
  244. graph_status = CalcElementCntOfNc1hwc0(dims, data_type, element_cnt);
  245. break;
  246. case FORMAT_FRACTAL_Z:
  247. graph_status = CalcElementCntOfFractalZ(dims, data_type, element_cnt);
  248. break;
  249. case FORMAT_NC1HWC0_C04:
  250. case FORMAT_FRACTAL_NZ:
  251. case FORMAT_FRACTAL_ZZ:
  252. case FORMAT_NDHWC:
  253. case FORMAT_NCDHW:
  254. case FORMAT_DHWCN:
  255. case FORMAT_DHWNC:
  256. case FORMAT_FRACTAL_Z_3D:
  257. case FORMAT_FRACTAL_Z_3D_TRANSPOSE:
  258. case FORMAT_NDC1HWC0:
  259. case FORMAT_FRACTAL_Z_C04:
  260. case FORMAT_FRACTAL_ZN_LSTM:
  261. graph_status = CalcElementCntByDims(dims, element_cnt);
  262. break;
  263. default:
  264. GELOGE(GRAPH_FAILED, "unsupported format, format=%d(%s).", format, format_str.c_str());
  265. graph_status = GRAPH_FAILED;
  266. break;
  267. }
  268. const string type_str = TypeUtils::DataTypeToSerialString(data_type);
  269. if (graph_status == GRAPH_SUCCESS) {
  270. GELOGD(
  271. "CalcTensorElementCnt end, format=%d(%s),"
  272. " data_type=%d(%s), element_cnt=%ld.",
  273. format, format_str.c_str(), data_type, type_str.c_str(), element_cnt);
  274. } else {
  275. GELOGE(GRAPH_FAILED, "CalcTensorElementCnt failed, format=%d(%s), data_type=%d(%s).", format, format_str.c_str(),
  276. data_type, type_str.c_str());
  277. }
  278. return graph_status;
  279. }
  280. ///
  281. /// Calculate tensor mem size.
  282. /// @param shape tensor shape
  283. /// @param format tensor format
  284. /// @param data_type tensor data type
  285. /// @param mem_size -1 means unknown shape,other means mem size
  286. /// @return GRAPH_SUCCESS:success, other:failed
  287. ///
  288. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus TensorUtils::CalcTensorMemSize(const GeShape &shape,
  289. Format format,
  290. DataType data_type,
  291. int64_t &mem_size) {
  292. const string format_str = TypeUtils::FormatToSerialString(format);
  293. const string type_str = TypeUtils::DataTypeToSerialString(data_type);
  294. uint32_t type_size = 0;
  295. bool result = TypeUtils::GetDataTypeLength(data_type, type_size);
  296. if (!result) {
  297. GELOGE(GRAPH_FAILED, "GetDataTypeLength failed, data_type=%d(%s).", data_type, type_str.c_str());
  298. return GRAPH_FAILED;
  299. }
  300. std::vector<int64_t> dims = shape.GetDims();
  301. int64_t element_cnt = 0;
  302. graphStatus status = CalcTensorElementCnt(dims, format, data_type, element_cnt);
  303. if (status != GRAPH_SUCCESS) {
  304. GELOGE(status, "CalcTensorElementCnt failed, status=%u format=%d(%s) data_type=%d(%s).", status, format,
  305. format_str.c_str(), data_type, type_str.c_str());
  306. return status;
  307. }
  308. // Support unknown shape
  309. if (element_cnt < 0) {
  310. mem_size = kMemSizeUnknownShape;
  311. GELOGD(
  312. "element_cnt is unknown. "
  313. "format=%d(%s), data_type=%d(%s), mem_size=%ld",
  314. format, format_str.c_str(), data_type, type_str.c_str(), mem_size);
  315. return GRAPH_SUCCESS;
  316. }
  317. auto type_size_int64 = static_cast<int64_t>(type_size);
  318. if (CheckMultiplyOverflowInt64(element_cnt, type_size_int64)) {
  319. GELOGE(GRAPH_FAILED, "CalcTensorMemSize overflow, when multiplying %ld and %ld, format=%d(%s), data_type=%d(%s).",
  320. element_cnt, type_size_int64, format, format_str.c_str(), data_type, type_str.c_str());
  321. return GRAPH_FAILED;
  322. }
  323. mem_size = element_cnt * type_size_int64;
  324. GELOGD(
  325. "CalcTensorMemSize end, "
  326. "format=%d(%s), data_type=%d(%s), mem_size=%ld",
  327. format, format_str.c_str(), data_type, type_str.c_str(), mem_size);
  328. return GRAPH_SUCCESS;
  329. }
  330. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  331. TensorUtils::GetTensorMemorySizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) {
  332. graphStatus graph_status = GetTensorSizeInBytes(desc_temp, size_temp);
  333. if (graph_status != GRAPH_SUCCESS) {
  334. return GRAPH_FAILED;
  335. }
  336. // 64-byte alignment, if size is 0, align to 32 bytes
  337. if (size_temp > (INT64_MAX - kNum2 * kDataMemAlignSize)) {
  338. GELOGW("The updated mem size %ld is bigger than INT64_MAX", size_temp);
  339. } else {
  340. size_temp = ((size_temp + kNum2 * kDataMemAlignSize - 1) / kDataMemAlignSize) * kDataMemAlignSize;
  341. }
  342. return GRAPH_SUCCESS;
  343. }
  344. GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus
  345. TensorUtils::GetTensorSizeInBytes(const GeTensorDesc &desc_temp, int64_t &size_temp) {
  346. GeShape output_shape = desc_temp.GetShape();
  347. Format format = desc_temp.GetFormat();
  348. DataType data_type = desc_temp.GetDataType();
  349. int64_t output_mem_size = 0;
  350. graphStatus graph_status = CalcTensorMemSize(output_shape, format, data_type, output_mem_size);
  351. if (graph_status != GRAPH_SUCCESS) {
  352. GELOGE(GRAPH_FAILED, "CalcTensorMemSize failed!");
  353. return GRAPH_FAILED;
  354. }
  355. if (output_mem_size < 0) {
  356. GELOGE(GRAPH_FAILED, "After calc concat tensor memory size, output_mem_size = %ld, out of data range [0, %ld]",
  357. output_mem_size, INT64_MAX);
  358. return GRAPH_FAILED;
  359. }
  360. size_temp = output_mem_size;
  361. return GRAPH_SUCCESS;
  362. }
  363. } // namespace ge

图引擎模块(GE)是MindSpore的一个子模块,其代码由C++实现,位于前端模块ME和底层硬件之间,起到承接作用。图引擎模块以ME下发的图作为输入,然后进行一系列的深度图优化操作,最后输出一张可以在底层硬件上高效运行的图。GE针对昇腾AI处理器的硬件结构特点,做了特定的优化工作,以此来充分发挥出昇腾AI处理器的强大算力。在进行模型训练/推理时,GE会被自动调用而用户并不感知。GE主要由GE API和GE Core两部分组成,详细的架构图如下所示