Browse Source

fix(dnn/cuda): add block size limit for culass gemm algo

GitOrigin-RevId: c0940e4535
tags/v1.3.1
Megvii Engine Team 4 years ago
parent
commit
b717606989
1 changed files with 14 additions and 4 deletions
  1. +14
    -4
      dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp

+ 14
- 4
dnn/src/cuda/matrix_mul/cutlass_float32_simt.cpp View File

@@ -22,10 +22,20 @@ using namespace cutlass_wrapper;

bool MatrixMulForwardImpl::AlgoFloat32SIMT::is_available(
const SizeArgs& args) const {
return args.opr->param().format == param::MatrixMul::Format::DEFAULT &&
args.layout_a.dtype == dtype::Float32() &&
args.layout_b.dtype == dtype::Float32() &&
args.layout_c.dtype == dtype::Float32();
bool available =
args.opr->param().format == param::MatrixMul::Format::DEFAULT &&
args.layout_a.dtype == dtype::Float32() &&
args.layout_b.dtype == dtype::Float32() &&
args.layout_c.dtype == dtype::Float32();
int n = args.layout_c.shape[1];
auto&& device_prop = cuda::current_device_prop();
int y_grid_limit = device_prop.maxGridSize[1];
// limit y grid
available &= ((n + m_algo_param.threadblock_n - 1) /
m_algo_param.threadblock_n <=
y_grid_limit);

return available;
}

size_t MatrixMulForwardImpl::AlgoFloat32SIMT::get_workspace_in_bytes(


Loading…
Cancel
Save