|
|
@@ -22,10 +22,20 @@ using namespace cutlass_wrapper; |
|
|
|
|
|
|
|
bool MatrixMulForwardImpl::AlgoFloat32SIMT::is_available( |
|
|
|
const SizeArgs& args) const { |
|
|
|
return args.opr->param().format == param::MatrixMul::Format::DEFAULT && |
|
|
|
args.layout_a.dtype == dtype::Float32() && |
|
|
|
args.layout_b.dtype == dtype::Float32() && |
|
|
|
args.layout_c.dtype == dtype::Float32(); |
|
|
|
bool available = |
|
|
|
args.opr->param().format == param::MatrixMul::Format::DEFAULT && |
|
|
|
args.layout_a.dtype == dtype::Float32() && |
|
|
|
args.layout_b.dtype == dtype::Float32() && |
|
|
|
args.layout_c.dtype == dtype::Float32(); |
|
|
|
int n = args.layout_c.shape[1]; |
|
|
|
auto&& device_prop = cuda::current_device_prop(); |
|
|
|
int y_grid_limit = device_prop.maxGridSize[1]; |
|
|
|
// limit y grid |
|
|
|
available &= ((n + m_algo_param.threadblock_n - 1) / |
|
|
|
m_algo_param.threadblock_n <= |
|
|
|
y_grid_limit); |
|
|
|
|
|
|
|
return available; |
|
|
|
} |
|
|
|
|
|
|
|
size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes( |
|
|
|