|
|
@@ -23,12 +23,20 @@ using namespace cutlass_wrapper; |
|
|
|
bool MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::is_available( |
|
|
|
const SizeArgs& args) const { |
|
|
|
auto&& param = args.opr->param(); |
|
|
|
int n = args.layout_c.shape[1], |
|
|
|
int m = args.layout_c.shape[0], n = args.layout_c.shape[1], |
|
|
|
k = args.layout_a.shape[param.transposeA ? 0 : 1]; |
|
|
|
return args.opr->param().format == param::MatrixMul::Format::DEFAULT && |
|
|
|
args.layout_a.dtype == dtype::Float32() && |
|
|
|
args.layout_b.dtype == dtype::Float32() && |
|
|
|
args.layout_c.dtype == dtype::Float32() && k > n; |
|
|
|
bool available = |
|
|
|
args.opr->param().format == param::MatrixMul::Format::DEFAULT && |
|
|
|
args.layout_a.dtype == dtype::Float32() && |
|
|
|
args.layout_b.dtype == dtype::Float32() && |
|
|
|
args.layout_c.dtype == dtype::Float32() && k > n; |
|
|
|
auto&& device_prop = cuda::current_device_prop(); |
|
|
|
int y_grid_limit = device_prop.maxGridSize[1]; |
|
|
|
// limit y grid |
|
|
|
available &= ((m + m_algo_param.threadblock_m - 1) / |
|
|
|
m_algo_param.threadblock_m <= |
|
|
|
y_grid_limit); |
|
|
|
return available; |
|
|
|
} |
|
|
|
|
|
|
|
size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( |
|
|
@@ -36,7 +44,7 @@ size_t MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::get_workspace_in_bytes( |
|
|
|
auto&& param = args.opr->param(); |
|
|
|
int m = args.layout_c.shape[0], n = args.layout_c.shape[1], |
|
|
|
k = args.layout_a.shape[param.transposeA ? 0 : 1]; |
|
|
|
int split_k_slices = k / n; |
|
|
|
int split_k_slices = std::max(1, k / n); |
|
|
|
return args.layout_c.dtype.size(m * n * split_k_slices); |
|
|
|
} |
|
|
|
|
|
|
@@ -49,7 +57,7 @@ void MatrixMulForwardImpl::AlgoFloat32SIMTSplitK::exec( |
|
|
|
int m = args.tensor_c.layout.shape[0], n = args.tensor_c.layout.shape[1], |
|
|
|
k = args.tensor_a.layout.shape[param.transposeA ? 0 : 1]; |
|
|
|
GemmCoord problem_size{m, n, k}; |
|
|
|
int split_k_slices = k / n; |
|
|
|
int split_k_slices = std::max(1, k / n); |
|
|
|
auto&& stream = cuda_stream(args.opr->handle()); |
|
|
|
int* workspace = reinterpret_cast<int*>(args.workspace.raw_ptr); |
|
|
|
return cutlass_matrix_mul_float32_simt( |
|
|
|