|
|
@@ -40,6 +40,12 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available(const SizeArgs& args) cons |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
// In conv_args.init_conv_desc will call cudnnSetTensor4dDescriptorEx(),which can't |
|
|
|
// been supported when total_nr_elems() > 2 ^ 31 |
|
|
|
if (args.src_layout->total_nr_elems() > INT_MAX || |
|
|
|
args.dst_layout->total_nr_elems() > INT_MAX) { |
|
|
|
return false; |
|
|
|
} |
|
|
|
auto dst_layout = *args.dst_layout; |
|
|
|
if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) { |
|
|
|
dst_layout.dtype = DType(); |
|
|
|