From 4b2b623b8b61d1dbaabb8f276c1862f224043c98 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 31 Mar 2021 17:58:34 +0800 Subject: [PATCH] fix(dnn/cuda): fix cutlass matmul splitk limit GitOrigin-RevId: fc9a7c638ca8087fd42d5386163e6d820214484c --- .../matrix_mul/cutlass_float32_simt_split_k.cpp | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp b/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp index 02d028da..5d68dfec 100644 --- a/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp +++ b/dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp @@ -23,12 +23,20 @@ using namespace cutlass_wrapper; bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( const SizeArgs& args) const { auto&& param = args.opr->param(); - int n = args.layout_c.shape[1], + int m = args.layout_c.shape[0], n = args.layout_c.shape[1], k = args.layout_a.shape[param.transposeA ? 0 : 1]; - return args.opr->param().format == param::MatrixMul::Format::DEFAULT && - args.layout_a.dtype == dtype::Float32() && - args.layout_b.dtype == dtype::Float32() && - args.layout_c.dtype == dtype::Float32() && k > n; + bool available = + args.opr->param().format == param::MatrixMul::Format::DEFAULT && + args.layout_a.dtype == dtype::Float32() && + args.layout_b.dtype == dtype::Float32() && + args.layout_c.dtype == dtype::Float32() && k > n; + auto&& device_prop = cuda::current_device_prop(); + int y_grid_limit = device_prop.maxGridSize[1]; + // limit y grid + available &= ((m + m_algo_param.threadblock_m - 1) / + m_algo_param.threadblock_m <= + y_grid_limit); + return available; } size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( @@ -36,7 +44,7 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( auto&& param = args.opr->param(); int m = args.layout_c.shape[0], n = args.layout_c.shape[1], k = args.layout_a.shape[param.transposeA ? 0 : 1]; - int split_k_slices = k / n; + int split_k_slices = std::max(1, k / n); return args.layout_c.dtype.size(m * n * split_k_slices); } @@ -49,7 +57,7 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; GemmCoord problem_size{m, n, k}; - int split_k_slices = k / n; + int split_k_slices = std::max(1, k / n); auto&& stream = cuda_stream(args.opr->handle()); int* workspace = reinterpret_cast(args.workspace.raw_ptr); return cutlass_matrix_mul_float32_simt(