Browse Source

feat(dnn): support tf32

GitOrigin-RevId: 9e5871f933
release-1.10
Megvii Engine Team 3 years ago
parent
commit
a0a5fcf182
5 changed files with 23 additions and 5 deletions
  1. +8
    -0
      dnn/src/cuda/batched_matrix_mul/cublas.cpp
  2. +1
    -1
      dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp
  3. +12
    -2
      dnn/src/cuda/matrix_mul/cublas.cpp
  4. +1
    -1
      dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp
  5. +1
    -1
      dnn/src/cuda/matrix_mul/cublas_lt.cpp

+ 8
- 0
dnn/src/cuda/batched_matrix_mul/cublas.cpp View File

@@ -88,7 +88,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const {


#if CUDART_VERSION >= 9010 #if CUDART_VERSION >= 9010
auto io16_c32 = [&]() { auto io16_c32 = [&]() {
#if CUDART_VERSION >= 11000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#else
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
#endif
auto zero = handle->zero_device(); auto zero = handle->zero_device();
auto one = handle->one_device(); auto one = handle->one_device();
cublas_check(cublasGemmBatchedEx( cublas_check(cublasGemmBatchedEx(
@@ -104,7 +108,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const {


#if CUDART_VERSION >= 9000 #if CUDART_VERSION >= 9000
auto io16_c16 = [&]() { auto io16_c16 = [&]() {
#if CUDART_VERSION >= 11000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#else
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
#endif
auto zero = handle->zero_device_h(); auto zero = handle->zero_device_h();
auto one = handle->one_device_h(); auto one = handle->one_device_h();
cublas_check(cublasHgemmBatched( cublas_check(cublasHgemmBatched(


+ 1
- 1
dnn/src/cuda/batched_matrix_mul/cublas_lt.cpp View File

@@ -124,7 +124,7 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const
batched_igemm(); batched_igemm();
} else if (desc.dt_compute == CUBLAS_COMPUTE_16F) { } else if (desc.dt_compute == CUBLAS_COMPUTE_16F) {
batched_hgemm(); batched_hgemm();
} else if (desc.dt_compute == CUBLAS_COMPUTE_32F) {
} else if (desc.dt_compute == CUBLAS_COMPUTE_32F_FAST_TF32) {
batched_sgemm(); batched_sgemm();
} else { } else {
megdnn_throw("compute_type must be int32/float16/float32"); megdnn_throw("compute_type must be int32/float16/float32");


+ 12
- 2
dnn/src/cuda/matrix_mul/cublas.cpp View File

@@ -49,18 +49,26 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const {
auto sgemm = [&]() { auto sgemm = [&]() {
auto zero = handle->zero_device(); auto zero = handle->zero_device();
auto one = handle->one_device(); auto one = handle->one_device();
#if CUDART_VERSION >= 11000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#endif
cublas_check(cublasSgemm( cublas_check(cublasSgemm(
cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N,
param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one, param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one,
args.tensor_b.ptr<dt_float32>(), args.tensor_b.layout.stride[0], args.tensor_b.ptr<dt_float32>(), args.tensor_b.layout.stride[0],
args.tensor_a.ptr<dt_float32>(), args.tensor_a.layout.stride[0], zero, args.tensor_a.ptr<dt_float32>(), args.tensor_a.layout.stride[0], zero,
args.tensor_c.ptr<dt_float32>(), args.tensor_c.layout.stride[0])); args.tensor_c.ptr<dt_float32>(), args.tensor_c.layout.stride[0]));
#if CUDART_VERSION >= 11000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH));
#endif
}; };


auto sgemm_ex = [&]() { auto sgemm_ex = [&]() {
auto zero = handle->zero_device(); auto zero = handle->zero_device();
auto one = handle->one_device(); auto one = handle->one_device();
#if CUDART_VERSION >= 9000
#if CUDART_VERSION >= 11000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#elif CUDART_VERSION >= 9000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
#endif #endif
auto sgemm_ex_err = cublasSgemmEx( auto sgemm_ex_err = cublasSgemmEx(
@@ -78,7 +86,9 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const {
}; };


auto hgemm = [&]() { auto hgemm = [&]() {
#if CUDART_VERSION >= 9000
#if CUDART_VERSION >= 11000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH));
#elif CUDART_VERSION >= 9000
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH));
#endif #endif
auto one_half = handle->one_device_h(); auto one_half = handle->one_device_h();


+ 1
- 1
dnn/src/cuda/matrix_mul/cublasLt_wrapper.cpp View File

@@ -28,7 +28,7 @@ static cublasComputeType_t to_cublas_compute_type(DType tp) {
case DTypeEnum::Float16: case DTypeEnum::Float16:
return CUBLAS_COMPUTE_16F; return CUBLAS_COMPUTE_16F;
case DTypeEnum::Float32: case DTypeEnum::Float32:
return CUBLAS_COMPUTE_32F;
return CUBLAS_COMPUTE_32F_FAST_TF32;
case DTypeEnum::Int32: case DTypeEnum::Int32:
case DTypeEnum::QuantizedS32: case DTypeEnum::QuantizedS32:
return CUBLAS_COMPUTE_32I; return CUBLAS_COMPUTE_32I;


+ 1
- 1
dnn/src/cuda/matrix_mul/cublas_lt.cpp View File

@@ -107,7 +107,7 @@ void MatrixMulForwardImpl::AlgoCuBlasLt::exec(const ExecArgs& args) const {
case CUBLAS_COMPUTE_16F: case CUBLAS_COMPUTE_16F:
hgemm(); hgemm();
break; break;
case CUBLAS_COMPUTE_32F:
case CUBLAS_COMPUTE_32F_FAST_TF32:
sgemm(); sgemm();
break; break;
case CUBLAS_COMPUTE_32I: case CUBLAS_COMPUTE_32I:


Loading…
Cancel
Save