|
|
@@ -604,9 +604,16 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src, |
|
|
|
if (!dst.valid()) { |
|
|
|
dst = supported_dst_dtype.at(0); |
|
|
|
} else { |
|
|
|
megdnn_assert(vec_contains(supported_dst_dtype, dst), |
|
|
|
"unsupported Conv(%s, %s) -> %s", src.name(), |
|
|
|
filter.name(), dst.name()); |
|
|
|
bool dst_supported = false; |
|
|
|
for (auto&& dt : supported_dst_dtype) { |
|
|
|
if (dtype_almost_equal(dt, dst)) { |
|
|
|
dst_supported = true; |
|
|
|
break; |
|
|
|
} |
|
|
|
} |
|
|
|
MEGDNN_MARK_USED_VAR(dst_supported); |
|
|
|
megdnn_assert(dst_supported, "unsupported Conv(%s, %s) -> %s", |
|
|
|
src.name(), filter.name(), dst.name()); |
|
|
|
} |
|
|
|
megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 |
|
|
|
#if !MEGDNN_DISABLE_FLOAT16 |
|
|
|