|
@@ -22,7 +22,15 @@ bool BatchedMatrixMulForwardImpl::AlgoCublas::is_available(const SizeArgs& args) |
|
|
auto dtype = args.layout_a.dtype; |
|
|
auto dtype = args.layout_a.dtype; |
|
|
auto&& param = args.opr->param(); |
|
|
auto&& param = args.opr->param(); |
|
|
auto&& handle = concrete_handle(args.opr->handle()); |
|
|
auto&& handle = concrete_handle(args.opr->handle()); |
|
|
if (dtype == dtype::Float32()) |
|
|
|
|
|
|
|
|
// fix: cublasSgemmBatched with versions prior to 11.1 has some error when batch = 1 |
|
|
|
|
|
// and matricA's width > 8191 .So temporarily drop this algo when |
|
|
|
|
|
// args.layout_a.shape[2] <= 8191 || args.layout_a.shape[0] != 1 |
|
|
|
|
|
if (dtype == dtype::Float32() |
|
|
|
|
|
#if CUBLAS_VERSION < 11200 |
|
|
|
|
|
&& (args.layout_a.shape[args.opr->param().transposeA ? 1 : 2] <= 8191 || |
|
|
|
|
|
args.layout_a.shape[0] != 1) |
|
|
|
|
|
#endif |
|
|
|
|
|
) |
|
|
return true; |
|
|
return true; |
|
|
if (dtype != dtype::Float16()) |
|
|
if (dtype != dtype::Float16()) |
|
|
return false; |
|
|
return false; |
|
|