diff --git a/dnn/include/megdnn/dtype.h b/dnn/include/megdnn/dtype.h index aae14fec..df3319f0 100644 --- a/dnn/include/megdnn/dtype.h +++ b/dnn/include/megdnn/dtype.h @@ -132,6 +132,15 @@ namespace megdnn { cb(::megdnn::dtype::Quantized4Asymm) \ cb(::megdnn::dtype::QuantizedS4) +#define MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb) \ + cb(::megdnn::dtype::QuantizedS32) \ + cb(::megdnn::dtype::QuantizedS8) \ + cb(::megdnn::dtype::QuantizedS4) + +#define MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb) \ + cb(::megdnn::dtype::Quantized8Asymm) \ + cb(::megdnn::dtype::Quantized4Asymm) + /*! * \brief a POD representation of a single byte * diff --git a/dnn/src/common/convolution.cpp b/dnn/src/common/convolution.cpp index 41b1249a..0a89010f 100644 --- a/dnn/src/common/convolution.cpp +++ b/dnn/src/common/convolution.cpp @@ -604,9 +604,16 @@ void ConvolutionBase::check_or_deduce_dtype_fwd(DType src, if (!dst.valid()) { dst = supported_dst_dtype.at(0); } else { - megdnn_assert(vec_contains(supported_dst_dtype, dst), - "unsupported Conv(%s, %s) -> %s", src.name(), - filter.name(), dst.name()); + bool dst_supported = false; + for (auto&& dt : supported_dst_dtype) { + if (dtype_almost_equal(dt, dst)) { + dst_supported = true; + break; + } + } + MEGDNN_MARK_USED_VAR(dst_supported); + megdnn_assert(dst_supported, "unsupported Conv(%s, %s) -> %s", + src.name(), filter.name(), dst.name()); } megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 #if !MEGDNN_DISABLE_FLOAT16 diff --git a/dnn/src/common/utils.cpp b/dnn/src/common/utils.cpp index 371afbb5..2dd2ca70 100644 --- a/dnn/src/common/utils.cpp +++ b/dnn/src/common/utils.cpp @@ -245,6 +245,25 @@ float megdnn::mul_scale(DType lhs, DType rhs) { } // clang-format on +bool megdnn::dtype_almost_equal(DType lhs, DType rhs) { + if (lhs.enumv() != rhs.enumv()) + return false; + if (lhs.category() != DTypeCategory::QUANTIZED) + return true; +#define cb(dt) \ + if (lhs.enumv() == DTypeTrait
::enumv) \ + return almost_equal(lhs.param
().scale, rhs.param
().scale); + MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb) +#undef cb +#define cb(dt) \ + if (lhs.enumv() == DTypeTrait
::enumv) \ + return almost_equal(lhs.param
().scale, rhs.param
().scale) && \ + lhs.param
().zero_point == rhs.param
().zero_point; + MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb) +#undef cb + megdnn_assert_internal(false); +} + template <> uint8_t megdnn::convert(dt_quint4 src, uint8_t dst, size_t offset) { diff --git a/dnn/src/common/utils.h b/dnn/src/common/utils.h index a12f2a53..1061ab47 100644 --- a/dnn/src/common/utils.h +++ b/dnn/src/common/utils.h @@ -434,6 +434,20 @@ int8_t convert(dt_qint4 src, int8_t dst, size_t offset); template <> dt_qint4 convert(int8_t src, dt_qint4 dst, size_t offset); +/*! + * \brief check float equal within given ULP(unit in the last place) + */ +template +static inline + typename std::enable_if::is_integer, bool>::type + almost_equal(T x, T y, int unit_last_place = 1) { + return std::abs(x - y) < (std::numeric_limits::epsilon() * + std::abs(x + y) * unit_last_place) || + std::abs(x - y) < std::numeric_limits::min(); +} + +bool dtype_almost_equal(DType lhs, DType rhs); + /** * \brief N-dimensional index space */