Browse Source

fix(python_module): fix conversion between numpy-ndarray and mgb tensor for qint4 and quint4

GitOrigin-RevId: 7450c4f25e
release-1.5
Megvii Engine Team 4 years ago
parent
commit
858261af1f
3 changed files with 26 additions and 8 deletions
  1. +1
    -0
      dnn/include/megdnn/oprs/base.h
  2. +2
    -0
      dnn/src/fallback/convolution/opr_impl.cpp
  3. +23
    -8
      src/core/impl/dtype.cpp

+ 1
- 0
dnn/include/megdnn/oprs/base.h View File

@@ -90,6 +90,7 @@ enum class AlgoDataType : uint32_t {
INT8X8X16 = 1 << 4, INT8X8X16 = 1 << 4,
INT16X16X32 = 1 << 5, INT16X16X32 = 1 << 5,
INT4X4X16 = 1 << 6, INT4X4X16 = 1 << 6,
QINT4x4x32 = 1 << 7,
}; };


/*! /*!


+ 2
- 0
dnn/src/fallback/convolution/opr_impl.cpp View File

@@ -434,6 +434,8 @@ ConvolutionImpl::NCBKernSizeParam::deduce_algo_data_type() const {
} }
} else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) { } else if (src_type.enumv() == DTypeEnum::Quantized8Asymm) {
return ConvolutionImpl::AlgoDataType::QUINT8X8X32; return ConvolutionImpl::AlgoDataType::QUINT8X8X32;
} else if (src_type.enumv() == DTypeEnum::QuantizedS4) {
return ConvolutionImpl::AlgoDataType::QINT4x4x32;
} else { } else {
megdnn_throw(ssprintf("not support data type of %s * %s -> %s\n", megdnn_throw(ssprintf("not support data type of %s * %s -> %s\n",
src_type.name(), filter_type.name(), src_type.name(), filter_type.name(),


+ 23
- 8
src/core/impl/dtype.cpp View File

@@ -14,7 +14,6 @@
#include "megbrain/exception.h" #include "megbrain/exception.h"
#include "megbrain/utils/metahelper.h" #include "megbrain/utils/metahelper.h"
#include "megbrain/utils/arith_helper.h" #include "megbrain/utils/arith_helper.h"
#include "megdnn/dtype.h"


#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
@@ -383,24 +382,40 @@ struct QuantizedLowbitMemcpy<DT, true> {
// cast with bits that 8 % bits == 0 // cast with bits that 8 % bits == 0
static constexpr uint16_t bits = DTypeTrait<DT>::low_bit; static constexpr uint16_t bits = DTypeTrait<DT>::low_bit;
static constexpr uint8_t MASK = (1 << bits) - 1; static constexpr uint8_t MASK = (1 << bits) - 1;
using Trait = QuantizedLowbitTrait<DT>;
static constexpr bool signedness =
std::is_same<DT, dtype::QuantizedS4>::value;


static void byte2compact(void* dest_raw, const void* src_raw, size_t n) { static void byte2compact(void* dest_raw, const void* src_raw, size_t n) {
auto dest = static_cast<uint8_t*>(dest_raw); auto dest = static_cast<uint8_t*>(dest_raw);
auto src = static_cast<const int8_t*>(src_raw); auto src = static_cast<const int8_t*>(src_raw);
memset(dest, 0, divup<size_t>(n * bits, 8)); memset(dest, 0, divup<size_t>(n * bits, 8));
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
int8_t val = src[i] + Trait::SHIFT;
mgb_assert(val >= 0 && val < (1 << bits));
dest[i * bits / 8] |= val << (i * bits % 8);
int8_t val = src[i];
static const auto min_val = DTypeTrait<DT>::min();
static const auto max_val = DTypeTrait<DT>::max();
MGB_MARK_USED_VAR(min_val);
MGB_MARK_USED_VAR(max_val);
mgb_assert(val >= static_cast<int8_t>(min_val) &&
val <= static_cast<int8_t>(max_val),
"data exceeds range(%d,%d) of data type", min_val,
max_val);
dest[i * bits / 8] |= (val & MASK) << (i * bits % 8);
} }
} }
static void compact2byte(void* dest_raw, const void* src_raw, size_t n) { static void compact2byte(void* dest_raw, const void* src_raw, size_t n) {
auto dest = static_cast<int8_t*>(dest_raw);
auto dest = reinterpret_cast<int8_t*>(dest_raw);
auto src = static_cast<const uint8_t*>(src_raw); auto src = static_cast<const uint8_t*>(src_raw);
for (size_t i = 0; i < n; ++i) { for (size_t i = 0; i < n; ++i) {
int8_t val = ((src[i * bits / 8] >> (i * bits % 8)) & MASK);
dest[i] = val - Trait::SHIFT;
uint8_t intermediate =
((src[i * bits / 8] >> (i * bits % 8)) & MASK);
if (signedness) {
int val = (intermediate & uint8_t(1 << (bits - 1)))
? ((int)(intermediate) | ~(int)(MASK))
: (int)(intermediate);
dest[i] = static_cast<int8_t>(val);
} else {
dest[i] = static_cast<int8_t>(intermediate);
}
} }
} }
}; };


Loading…
Cancel
Save