diff --git a/dnn/src/cuda/batched_matrix_mul/cublas.cpp b/dnn/src/cuda/batched_matrix_mul/cublas.cpp index b2261fec..7a1a1e43 100644 --- a/dnn/src/cuda/batched_matrix_mul/cublas.cpp +++ b/dnn/src/cuda/batched_matrix_mul/cublas.cpp @@ -22,7 +22,15 @@ bool BatchedMatrixMulForwardImpl::AlgoCublas::is_available(const SizeArgs& args) auto dtype = args.layout_a.dtype; auto&& param = args.opr->param(); auto&& handle = concrete_handle(args.opr->handle()); - if (dtype == dtype::Float32()) + // fix: cublasSgemmBatched with versions prior to 11.1 has some error when batch = 1 + // and matricA's width > 8191 .So temporarily drop this algo when + // args.layout_a.shape[2] <= 8191 || args.layout_a.shape[0] != 1 + if (dtype == dtype::Float32() +#if CUBLAS_VERSION < 11200 + && (args.layout_a.shape[args.opr->param().transposeA ? 1 : 2] <= 8191 || + args.layout_a.shape[0] != 1) +#endif + ) return true; if (dtype != dtype::Float16()) return false;