You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cublasLt_wrapper.cpp 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335
  1. /**
  2. * \file dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
  10. * implied.
  11. */
  12. #include "src/common/utils.h"
  13. #include "src/cuda/matrix_mul/cublasLt_wrapper.h"
  14. #include "src/cuda/utils.h"
  15. #if CUDA_VERSION >= 10010
  16. namespace megdnn {
  17. namespace cuda {
  18. static cudaDataType_t to_cuda_dtype(DType tp) {
  19. switch (tp.enumv()) {
  20. case DTypeEnum::Float16:
  21. return CUDA_R_16F;
  22. case DTypeEnum::Float32:
  23. return CUDA_R_32F;
  24. case DTypeEnum::Int8:
  25. case DTypeEnum::QuantizedS8:
  26. return CUDA_R_8I;
  27. case DTypeEnum::Int32:
  28. case DTypeEnum::QuantizedS32:
  29. return CUDA_R_32I;
  30. default:
  31. megdnn_throw("dtype must be float16/float32/int8/qs8/int32");
  32. }
  33. }
  34. #if CUDA_VERSION >= 11000
  35. static cublasComputeType_t to_cublas_compute_type(DType tp) {
  36. switch (tp.enumv()) {
  37. case DTypeEnum::Float16:
  38. return CUBLAS_COMPUTE_16F;
  39. case DTypeEnum::Float32:
  40. return CUBLAS_COMPUTE_32F;
  41. case DTypeEnum::Int32:
  42. case DTypeEnum::QuantizedS32:
  43. return CUBLAS_COMPUTE_32I;
  44. default:
  45. megdnn_throw("dtype must be float16/float32/int32/Qs32");
  46. }
  47. }
  48. #endif
  49. static const char* cuda_type_to_str(cudaDataType_t tp) {
  50. switch (tp) {
  51. case CUDA_R_16F:
  52. return "CUDA_R_16F";
  53. case CUDA_R_32F:
  54. return "CUDA_R_32F";
  55. case CUDA_R_8I:
  56. return "CUDA_R_8I";
  57. case CUDA_R_32I:
  58. return "CUDA_R_32I";
  59. default:
  60. megdnn_throw("dtype must be float16/float32/int8/int32");
  61. }
  62. }
  63. static size_t cuda_dtype_size(cudaDataType_t dt) {
  64. switch (dt) {
  65. case CUDA_R_8I:
  66. return 1_z;
  67. case CUDA_R_16F:
  68. return 2_z;
  69. case CUDA_R_32F:
  70. case CUDA_R_32I:
  71. return 4_z;
  72. default:
  73. megdnn_throw("dtype must be float16/float32/int8/int32");
  74. }
  75. }
  76. CUBLASLTMatmulDesc::~CUBLASLTMatmulDesc() {
  77. if (matmul_desc)
  78. cublas_check(cublasLtMatmulDescDestroy(matmul_desc));
  79. if (layout_a)
  80. cublas_check(cublasLtMatrixLayoutDestroy(layout_a));
  81. if (layout_b)
  82. cublas_check(cublasLtMatrixLayoutDestroy(layout_b));
  83. if (layout_c)
  84. cublas_check(cublasLtMatrixLayoutDestroy(layout_c));
  85. if (layout_trans_a)
  86. cublas_check(cublasLtMatrixLayoutDestroy(layout_trans_a));
  87. if (layout_trans_b)
  88. cublas_check(cublasLtMatrixLayoutDestroy(layout_trans_b));
  89. if (layout_trans_c)
  90. cublas_check(cublasLtMatrixLayoutDestroy(layout_trans_c));
  91. }
  92. void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) {
  93. cublasOperation_t trans_a, trans_b;
  94. auto m = args.layout_c.shape[batched ? 1 : 0],
  95. n = args.layout_c.shape[batched ? 2 : 1];
  96. auto k = batched ? args.layout_a.shape[args.transposeA ? 1 : 2]
  97. : args.layout_a.shape[args.transposeA ? 0 : 1];
  98. int batch = (batched ? args.layout_a.shape[0] : 1);
  99. uint32_t pm = CUBLAS_POINTER_MODE_DEVICE;
  100. dt_b = to_cuda_dtype(args.layout_b.dtype);
  101. dt_a = to_cuda_dtype(args.layout_a.dtype);
  102. dt_c = to_cuda_dtype(args.layout_c.dtype);
  103. megdnn_assert(dt_a == dt_b, "matrix A and B should have same precision");
  104. #if CUDA_VERSION >= 11000
  105. dt_compute = to_cublas_compute_type(args.layout_c.dtype);
  106. cublas_check(cublasLtMatmulDescCreate(&matmul_desc, dt_compute, dt_c));
  107. #else
  108. dt_compute = dt_c;
  109. cublas_check(cublasLtMatmulDescCreate(&matmul_desc, dt_compute));
  110. #endif
  111. cublas_check(cublasLtMatmulDescSetAttribute(
  112. matmul_desc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pm, sizeof(pm)));
  113. cublasLtOrder_t order_COL32 = CUBLASLT_ORDER_COL32;
  114. cublasLtOrder_t order_COL4_4R2_8C = CUBLASLT_ORDER_COL4_4R2_8C;
  115. /**
  116. * \NOTE that cublas takes column-major matrices as inputs,
  117. * but megdnn takes row-major ones.
  118. * So we calculate C^t = B^t * A^t by cublas. Here the transpose symbol
  119. * implies row-major to column-major conversion
  120. */
  121. if (dt_c == CUDA_R_32I) {
  122. /**
  123. * \NOTE: To use IMMA kernels, use computeType = CUDA_R_32I and
  124. * CUBLASLT_ORDER_COL32 for matrices A,C,D and
  125. * CUBLASLT_ORDER_COL4_4R2_8C for matrix B.
  126. */
  127. int ldbtransform, ldatransform, ldctransform;
  128. size_t stride_b_trans, stride_a_trans, stride_c_trans;
  129. ldbtransform = 32 * n;
  130. ldatransform = 32 * round_up<int32_t>(m, 8);
  131. ldctransform = 32 * n;
  132. stride_b_trans = round_up<int32_t>(k, 32) / 32 * ldbtransform;
  133. stride_a_trans = round_up<int32_t>(k, 32) / 32 * ldatransform;
  134. stride_c_trans = round_up<int32_t>(m, 32) / 32 * ldctransform;
  135. trans_b = CUBLAS_OP_T;
  136. cublas_check(cublasLtMatmulDescSetAttribute(matmul_desc,
  137. CUBLASLT_MATMUL_DESC_TRANSB,
  138. &trans_b, sizeof(trans_b)));
  139. // origin layout
  140. cublas_check(cublasLtMatrixLayoutCreate(
  141. &layout_b, dt_b, n, k, args.layout_b.stride[batched ? 1 : 0]));
  142. cublas_check(cublasLtMatrixLayoutCreate(
  143. &layout_a, dt_a, k, m, args.layout_a.stride[batched ? 1 : 0]));
  144. cublas_check(cublasLtMatrixLayoutCreate(
  145. &layout_c, dt_c, n, m, args.layout_c.stride[batched ? 1 : 0]));
  146. // transformed layout
  147. cublas_check(cublasLtMatrixLayoutCreate(&layout_trans_b, dt_b, n, k,
  148. ldbtransform));
  149. cublas_check(cublasLtMatrixLayoutCreate(&layout_trans_a, dt_a, m, k,
  150. ldatransform));
  151. cublas_check(cublasLtMatrixLayoutCreate(&layout_trans_c, dt_c, n, m,
  152. ldctransform));
  153. cublas_check(cublasLtMatrixLayoutSetAttribute(
  154. layout_trans_b, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32,
  155. sizeof(order_COL32)));
  156. cublas_check(cublasLtMatrixLayoutSetAttribute(
  157. layout_trans_a, CUBLASLT_MATRIX_LAYOUT_ORDER,
  158. &order_COL4_4R2_8C, sizeof(order_COL4_4R2_8C)));
  159. cublas_check(cublasLtMatrixLayoutSetAttribute(
  160. layout_trans_c, CUBLASLT_MATRIX_LAYOUT_ORDER, &order_COL32,
  161. sizeof(order_COL32)));
  162. if (batched) {
  163. cublas_check(cublasLtMatrixLayoutSetAttribute(
  164. layout_trans_b, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch,
  165. sizeof(batch)));
  166. cublas_check(cublasLtMatrixLayoutSetAttribute(
  167. layout_trans_a, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch,
  168. sizeof(batch)));
  169. cublas_check(cublasLtMatrixLayoutSetAttribute(
  170. layout_trans_c, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch,
  171. sizeof(batch)));
  172. cublas_check(cublasLtMatrixLayoutSetAttribute(
  173. layout_trans_b, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
  174. &stride_b_trans, sizeof(stride_b_trans)));
  175. cublas_check(cublasLtMatrixLayoutSetAttribute(
  176. layout_trans_a, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
  177. &stride_a_trans, sizeof(stride_a_trans)));
  178. cublas_check(cublasLtMatrixLayoutSetAttribute(
  179. layout_trans_c, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET,
  180. &stride_c_trans, sizeof(stride_c_trans)));
  181. }
  182. workspace_b = batch * cuda_dtype_size(dt_b) * stride_b_trans;
  183. workspace_a = batch * cuda_dtype_size(dt_a) * stride_a_trans;
  184. workspace_c = batch * cuda_dtype_size(dt_c) * stride_c_trans;
  185. } else {
  186. trans_b = args.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N;
  187. trans_a = args.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N;
  188. cublas_check(cublasLtMatmulDescSetAttribute(matmul_desc,
  189. CUBLASLT_MATMUL_DESC_TRANSA,
  190. &trans_b, sizeof(trans_b)));
  191. cublas_check(cublasLtMatmulDescSetAttribute(matmul_desc,
  192. CUBLASLT_MATMUL_DESC_TRANSB,
  193. &trans_a, sizeof(trans_a)));
  194. cublas_check(cublasLtMatrixLayoutCreate(
  195. &layout_b, dt_b, trans_b == CUBLAS_OP_N ? n : k,
  196. trans_b == CUBLAS_OP_N ? k : n,
  197. args.layout_b.stride[batched ? 1 : 0]));
  198. cublas_check(cublasLtMatrixLayoutCreate(
  199. &layout_a, dt_a, trans_a == CUBLAS_OP_N ? k : m,
  200. trans_a == CUBLAS_OP_N ? m : k,
  201. args.layout_a.stride[batched ? 1 : 0]));
  202. cublas_check(cublasLtMatrixLayoutCreate(
  203. &layout_c, dt_c, n, m, args.layout_c.stride[batched ? 1 : 0]));
  204. }
  205. size_t stride_b = args.layout_b.stride[0];
  206. size_t stride_a = args.layout_a.stride[0];
  207. size_t stride_c = args.layout_c.stride[0];
  208. cublas_check(cublasLtMatrixLayoutSetAttribute(
  209. layout_b, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch,
  210. sizeof(batch)));
  211. cublas_check(cublasLtMatrixLayoutSetAttribute(
  212. layout_a, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch,
  213. sizeof(batch)));
  214. cublas_check(cublasLtMatrixLayoutSetAttribute(
  215. layout_c, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch,
  216. sizeof(batch)));
  217. cublas_check(cublasLtMatrixLayoutSetAttribute(
  218. layout_b, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_b,
  219. sizeof(stride_b)));
  220. cublas_check(cublasLtMatrixLayoutSetAttribute(
  221. layout_a, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_a,
  222. sizeof(stride_a)));
  223. cublas_check(cublasLtMatrixLayoutSetAttribute(
  224. layout_c, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &stride_c,
  225. sizeof(stride_c)));
  226. }
  227. bool CUBLASLTMatmulDesc::is_available(const SizeArgs& args, size_t ws_limit) {
  228. bool support;
  229. cublasLtMatmulAlgo_t algo;
  230. switch (dt_c) {
  231. case CUDA_R_16F:
  232. support = (dt_a == CUDA_R_16F);
  233. break;
  234. case CUDA_R_32I: {
  235. support = (dt_a == CUDA_R_8I) &&
  236. (!args.transposeA && !args.transposeB);
  237. break;
  238. }
  239. case CUDA_R_32F:
  240. support = (dt_a == CUDA_R_16F || dt_a == CUDA_R_32F);
  241. break;
  242. case CUDA_R_64F: /* not support? */
  243. default:
  244. support = false;
  245. break;
  246. }
  247. support = support && dt_a == dt_b;
  248. support = support && get_algorithm_heuristic(args, ws_limit, algo);
  249. return support;
  250. }
  251. WorkspaceBundle CUBLASLTMatmulDesc::get_workspace_bundle(
  252. const SizeArgs& args, const cublasLtMatmulAlgo_t& algo) {
  253. size_t algo_workspace_size;
  254. auto&& handle = args.handle;
  255. auto&& cublasLt_handle = handle->cublasLt_handle();
  256. cublasStatus_t status;
  257. cublasLtMatmulHeuristicResult_t result{};
  258. status = cublasLtMatmulAlgoCheck(
  259. cublasLt_handle, matmul_desc,
  260. dt_c == CUDA_R_32I ? layout_trans_b : layout_b,
  261. dt_c == CUDA_R_32I ? layout_trans_a : layout_a,
  262. dt_c == CUDA_R_32I ? layout_trans_c : layout_c,
  263. dt_c == CUDA_R_32I ? layout_trans_c : layout_c, &algo, &result);
  264. // return empty WorkspaceBundle if cublasLtMatmulAlgoCheck() failed
  265. if (status != CUBLAS_STATUS_SUCCESS)
  266. return {nullptr, {}};
  267. algo_workspace_size = result.workspaceSize;
  268. return {nullptr,
  269. (dt_c == CUDA_R_32I)
  270. ? SmallVector<size_t>{algo_workspace_size, workspace_b,
  271. workspace_a, workspace_c}
  272. : SmallVector<size_t>{algo_workspace_size}};
  273. }
  274. bool CUBLASLTMatmulDesc::get_algorithm_heuristic(const SizeArgs& args,
  275. size_t ws_limit,
  276. cublasLtMatmulAlgo_t& algo) {
  277. bool result;
  278. int return_algo_count;
  279. size_t algo_ws_limit;
  280. cublasStatus_t status;
  281. cublasLtMatmulPreference_t algo_pref;
  282. cublasLtMatmulHeuristicResult_t algo_result{};
  283. auto&& handle = concrete_handle(args.handle);
  284. auto&& cublasLt_handle = handle->cublasLt_handle();
  285. size_t temp = workspace_b + workspace_a + workspace_c;
  286. algo_ws_limit = (ws_limit > temp) ? (ws_limit - temp) : 0;
  287. /**
  288. * \Note: algo_ws_limit must be zero if cublasLtGetVersion() <= 10100
  289. */
  290. // algo_ws_limit = 0;
  291. if (dt_c == CUDA_R_32I) {
  292. //[FIXME]: cublasLt(Version 10020) produce wrong result when k in
  293. //[64*n+1 , 64*n+32] for small matrix
  294. //[TODO]: check if this bug is fixed in latter cublasLt.
  295. size_t k_pos = (is_batched ? 1 : 0) + (args.transposeA ? 0 : 1);
  296. size_t k = args.layout_a.shape[k_pos];
  297. bool flt = (k < 65 || ((k - 1) / 32) % 2 == 1);
  298. if (!flt)
  299. return false;
  300. }
  301. result = false;
  302. cublas_check(cublasLtMatmulPreferenceCreate(&algo_pref));
  303. cublas_check(cublasLtMatmulPreferenceSetAttribute(
  304. algo_pref, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &algo_ws_limit,
  305. sizeof(algo_ws_limit)));
  306. status = cublasLtMatmulAlgoGetHeuristic(
  307. cublasLt_handle, matmul_desc,
  308. dt_c == CUDA_R_32I ? layout_trans_b : layout_b,
  309. dt_c == CUDA_R_32I ? layout_trans_a : layout_a,
  310. dt_c == CUDA_R_32I ? layout_trans_c : layout_c,
  311. dt_c == CUDA_R_32I ? layout_trans_c : layout_c, algo_pref, 1,
  312. &algo_result, &return_algo_count);
  313. if (status == CUBLAS_STATUS_SUCCESS && return_algo_count > 0 &&
  314. // perform cublasLtAlgoCheck() to make sure the algo is correct
  315. get_workspace_bundle(args, algo_result.algo).nr_workspace() > 0) {
  316. result = true;
  317. algo = algo_result.algo;
  318. }
  319. cublas_check(cublasLtMatmulPreferenceDestroy(algo_pref));
  320. return result;
  321. }
  322. } // namespace cuda
  323. } // namespace megdnn
  324. #endif
  325. // vim: syntax=cpp.doxygen

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台