Browse Source

fix(dnn/cuda): fix cutlass matmul splitk limit

GitOrigin-RevId: fc9a7c638c
release-1.4
Megvii Engine Team 4 years ago
parent
commit
4b2b623b8b
1 changed files with 15 additions and 7 deletions
  1. +15
    -7
      dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp

+ 15
- 7
dnn/src/cuda/matrix_mul/cutlass_float32_simt_split_k.cpp View File

@@ -23,12 +23,20 @@ using namespace cutlass_wrapper;
bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available(
const SizeArgs& args) const {
auto&& param = args.opr->param();
int n = args.layout_c.shape[1],
int m = args.layout_c.shape[0], n = args.layout_c.shape[1],
k = args.layout_a.shape[param.transposeA ? 0 : 1];
return args.opr->param().format == param::MatrixMul::Format::DEFAULT &&
args.layout_a.dtype == dtype::Float32() &&
args.layout_b.dtype == dtype::Float32() &&
args.layout_c.dtype == dtype::Float32() && k > n;
bool available =
args.opr->param().format == param::MatrixMul::Format::DEFAULT &&
args.layout_a.dtype == dtype::Float32() &&
args.layout_b.dtype == dtype::Float32() &&
args.layout_c.dtype == dtype::Float32() && k > n;
auto&& device_prop = cuda::current_device_prop();
int y_grid_limit = device_prop.maxGridSize[1];
// limit y grid
available &= ((m + m_algo_param.threadblock_m - 1) /
m_algo_param.threadblock_m <=
y_grid_limit);
return available;
}

size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes(
@@ -36,7 +44,7 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes(
auto&& param = args.opr->param();
int m = args.layout_c.shape[0], n = args.layout_c.shape[1],
k = args.layout_a.shape[param.transposeA ? 0 : 1];
int split_k_slices = k / n;
int split_k_slices = std::max(1, k / n);
return args.layout_c.dtype.size(m * n * split_k_slices);
}

@@ -49,7 +57,7 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec(
int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1],
k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1];
GemmCoord problem_size{m, n, k};
int split_k_slices = k / n;
int split_k_slices = std::max(1, k / n);
auto&& stream = cuda_stream(args.opr->handle());
int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr);
return cutlass_matrix_mul_float32_simt(


Loading…
Cancel
Save