|
|
@@ -49,18 +49,26 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const { |
|
|
|
auto sgemm = [&]() { |
|
|
|
auto zero = handle->zero_device(); |
|
|
|
auto one = handle->one_device(); |
|
|
|
#if CUDART_VERSION >= 11000 |
|
|
|
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); |
|
|
|
#endif |
|
|
|
cublas_check(cublasSgemm( |
|
|
|
cublas_handle, param.transposeB ? CUBLAS_OP_T : CUBLAS_OP_N, |
|
|
|
param.transposeA ? CUBLAS_OP_T : CUBLAS_OP_N, n, m, k, one, |
|
|
|
args.tensor_b.ptr<dt_float32>(), args.tensor_b.layout.stride[0], |
|
|
|
args.tensor_a.ptr<dt_float32>(), args.tensor_a.layout.stride[0], zero, |
|
|
|
args.tensor_c.ptr<dt_float32>(), args.tensor_c.layout.stride[0])); |
|
|
|
#if CUDART_VERSION >= 11000 |
|
|
|
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_DEFAULT_MATH)); |
|
|
|
#endif |
|
|
|
}; |
|
|
|
|
|
|
|
auto sgemm_ex = [&]() { |
|
|
|
auto zero = handle->zero_device(); |
|
|
|
auto one = handle->one_device(); |
|
|
|
#if CUDART_VERSION >= 9000 |
|
|
|
#if CUDART_VERSION >= 11000 |
|
|
|
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); |
|
|
|
#elif CUDART_VERSION >= 9000 |
|
|
|
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); |
|
|
|
#endif |
|
|
|
auto sgemm_ex_err = cublasSgemmEx( |
|
|
@@ -78,7 +86,9 @@ void MatrixMulForwardImpl::AlgoCuBlas::exec(const ExecArgs& args) const { |
|
|
|
}; |
|
|
|
|
|
|
|
auto hgemm = [&]() { |
|
|
|
#if CUDART_VERSION >= 9000 |
|
|
|
#if CUDART_VERSION >= 11000 |
|
|
|
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TF32_TENSOR_OP_MATH)); |
|
|
|
#elif CUDART_VERSION >= 9000 |
|
|
|
cublas_check(cublasSetMathMode(cublas_handle, CUBLAS_TENSOR_OP_MATH)); |
|
|
|
#endif |
|
|
|
auto one_half = handle->one_device_h(); |
|
|
|