You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cublas_lt.cpp 7.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162
  1. /**
  2. * \file dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./algo.h"
  12. #include "src/cuda/handle.h"
  13. #include "src/cuda/utils.h"
  14. #include "src/cuda/matrix_mul/cublasLt_wrapper.h"
  15. using namespace megdnn;
  16. using namespace cuda;
  17. #if CUDA_VERSION >= 10010
  18. static inline CUBLASLTMatmulDesc::SizeArgs from_local_size_args(
  19. const BatchedMatrixMulForwardImpl::AlgoBase::SizeArgs& args) {
  20. auto&& param = args.opr->param();
  21. auto&& handle = concrete_handle(args.opr->handle());
  22. bool transA = param.transposeA;
  23. bool transB = param.transposeB;
  24. return {handle, transA, transB,
  25. args.layout_a, args.layout_b, args.layout_c};
  26. }
  27. bool BatchedMatrixMulForwardImpl::AlgoCublasLt::is_available(
  28. const SizeArgs& args) const {
  29. auto cublasLt_args = from_local_size_args(args);
  30. auto&& dev_prop = current_device_prop();
  31. bool is_dev_support = dev_prop.major >= 7;
  32. bool res = is_dev_support && CUBLASLTMatmulDesc(cublasLt_args, true)
  33. .is_available(cublasLt_args, INT_MAX);
  34. return res;
  35. }
  36. size_t BatchedMatrixMulForwardImpl::AlgoCublasLt::get_workspace_in_bytes(
  37. const SizeArgs& args) const {
  38. auto cublasLt_args = from_local_size_args(args);
  39. cublasLtMatmulAlgo_t algo;
  40. CUBLASLTMatmulDesc desc(cublasLt_args, true);
  41. desc.get_algorithm_heuristic(cublasLt_args, INT_MAX, algo);
  42. return desc.get_workspace_bundle(cublasLt_args, algo).total_size_in_bytes();
  43. }
  44. void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(
  45. const ExecArgs& args) const {
  46. auto cublasLt_args = from_local_size_args(args);
  47. cublasLtMatmulAlgo_t algo;
  48. CUBLASLTMatmulDesc desc(cublasLt_args, true);
  49. desc.get_algorithm_heuristic(cublasLt_args, INT_MAX, algo);
  50. auto ws_bundle = desc.get_workspace_bundle(cublasLt_args, algo);
  51. auto&& handle = concrete_handle(args.opr->handle());
  52. auto&& stream = handle->stream();
  53. auto&& cublasLt_handle = handle->cublasLt_handle();
  54. auto batched_hgemm = [&]() {
  55. auto zero_half = handle->zero_device_h();
  56. auto one_half = handle->one_device_h();
  57. megdnn_assert(ws_bundle.nr_workspace() == 1,
  58. "workspace bundle size should be 1(ws_algo)");
  59. cublas_check(cublasLtMatmul(
  60. cublasLt_handle, desc.matmul_desc, one_half,
  61. static_cast<const __half*>(args.tensor_b.raw_ptr),
  62. desc.layout_b,
  63. static_cast<const __half*>(args.tensor_a.raw_ptr),
  64. desc.layout_a, zero_half,
  65. static_cast<const __half*>(args.tensor_c.raw_ptr),
  66. desc.layout_c, static_cast<__half*>(args.tensor_c.raw_ptr),
  67. desc.layout_c, &algo, ws_bundle.get(0), ws_bundle.get_size(0),
  68. stream));
  69. };
  70. auto batched_sgemm = [&]() {
  71. auto zero = handle->zero_device();
  72. auto one = handle->one_device();
  73. auto dev_b =
  74. (desc.dt_b == CUDA_R_16F)
  75. ? static_cast<void*>(args.tensor_b.ptr<dt_float16>())
  76. : static_cast<void*>(args.tensor_b.ptr<dt_float32>());
  77. auto dev_a =
  78. (desc.dt_a == CUDA_R_16F)
  79. ? static_cast<void*>(args.tensor_a.ptr<dt_float16>())
  80. : static_cast<void*>(args.tensor_a.ptr<dt_float32>());
  81. auto dev_c = static_cast<void*>(args.tensor_c.raw_ptr);
  82. megdnn_assert(ws_bundle.nr_workspace() == 1,
  83. "workspace bundle size should be 1(ws_algo)");
  84. cublas_check(cublasLtMatmul(cublasLt_handle, desc.matmul_desc, one,
  85. dev_b, desc.layout_b, dev_a, desc.layout_a,
  86. zero, dev_c, desc.layout_c, dev_c,
  87. desc.layout_c, &algo, ws_bundle.get(0),
  88. ws_bundle.get_size(0), stream));
  89. };
  90. auto batched_igemm = [&]() {
  91. auto zero = handle->zero_device();
  92. auto one = handle->one_device();
  93. megdnn_assert(
  94. ws_bundle.nr_workspace() == 4,
  95. "workspace bundle size should be 4(ws_algo, ws_a, ws_b, ws_c)");
  96. void* ws_b = ws_bundle.get(1);
  97. void* ws_a = ws_bundle.get(2);
  98. void* ws_c = ws_bundle.get(3);
  99. int32_t pm = CUBLAS_POINTER_MODE_DEVICE;
  100. cublasOperation_t trans_a = CUBLAS_OP_T, trans_c = CUBLAS_OP_N;
  101. cublasLtMatrixTransformDesc_t transform_desc = nullptr;
  102. cublas_check(
  103. cublasLtMatrixTransformDescCreate(&transform_desc, CUDA_R_32F));
  104. cublas_check(cublasLtMatrixTransformDescSetAttribute(
  105. transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_POINTER_MODE,
  106. &pm, sizeof(pm)));
  107. cublas_check(cublasLtMatrixTransform(
  108. cublasLt_handle, transform_desc, one, args.tensor_b.raw_ptr,
  109. desc.layout_b, zero, nullptr, nullptr, ws_b,
  110. desc.layout_trans_b, stream));
  111. cublas_check(cublasLtMatrixTransformDescSetAttribute(
  112. transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &trans_a,
  113. sizeof(trans_a)));
  114. cublas_check(cublasLtMatrixTransform(
  115. cublasLt_handle, transform_desc, one, args.tensor_a.raw_ptr,
  116. desc.layout_a, zero, nullptr, nullptr, ws_a,
  117. desc.layout_trans_a, stream));
  118. cublas_check(cublasLtMatmul(
  119. cublasLt_handle, desc.matmul_desc, one, ws_b,
  120. desc.layout_trans_b, ws_a, desc.layout_trans_a, zero, ws_c,
  121. desc.layout_trans_c, ws_c, desc.layout_trans_c, &algo,
  122. ws_bundle.get(0), ws_bundle.get_size(0), stream));
  123. cublas_check(cublasLtMatrixTransformDescSetAttribute(
  124. transform_desc, CUBLASLT_MATRIX_TRANSFORM_DESC_TRANSA, &trans_c,
  125. sizeof(trans_c)));
  126. cublas_check(cublasLtMatrixTransform(
  127. cublasLt_handle, transform_desc, one, ws_c, desc.layout_trans_c,
  128. zero, nullptr, nullptr, args.tensor_c.raw_ptr, desc.layout_c,
  129. stream));
  130. cublas_check(cublasLtMatrixTransformDescDestroy(transform_desc));
  131. };
  132. ws_bundle.set(args.workspace.raw_ptr);
  133. #if CUDA_VERSION >= 11000
  134. if (desc.dt_compute == CUBLAS_COMPUTE_32I) {
  135. batched_igemm();
  136. } else if (desc.dt_compute == CUBLAS_COMPUTE_16F) {
  137. batched_hgemm();
  138. } else if (desc.dt_compute == CUBLAS_COMPUTE_32F) {
  139. batched_sgemm();
  140. } else {
  141. megdnn_throw("compute_type must be int32/float16/float32");
  142. }
  143. #else
  144. if (desc.dt_compute == CUDA_R_32I) {
  145. batched_igemm();
  146. } else if (desc.dt_compute == CUDA_R_16F) {
  147. batched_hgemm();
  148. } else if (desc.dt_compute == CUDA_R_32F) {
  149. batched_sgemm();
  150. } else {
  151. megdnn_throw("compute_type must be int32/float16/float32");
  152. }
  153. #endif
  154. }
  155. #endif

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台