@@ -88,7 +88,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const { | |||||
#if CUDART_VERSION >= 9010 | #if CUDART_VERSION >= 9010 | ||||
auto io16_c32 = [&]() { | auto io16_c32 = [&]() { | ||||
#if CUDART_VERSION >= 11000 | |||||
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); | |||||
#else | |||||
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); | cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); | ||||
#endif | |||||
auto zero = handle->zero_device(); | auto zero = handle->zero_device(); | ||||
auto one = handle->one_device(); | auto one = handle->one_device(); | ||||
cublas_check(cublasGemmBatchedEx( | cublas_check(cublasGemmBatchedEx( | ||||
@@ -104,7 +108,11 @@ void BatchedMatrixMulForwardImpl::AlgoCublas::exec(const ExecArgs& args) const { | |||||
#if CUDART_VERSION >= 9000 | #if CUDART_VERSION >= 9000 | ||||
auto io16_c16 = [&]() { | auto io16_c16 = [&]() { | ||||
#if CUDART_VERSION >= 11000 | |||||
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); | |||||
#else | |||||
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); | cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); | ||||
#endif | |||||
auto zero = handle->zero_device_h(); | auto zero = handle->zero_device_h(); | ||||
auto one = handle->one_device_h(); | auto one = handle->one_device_h(); | ||||
cublas_check(cublasHgemmBatched( | cublas_check(cublasHgemmBatched( | ||||
@@ -124,7 +124,7 @@ void BatchedMatrixMulForwardImpl::AlgoCublasLt::exec(const ExecArgs& args) const | |||||
batched_igemm(); | batched_igemm(); | ||||
} else if (desc.dt_compute == CUBLAS_COMPUTE_16F) { | } else if (desc.dt_compute == CUBLAS_COMPUTE_16F) { | ||||
batched_hgemm(); | batched_hgemm(); | ||||
} else if (desc.dt_compute == CUBLAS_COMPUTE_32F) { | |||||
} else if (desc.dt_compute == CUBLAS_COMPUTE_32F_FAST_TF32) { | |||||
batched_sgemm(); | batched_sgemm(); | ||||
} else { | } else { | ||||
megdnn_throw("compute_type must be int32/float16/float32"); | megdnn_throw("compute_type must be int32/float16/float32"); | ||||
@@ -49,18 +49,26 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const { | |||||
auto sgemm = [&]() { | auto sgemm = [&]() { | ||||
auto zero = handle->zero_device(); | auto zero = handle->zero_device(); | ||||
auto one = handle->one_device(); | auto one = handle->one_device(); | ||||
#if CUDART_VERSION >= 11000 | |||||
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); | |||||
#endif | |||||
cublas_check(cublasSgemm( | cublas_check(cublasSgemm( | ||||
cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, | cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, | ||||
param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one, | param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one, | ||||
args.tensor_b.ptr<dt_float32>(), args.tensor_b.layout.stride[0], | args.tensor_b.ptr<dt_float32>(), args.tensor_b.layout.stride[0], | ||||
args.tensor_a.ptr<dt_float32>(), args.tensor_a.layout.stride[0], zero, | args.tensor_a.ptr<dt_float32>(), args.tensor_a.layout.stride[0], zero, | ||||
args.tensor_c.ptr<dt_float32>(), args.tensor_c.layout.stride[0])); | args.tensor_c.ptr<dt_float32>(), args.tensor_c.layout.stride[0])); | ||||
#if CUDART_VERSION >= 11000 | |||||
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH)); | |||||
#endif | |||||
}; | }; | ||||
auto sgemm_ex = [&]() { | auto sgemm_ex = [&]() { | ||||
auto zero = handle->zero_device(); | auto zero = handle->zero_device(); | ||||
auto one = handle->one_device(); | auto one = handle->one_device(); | ||||
#if CUDART_VERSION >= 9000 | |||||
#if CUDART_VERSION >= 11000 | |||||
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); | |||||
#elif CUDART_VERSION >= 9000 | |||||
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); | cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); | ||||
#endif | #endif | ||||
auto sgemm_ex_err = cublasSgemmEx( | auto sgemm_ex_err = cublasSgemmEx( | ||||
@@ -78,7 +86,9 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const { | |||||
}; | }; | ||||
auto hgemm = [&]() { | auto hgemm = [&]() { | ||||
#if CUDART_VERSION >= 9000 | |||||
#if CUDART_VERSION >= 11000 | |||||
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); | |||||
#elif CUDART_VERSION >= 9000 | |||||
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); | cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); | ||||
#endif | #endif | ||||
auto one_half = handle->one_device_h(); | auto one_half = handle->one_device_h(); | ||||
@@ -28,7 +28,7 @@ static cublasComputeType_t to_cublas_compute_type(DType tp) { | |||||
case DTypeEnum::Float16: | case DTypeEnum::Float16: | ||||
return CUBLAS_COMPUTE_16F; | return CUBLAS_COMPUTE_16F; | ||||
case DTypeEnum::Float32: | case DTypeEnum::Float32: | ||||
return CUBLAS_COMPUTE_32F; | |||||
return CUBLAS_COMPUTE_32F_FAST_TF32; | |||||
case DTypeEnum::Int32: | case DTypeEnum::Int32: | ||||
case DTypeEnum::QuantizedS32: | case DTypeEnum::QuantizedS32: | ||||
return CUBLAS_COMPUTE_32I; | return CUBLAS_COMPUTE_32I; | ||||
@@ -107,7 +107,7 @@ void MatrixMulForwardImpl::AlgoCuBlasLt::exec(const ExecArgs& args) const { | |||||
case CUBLAS_COMPUTE_16F: | case CUBLAS_COMPUTE_16F: | ||||
hgemm(); | hgemm(); | ||||
break; | break; | ||||
case CUBLAS_COMPUTE_32F: | |||||
case CUBLAS_COMPUTE_32F_FAST_TF32: | |||||
sgemm(); | sgemm(); | ||||
break; | break; | ||||
case CUBLAS_COMPUTE_32I: | case CUBLAS_COMPUTE_32I: | ||||