GitOrigin-RevId: e0f97052ff
tags/v0.4.0
@@ -132,6 +132,15 @@ namespace megdnn { | |||||
cb(::megdnn::dtype::Quantized4Asymm) \ | cb(::megdnn::dtype::Quantized4Asymm) \ | ||||
cb(::megdnn::dtype::QuantizedS4) | cb(::megdnn::dtype::QuantizedS4) | ||||
#define MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb) \ | |||||
cb(::megdnn::dtype::QuantizedS32) \ | |||||
cb(::megdnn::dtype::QuantizedS8) \ | |||||
cb(::megdnn::dtype::QuantizedS4) | |||||
#define MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb) \ | |||||
cb(::megdnn::dtype::Quantized8Asymm) \ | |||||
cb(::megdnn::dtype::Quantized4Asymm) | |||||
/*! | /*! | ||||
* \brief a POD representation of a single byte | * \brief a POD representation of a single byte | ||||
* | * | ||||
@@ -604,9 +604,16 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src, | |||||
if (!dst.valid()) { | if (!dst.valid()) { | ||||
dst = supported_dst_dtype.at(0); | dst = supported_dst_dtype.at(0); | ||||
} else { | } else { | ||||
megdnn_assert(vec_contains(supported_dst_dtype, dst), | |||||
"unsupported Conv(%s, %s) -> %s", src.name(), | |||||
filter.name(), dst.name()); | |||||
bool dst_supported = false; | |||||
for (auto&& dt : supported_dst_dtype) { | |||||
if (dtype_almost_equal(dt, dst)) { | |||||
dst_supported = true; | |||||
break; | |||||
} | |||||
} | |||||
MEGDNN_MARK_USED_VAR(dst_supported); | |||||
megdnn_assert(dst_supported, "unsupported Conv(%s, %s) -> %s", | |||||
src.name(), filter.name(), dst.name()); | |||||
} | } | ||||
megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 | megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 | ||||
#if !MEGDNN_DISABLE_FLOAT16 | #if !MEGDNN_DISABLE_FLOAT16 | ||||
@@ -245,6 +245,25 @@ float megdnn::mul_scale(DType lhs, DType rhs) { | |||||
} | } | ||||
// clang-format on | // clang-format on | ||||
bool megdnn::dtype_almost_equal(DType lhs, DType rhs) { | |||||
if (lhs.enumv() != rhs.enumv()) | |||||
return false; | |||||
if (lhs.category() != DTypeCategory::QUANTIZED) | |||||
return true; | |||||
#define cb(dt) \ | |||||
if (lhs.enumv() == DTypeTrait<dt>::enumv) \ | |||||
return almost_equal(lhs.param<dt>().scale, rhs.param<dt>().scale); | |||||
MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb) | |||||
#undef cb | |||||
#define cb(dt) \ | |||||
if (lhs.enumv() == DTypeTrait<dt>::enumv) \ | |||||
return almost_equal(lhs.param<dt>().scale, rhs.param<dt>().scale) && \ | |||||
lhs.param<dt>().zero_point == rhs.param<dt>().zero_point; | |||||
MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb) | |||||
#undef cb | |||||
megdnn_assert_internal(false); | |||||
} | |||||
template <> | template <> | ||||
uint8_t megdnn::convert<dt_quint4, uint8_t>(dt_quint4 src, uint8_t dst, | uint8_t megdnn::convert<dt_quint4, uint8_t>(dt_quint4 src, uint8_t dst, | ||||
size_t offset) { | size_t offset) { | ||||
@@ -434,6 +434,20 @@ int8_t convert<dt_qint4, int8_t>(dt_qint4 src, int8_t dst, size_t offset); | |||||
template <> | template <> | ||||
dt_qint4 convert<int8_t, dt_qint4>(int8_t src, dt_qint4 dst, size_t offset); | dt_qint4 convert<int8_t, dt_qint4>(int8_t src, dt_qint4 dst, size_t offset); | ||||
/*! | |||||
* \brief check float equal within given ULP(unit in the last place) | |||||
*/ | |||||
template <class T> | |||||
static inline | |||||
typename std::enable_if<!std::numeric_limits<T>::is_integer, bool>::type | |||||
almost_equal(T x, T y, int unit_last_place = 1) { | |||||
return std::abs(x - y) < (std::numeric_limits<T>::epsilon() * | |||||
std::abs(x + y) * unit_last_place) || | |||||
std::abs(x - y) < std::numeric_limits<T>::min(); | |||||
} | |||||
bool dtype_almost_equal(DType lhs, DType rhs); | |||||
/** | /** | ||||
* \brief N-dimensional index space | * \brief N-dimensional index space | ||||
*/ | */ | ||||