You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

cublas.cpp 5.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. /**
  2. * \file dnn/src/cuda/batched_matrix_mul/cublas.cpp
  3. * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  4. *
  5. * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  6. *
  7. * Unless required by applicable law or agreed to in writing,
  8. * software distributed under the License is distributed on an
  9. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  10. */
  11. #include "./algo.h"
  12. #include "./helper.cuh"
  13. #include "src/common/utils.cuh"
  14. #include "src/cuda/handle.h"
  15. #include "src/cuda/utils.h"
  16. using namespace megdnn;
  17. using namespace cuda;
  18. using namespace batched_matrix_mul;
  19. bool BatchedMatrixMulForwardImpl::AlgoCublas::is_available(const SizeArgs& args) const {
  20. auto dtype = args.layout_a.dtype;
  21. auto&& param = args.opr->param();
  22. auto&& handle = concrete_handle(args.opr->handle());
  23. // fix: cublasSgemmBatched with versions prior to 11.1 has some error when batch = 1
  24. // and matricA's width > 8191 .So temporarily drop this algo when
  25. // args.layout_a.shape[2] <= 8191 || args.layout_a.shape[0] != 1
  26. if (dtype == dtype::Float32()
  27. #if CUBLAS_VERSION < 11200
  28. && (args.layout_a.shape[args.opr->param().transposeA ? 1 : 2] <= 8191 ||
  29. args.layout_a.shape[0] != 1)
  30. #endif
  31. )
  32. return true;
  33. if (dtype != dtype::Float16())
  34. return false;
  35. else {
  36. auto&& cuda_cap = handle->device_prop();
  37. if (param.compute_mode == Param::ComputeMode::FLOAT32) {
  38. #if CUDART_VERSION >= 9010
  39. return cuda_cap.major >= 5;
  40. #else
  41. MEGDNN_MARK_USED_VAR(cuda_cap);
  42. return false;
  43. #endif
  44. } else {
  45. #if CUDART_VERSION >= 9000
  46. return cuda_cap.major >= 6;
  47. #else
  48. MEGDNN_MARK_USED_VAR(cuda_cap);
  49. return false;
  50. #endif
  51. }
  52. }
  53. }
  54. size_t BatchedMatrixMulForwardImpl::AlgoCublas::get_workspace_in_bytes(
  55. const SizeArgs& args) const {
  56. return args.layout_a.shape[0] * 3 * sizeof(uintptr_t);
  57. }
  58. void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const {
  59. auto param = args.opr->param();
  60. auto dtype = args.layout_a.dtype;
  61. auto handle = concrete_handle(args.opr->handle());
  62. auto cublas_handle = handle->cublas_handle();
  63. auto stream = cuda_stream(handle);
  64. auto batch = args.layout_a.shape[0];
  65. auto m = args.layout_c.shape[1], n = args.layout_c.shape[2];
  66. auto k = args.layout_a.shape[param.transposeA ? 1 : 2];
  67. auto workspace = args.workspace;
  68. uintptr_t* As = static_cast<uintptr_t*>(
  69. static_cast<void*>(workspace.raw_ptr + 0 * batch * sizeof(uintptr_t)));
  70. uintptr_t* Bs = static_cast<uintptr_t*>(
  71. static_cast<void*>(workspace.raw_ptr + 1 * batch * sizeof(uintptr_t)));
  72. uintptr_t* Cs = static_cast<uintptr_t*>(
  73. static_cast<void*>(workspace.raw_ptr + 2 * batch * sizeof(uintptr_t)));
  74. arange<uintptr_t>(
  75. As, reinterpret_cast<uintptr_t>(args.tensor_a.raw_ptr()),
  76. args.layout_a.stride[0] * dtype.size(), batch, stream);
  77. arange<uintptr_t>(
  78. Bs, reinterpret_cast<uintptr_t>(args.tensor_b.raw_ptr()),
  79. args.layout_b.stride[0] * dtype.size(), batch, stream);
  80. arange<uintptr_t>(
  81. Cs, reinterpret_cast<uintptr_t>(args.tensor_c.raw_ptr()),
  82. args.layout_c.stride[0] * dtype.size(), batch, stream);
  83. auto io32_c32 = [&]() {
  84. auto zero = handle->zero_device();
  85. auto one = handle->one_device();
  86. cublas_check(cublasSgemmBatched(
  87. cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
  88. param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one,
  89. reinterpret_cast<const dt_float32**>(Bs), args.layout_b.stride[1],
  90. reinterpret_cast<const dt_float32**>(As), args.layout_a.stride[1], zero,
  91. reinterpret_cast<dt_float32**>(Cs), args.layout_c.stride[1], batch));
  92. };
  93. #if CUDART_VERSION >= 9010
  94. auto io16_c32 = [&]() {
  95. cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
  96. auto zero = handle->zero_device();
  97. auto one = handle->one_device();
  98. cublas_check(cublasGemmBatchedEx(
  99. cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
  100. param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one,
  101. reinterpret_cast<const void**>(Bs), CUDA_R_16F, args.layout_b.stride[1],
  102. reinterpret_cast<const void**>(As), CUDA_R_16F, args.layout_a.stride[1],
  103. zero, reinterpret_cast<void**>(Cs), CUDA_R_16F, args.layout_c.stride[1],
  104. batch, CUDA_R_32F, CUBLAS_GEMM_DEFAULT));
  105. cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH));
  106. };
  107. #endif
  108. #if CUDART_VERSION >= 9000
  109. auto io16_c16 = [&]() {
  110. cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
  111. auto zero = handle->zero_device_h();
  112. auto one = handle->one_device_h();
  113. cublas_check(cublasHgemmBatched(
  114. cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
  115. param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one,
  116. reinterpret_cast<const __half**>(Bs), args.layout_b.stride[1],
  117. reinterpret_cast<const __half**>(As), args.layout_a.stride[1], zero,
  118. reinterpret_cast<__half**>(Cs), args.layout_c.stride[1], batch));
  119. cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH));
  120. };
  121. #endif
  122. if (dtype == dtype::Float32()) {
  123. io32_c32();
  124. } else {
  125. if (param.compute_mode == Param::ComputeMode::FLOAT32) {
  126. #if CUDART_VERSION >= 9010
  127. io16_c32();
  128. #endif
  129. } else {
  130. #if CUDART_VERSION >= 9000
  131. io16_c16();
  132. #endif
  133. }
  134. }
  135. }

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台