|
|
@@ -12,6 +12,7 @@ |
|
|
|
#include "src/common/utils.h" |
|
|
|
#include "src/cuda/utils.h" |
|
|
|
#if CUDA_VERSION >= 10010 |
|
|
|
|
|
|
|
namespace megdnn { |
|
|
|
namespace cuda { |
|
|
|
static cudaDataType_t to_cuda_dtype(DType tp) { |
|
|
@@ -31,6 +32,22 @@ static cudaDataType_t to_cuda_dtype(DType tp) { |
|
|
|
"dtype must be float16/float32/int8/qs8/int32")); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
static cublasComputeType_t to_cublas_compute_type(DType tp) { |
|
|
|
switch (tp.enumv()) { |
|
|
|
case DTypeEnum::Float16: |
|
|
|
return CUBLAS_COMPUTE_16F; |
|
|
|
case DTypeEnum::Float32: |
|
|
|
return CUBLAS_COMPUTE_32F; |
|
|
|
case DTypeEnum::Int32: |
|
|
|
case DTypeEnum::QuantizedS32: |
|
|
|
return CUBLAS_COMPUTE_32I; |
|
|
|
default: |
|
|
|
megdnn_throw(megdnn_mangle( |
|
|
|
"dtype must be float16/float32/int32/Qs32")); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
static const char* cuda_type_to_str(cudaDataType_t tp) { |
|
|
|
switch (tp) { |
|
|
|
case CUDA_R_16F: |
|
|
@@ -46,6 +63,7 @@ static const char* cuda_type_to_str(cudaDataType_t tp) { |
|
|
|
megdnn_mangle("dtype must be float16/float32/int8/int32")); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
static size_t cuda_dtype_size(cudaDataType_t dt) { |
|
|
|
switch (dt) { |
|
|
|
case CUDA_R_8I: |
|
|
@@ -60,6 +78,7 @@ static size_t cuda_dtype_size(cudaDataType_t dt) { |
|
|
|
megdnn_mangle("dtype must be float16/float32/int8/int32")); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
CUBLASLTMatmulDesc::~CUBLASLTMatmulDesc() { |
|
|
|
if (matmul_desc) |
|
|
|
cublas_check(cublasLtMatmulDescDestroy(matmul_desc)); |
|
|
@@ -86,9 +105,10 @@ void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) { |
|
|
|
uint32_t pm = CUBLAS_POINTER_MODE_DEVICE; |
|
|
|
dt_b = to_cuda_dtype(args.layout_b.dtype); |
|
|
|
dt_a = to_cuda_dtype(args.layout_a.dtype); |
|
|
|
dt_compute = dt_c = to_cuda_dtype(args.layout_c.dtype); |
|
|
|
dt_c = to_cuda_dtype(args.layout_c.dtype); |
|
|
|
dt_compute = to_cublas_compute_type(args.layout_c.dtype); |
|
|
|
megdnn_assert(dt_a == dt_b, "matrix A and B should have same precision"); |
|
|
|
cublas_check(cublasLtMatmulDescCreate(&matmul_desc, dt_compute)); |
|
|
|
cublas_check(cublasLtMatmulDescCreate(&matmul_desc, dt_compute, dt_c)); |
|
|
|
cublas_check(cublasLtMatmulDescSetAttribute( |
|
|
|
matmul_desc, CUBLASLT_MATMUL_DESC_POINTER_MODE, &pm, sizeof(pm))); |
|
|
|
|
|
|
@@ -100,7 +120,7 @@ void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) { |
|
|
|
* So we calculate C^t = B^t * A^t by cublas. Here the transpose symbol |
|
|
|
* implies row-major to column-major conversion |
|
|
|
*/ |
|
|
|
if (dt_compute == CUDA_R_32I) { |
|
|
|
if (dt_c == CUDA_R_32I) { |
|
|
|
/** |
|
|
|
* \NOTE: To use IMMA kernels, use computeType = CUDA_R_32I and |
|
|
|
* CUBLASLT_ORDER_COL32 for matrices A,C,D and |
|
|
@@ -209,7 +229,7 @@ void CUBLASLTMatmulDesc::set(const SizeArgs& args, bool batched) { |
|
|
|
bool CUBLASLTMatmulDesc::is_available(const SizeArgs& args, size_t ws_limit) { |
|
|
|
bool support; |
|
|
|
cublasLtMatmulAlgo_t algo; |
|
|
|
switch (dt_compute) { |
|
|
|
switch (dt_c) { |
|
|
|
case CUDA_R_16F: |
|
|
|
support = (dt_a == CUDA_R_16F); |
|
|
|
break; |
|
|
@@ -239,17 +259,17 @@ WorkspaceBundle CUBLASLTMatmulDesc::get_workspace_bundle( |
|
|
|
cublasLtMatmulHeuristicResult_t result{}; |
|
|
|
status = cublasLtMatmulAlgoCheck( |
|
|
|
cublasLt_handle, matmul_desc, |
|
|
|
dt_compute == CUDA_R_32I ? layout_trans_b : layout_b, |
|
|
|
dt_compute == CUDA_R_32I ? layout_trans_a : layout_a, |
|
|
|
dt_compute == CUDA_R_32I ? layout_trans_c : layout_c, |
|
|
|
dt_compute == CUDA_R_32I ? layout_trans_c : layout_c, &algo, |
|
|
|
dt_c == CUDA_R_32I ? layout_trans_b : layout_b, |
|
|
|
dt_c == CUDA_R_32I ? layout_trans_a : layout_a, |
|
|
|
dt_c == CUDA_R_32I ? layout_trans_c : layout_c, |
|
|
|
dt_c == CUDA_R_32I ? layout_trans_c : layout_c, &algo, |
|
|
|
&result); |
|
|
|
// return empty WorkspaceBundle if cublasLtMatmulAlgoCheck() failed |
|
|
|
if (status != CUBLAS_STATUS_SUCCESS) |
|
|
|
return {nullptr, {}}; |
|
|
|
algo_workspace_size = result.workspaceSize; |
|
|
|
return {nullptr, |
|
|
|
(dt_compute == CUDA_R_32I) |
|
|
|
(dt_c == CUDA_R_32I) |
|
|
|
? SmallVector<size_t>{algo_workspace_size, workspace_b, |
|
|
|
workspace_a, workspace_c} |
|
|
|
: SmallVector<size_t>{algo_workspace_size}}; |
|
|
@@ -273,7 +293,7 @@ bool CUBLASLTMatmulDesc::get_algorithm_heuristic(const SizeArgs& args, |
|
|
|
* \Note: algo_ws_limit must be zero if cublasLtGetVersion() <= 10100 |
|
|
|
*/ |
|
|
|
// algo_ws_limit = 0; |
|
|
|
if (dt_compute == CUDA_R_32I) { |
|
|
|
if (dt_c == CUDA_R_32I) { |
|
|
|
//[FIXME]: cublasLt(Version 10020) produce wrong result when k in |
|
|
|
//[64*n+1 , 64*n+32] for small matrix |
|
|
|
|
|
|
@@ -291,10 +311,10 @@ bool CUBLASLTMatmulDesc::get_algorithm_heuristic(const SizeArgs& args, |
|
|
|
sizeof(algo_ws_limit))); |
|
|
|
status = cublasLtMatmulAlgoGetHeuristic( |
|
|
|
cublasLt_handle, matmul_desc, |
|
|
|
dt_compute == CUDA_R_32I ? layout_trans_b : layout_b, |
|
|
|
dt_compute == CUDA_R_32I ? layout_trans_a : layout_a, |
|
|
|
dt_compute == CUDA_R_32I ? layout_trans_c : layout_c, |
|
|
|
dt_compute == CUDA_R_32I ? layout_trans_c : layout_c, algo_pref, 1, |
|
|
|
dt_c == CUDA_R_32I ? layout_trans_b : layout_b, |
|
|
|
dt_c == CUDA_R_32I ? layout_trans_a : layout_a, |
|
|
|
dt_c == CUDA_R_32I ? layout_trans_c : layout_c, |
|
|
|
dt_c == CUDA_R_32I ? layout_trans_c : layout_c, algo_pref, 1, |
|
|
|
&algo_result, &return_algo_count); |
|
|
|
if (status == CUBLAS_STATUS_SUCCESS && return_algo_count > 0 && |
|
|
|
// perform cublasLtAlgoCheck() to make sure the algo is correct |
|
|
|