Browse Source

fix(dnn): drop batched matmul cublas algo when batch is 1

GitOrigin-RevId: 71126a27b0
release-1.7
Megvii Engine Team 3 years ago
parent
commit
849f0ece9d
1 changed files with 9 additions and 1 deletions
  1. +9
    -1
      dnn/src/cuda/batched_matrix_mul/cublas.cpp

+ 9
- 1
dnn/src/cuda/batched_matrix_mul/cublas.cpp View File

@@ -22,7 +22,15 @@ bool BatchedMatrixMulForwardImpl::AlgoCublas::is_available(const SizeArgs& args)
auto dtype = args.layout_a.dtype;
auto&& param = args.opr->param();
auto&& handle = concrete_handle(args.opr->handle());
if (dtype == dtype::Float32())
// fix: cublasSgemmBatched with versions prior to 11.1 has some error when batch = 1
// and matricA's width > 8191 .So temporarily drop this algo when
// args.layout_a.shape[2] <= 8191 || args.layout_a.shape[0] != 1
if (dtype == dtype::Float32()
#if CUBLAS_VERSION < 11200
&& (args.layout_a.shape[args.opr->param().transposeA ? 1 : 2] <= 8191 ||
args.layout_a.shape[0] != 1)
#endif
)
return true;
if (dtype != dtype::Float16())
return false;


Loading…
Cancel
Save