|
@@ -55,12 +55,12 @@ size_t MatrixMulForwardImpl::AlgoFloat16TensorOpSplitK::get_workspace_in_bytes( |
|
|
k = args.layout_a.shape[param.transposeA ? 0 : 1]; |
|
|
k = args.layout_a.shape[param.transposeA ? 0 : 1]; |
|
|
int split_k_slices = std::max(1, k / n); |
|
|
int split_k_slices = std::max(1, k / n); |
|
|
if (!aligned.first) |
|
|
if (!aligned.first) |
|
|
return args.layout_c.dtype.size(m * n * split_k_slices); |
|
|
|
|
|
|
|
|
return sizeof(float) * (m * n * split_k_slices); |
|
|
const auto& layouts = aligned.second; |
|
|
const auto& layouts = aligned.second; |
|
|
int align_m = layouts[2].shape[0], align_n = layouts[2].shape[1], |
|
|
int align_m = layouts[2].shape[0], align_n = layouts[2].shape[1], |
|
|
align_k = layouts[0].shape[1]; |
|
|
align_k = layouts[0].shape[1]; |
|
|
split_k_slices = std::max(1, align_k / align_n); |
|
|
split_k_slices = std::max(1, align_k / align_n); |
|
|
size_t ws_size = args.layout_c.dtype.size(align_m * align_n * split_k_slices); |
|
|
|
|
|
|
|
|
size_t ws_size = sizeof(float) * (align_m * align_n * split_k_slices); |
|
|
for (auto&& ly : layouts) |
|
|
for (auto&& ly : layouts) |
|
|
ws_size += ly.span().dist_byte(); |
|
|
ws_size += ly.span().dist_byte(); |
|
|
return ws_size; |
|
|
return ws_size; |
|
|