Browse Source

fix(dnn): fix data type check for quantized convolution

GitOrigin-RevId: e0f97052ff
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
8ac7333378
4 changed files with 52 additions and 3 deletions
  1. +9
    -0
      dnn/include/megdnn/dtype.h
  2. +10
    -3
      dnn/src/common/convolution.cpp
  3. +19
    -0
      dnn/src/common/utils.cpp
  4. +14
    -0
      dnn/src/common/utils.h

+ 9
- 0
dnn/include/megdnn/dtype.h View File

@@ -132,6 +132,15 @@ namespace megdnn {
cb(::megdnn::dtype::Quantized4Asymm) \ cb(::megdnn::dtype::Quantized4Asymm) \
cb(::megdnn::dtype::QuantizedS4) cb(::megdnn::dtype::QuantizedS4)


#define MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb) \
cb(::megdnn::dtype::QuantizedS32) \
cb(::megdnn::dtype::QuantizedS8) \
cb(::megdnn::dtype::QuantizedS4)

#define MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb) \
cb(::megdnn::dtype::Quantized8Asymm) \
cb(::megdnn::dtype::Quantized4Asymm)

/*! /*!
* \brief a POD representation of a single byte * \brief a POD representation of a single byte
* *


+ 10
- 3
dnn/src/common/convolution.cpp View File

@@ -604,9 +604,16 @@ void ConvolutionBase<Parameter>::check_or_deduce_dtype_fwd(DType src,
if (!dst.valid()) { if (!dst.valid()) {
dst = supported_dst_dtype.at(0); dst = supported_dst_dtype.at(0);
} else { } else {
megdnn_assert(vec_contains(supported_dst_dtype, dst),
"unsupported Conv(%s, %s) -> %s", src.name(),
filter.name(), dst.name());
bool dst_supported = false;
for (auto&& dt : supported_dst_dtype) {
if (dtype_almost_equal(dt, dst)) {
dst_supported = true;
break;
}
}
MEGDNN_MARK_USED_VAR(dst_supported);
megdnn_assert(dst_supported, "unsupported Conv(%s, %s) -> %s",
src.name(), filter.name(), dst.name());
} }
megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32 megdnn_assert(param().compute_mode != Param::ComputeMode::FLOAT32
#if !MEGDNN_DISABLE_FLOAT16 #if !MEGDNN_DISABLE_FLOAT16


+ 19
- 0
dnn/src/common/utils.cpp View File

@@ -245,6 +245,25 @@ float megdnn::mul_scale(DType lhs, DType rhs) {
} }
// clang-format on // clang-format on


bool megdnn::dtype_almost_equal(DType lhs, DType rhs) {
if (lhs.enumv() != rhs.enumv())
return false;
if (lhs.category() != DTypeCategory::QUANTIZED)
return true;
#define cb(dt) \
if (lhs.enumv() == DTypeTrait<dt>::enumv) \
return almost_equal(lhs.param<dt>().scale, rhs.param<dt>().scale);
MEGDNN_FOREACH_QUANTIZED_DTYPE_SYMM(cb)
#undef cb
#define cb(dt) \
if (lhs.enumv() == DTypeTrait<dt>::enumv) \
return almost_equal(lhs.param<dt>().scale, rhs.param<dt>().scale) && \
lhs.param<dt>().zero_point == rhs.param<dt>().zero_point;
MEGDNN_FOREACH_QUANTIZED_DTYPE_ASYMM(cb)
#undef cb
megdnn_assert_internal(false);
}

template <> template <>
uint8_t megdnn::convert<dt_quint4, uint8_t>(dt_quint4 src, uint8_t dst, uint8_t megdnn::convert<dt_quint4, uint8_t>(dt_quint4 src, uint8_t dst,
size_t offset) { size_t offset) {


+ 14
- 0
dnn/src/common/utils.h View File

@@ -434,6 +434,20 @@ int8_t convert<dt_qint4, int8_t>(dt_qint4 src, int8_t dst, size_t offset);
template <> template <>
dt_qint4 convert<int8_t, dt_qint4>(int8_t src, dt_qint4 dst, size_t offset); dt_qint4 convert<int8_t, dt_qint4>(int8_t src, dt_qint4 dst, size_t offset);


/*!
* \brief check float equal within given ULP(unit in the last place)
*/
template <class T>
static inline
typename std::enable_if<!std::numeric_limits<T>::is_integer, bool>::type
almost_equal(T x, T y, int unit_last_place = 1) {
return std::abs(x - y) < (std::numeric_limits<T>::epsilon() *
std::abs(x + y) * unit_last_place) ||
std::abs(x - y) < std::numeric_limits<T>::min();
}

bool dtype_almost_equal(DType lhs, DType rhs);

/** /**
* \brief N-dimensional index space * \brief N-dimensional index space
*/ */


Loading…
Cancel
Save