Browse Source

feat(mge):add qint4/quint4 to python

GitOrigin-RevId: f94609db00
tags/v0.3.2
Megvii Engine Team 5 years ago
parent
commit
fc6aa12e2f
4 changed files with 207 additions and 4 deletions
  1. +6
    -1
      dnn/src/fallback/type_cvt/opr_impl.cpp
  2. +101
    -1
      python_module/megengine/_internal/dtype.py
  3. +43
    -2
      python_module/src/cpp/python_helper.cpp
  4. +57
    -0
      src/core/impl/dtype.cpp

+ 6
- 1
dnn/src/fallback/type_cvt/opr_impl.cpp View File

@@ -451,7 +451,12 @@ namespace fallback {


void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) {
check_exec(src.layout, dst.layout); check_exec(src.layout, dst.layout);
if (src.layout.is_contiguous() && dst.layout.is_contiguous()) {
auto is_quantize_lowbit = [](const DType& dt) {
return dt.category() == DTypeCategory::QUANTIZED && dt.is_low_bit();
};
if (src.layout.is_contiguous() && dst.layout.is_contiguous() &&
!is_quantize_lowbit(src.layout.dtype) &&
!is_quantize_lowbit(dst.layout.dtype)) {
MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst)); MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst));
} else { } else {
naive::TypeCvtImpl::exec(src, dst); naive::TypeCvtImpl::exec(src, dst);


+ 101
- 1
python_module/megengine/_internal/dtype.py View File

@@ -32,7 +32,7 @@ def get_scale(dtype):
def get_zero_point(dtype): def get_zero_point(dtype):
assert is_quantize(dtype) assert is_quantize(dtype)
metadata = dtype.metadata["mgb_dtype"] metadata = dtype.metadata["mgb_dtype"]
assert metadata["name"] == "Quantized8Asymm"
assert metadata["name"] in ("Quantized8Asymm", "Quantized4Asymm")
return metadata["zero_point"] return metadata["zero_point"]




@@ -79,6 +79,38 @@ def qint32(scale):
) )




def quint4(scale, zero_point):
"""
Consturct a quantized unsigned int4 data type with ``scale`` (float) and
``zero_point`` (uint8). The real value represented by a quint4 data type is
float_val = scale * (uint4_val - zero_point)
"""
int_zp = int(zero_point)
assert int_zp == zero_point, "zero_point should be an integer"
if int_zp < 0 or int_zp > 15:
raise ValueError("zero_point should be within [0, 15] for quint4")
return np.dtype(
np.uint8,
metadata={
"mgb_dtype": {
"name": "Quantized4Asymm",
"scale": float(scale),
"zero_point": int(zero_point),
}
},
)


def qint4(scale):
"""
Construct a quantized int4 data type with ``scale`` (float). The real value
represented by a qint4 data type is float_val = scale * int4_val
"""
return np.dtype(
np.int8, metadata={"mgb_dtype": {"name": "QuantizedS4", "scale": float(scale)}}
)


def convert_to_quint8(arr, q): def convert_to_quint8(arr, q):
""" """
Quantize a float NumPy ndarray into a quint8 one with specified params. Quantize a float NumPy ndarray into a quint8 one with specified params.
@@ -177,3 +209,71 @@ def convert_from_qint32(arr):
), "arr should be a ndarray with qint8 dtype" ), "arr should be a ndarray with qint8 dtype"
scale = arr.dtype.metadata["mgb_dtype"]["scale"] scale = arr.dtype.metadata["mgb_dtype"]["scale"]
return arr.astype(np.float32) * scale return arr.astype(np.float32) * scale


def convert_to_quint4(arr, q):
"""
Quantize a float NumPy ndarray into a quint4 one with specified params.

:param arr: Input ndarray.
:type arr: :class:`np.ndarray`
:param q: Target data type, should be a quint4.
:type q: :class:`np.dtype`
"""
assert isinstance(arr, np.ndarray)
assert (
"mgb_dtype" in q.metadata
and q.metadata["mgb_dtype"]["name"] == "Quantized4Asymm"
), "q should be a quint4 dtype"
scale, zp = q.metadata["mgb_dtype"]["scale"], q.metadata["mgb_dtype"]["zero_point"]
return (np.round(arr / scale) + zp).clip(0, 15).astype(q)


def convert_from_quint4(arr):
"""
Dequantize a quint4 NumPy ndarray into a float one.

:param arr: Input ndarray.
"""
assert isinstance(arr, np.ndarray)
assert (
"mgb_dtype" in arr.dtype.metadata
and arr.dtype.metadata["mgb_dtype"]["name"] == "Quantized4Asymm"
), "arr should be a ndarray with quint4 dtype"
scale, zp = (
arr.dtype.metadata["mgb_dtype"]["scale"],
arr.dtype.metadata["mgb_dtype"]["zero_point"],
)
return (arr.astype(np.float32) - zp) * scale


def convert_to_qint4(arr, q):
"""
Quantize a float NumPy ndarray into a qint4 one with specified params.

:param arr: Input ndarray.
:type arr: :class:`np.ndarray`
:param q: Target data type, should be a qint4.
:type q: :class:`np.dtype`
"""
assert isinstance(arr, np.ndarray)
assert (
"mgb_dtype" in q.metadata and q.metadata["mgb_dtype"]["name"] == "QuantizedS4"
), "q should be a qint4 dtype"
scale = q.metadata["mgb_dtype"]["scale"]
return (np.round(arr / scale)).clip(-8, 7).astype(q)


def convert_from_qint4(arr):
"""
Dequantize a qint4 NumPy ndarray into a float one.

:param arr: Input ndarray.
"""
assert isinstance(arr, np.ndarray)
assert (
"mgb_dtype" in arr.dtype.metadata
and arr.dtype.metadata["mgb_dtype"]["name"] == "QuantizedS4"
), "arr should be a ndarray with qint4 dtype"
scale = arr.dtype.metadata["mgb_dtype"]["scale"]
return arr.astype(np.float32) * scale

+ 43
- 2
python_module/src/cpp/python_helper.cpp View File

@@ -452,6 +452,23 @@ std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr(
{{"scale", PyFloat_FromDouble(param.scale)}}); {{"scale", PyFloat_FromDouble(param.scale)}});
break; break;
} }
case DTypeEnum::Quantized4Asymm: {
auto& param = dtype.param<dtype::Quantized4Asymm>();
type_descr = PyArray_DescrNewFromType(NPY_UINT8);
type_descr->metadata = build_mgb_dtype_dict(
DTypeTrait<dtype::Quantized4Asymm>::name,
{{"scale", PyFloat_FromDouble(param.scale)},
{"zero_point", PyLong_FromLong(param.zero_point)}});
break;
}
case DTypeEnum::QuantizedS4: {
auto& param = dtype.param<dtype::QuantizedS4>();
type_descr = PyArray_DescrNewFromType(NPY_INT8);
type_descr->metadata = build_mgb_dtype_dict(
DTypeTrait<dtype::QuantizedS4>::name,
{{"scale", PyFloat_FromDouble(param.scale)}});
break;
}
case DTypeEnum::QuantizedS32: { case DTypeEnum::QuantizedS32: {
auto& param = dtype.param<dtype::QuantizedS32>(); auto& param = dtype.param<dtype::QuantizedS32>();
type_descr = PyArray_DescrNewFromType(NPY_INT32); type_descr = PyArray_DescrNewFromType(NPY_INT32);
@@ -529,7 +546,29 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) {
static_cast<float>(PyFloat_AS_DOUBLE(scale_py)), static_cast<float>(PyFloat_AS_DOUBLE(scale_py)),
static_cast<uint8_t>(zero_point)); static_cast<uint8_t>(zero_point));
} }
if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8") {
if (dtype_name == "Quantized4Asymm") {
PyObject* scale_py = PyDict_GetItemString(metadata, "scale");
PyObject* zero_point_py =
PyDict_GetItemString(metadata, "zero_point");
mgb_assert(scale_py && zero_point_py,
"Invalid Quantized4Asymm metadata: missing scale or "
"zero_point.");
mgb_assert(
PyFloat_Check(scale_py),
"Invalid Quantized4Asymm metadata: scale should be float");
mgb_assert(PyLong_Check(zero_point_py),
"Invalid Quantized4Asymm metadata: zero_point should be "
"integer");
auto zero_point = PyLong_AS_LONG(zero_point_py);
mgb_assert(zero_point >= 0 && zero_point < 15,
"Invalid Quantized4Asymm metadata: zero_point should be "
"in [0, 15)");
return dtype::Quantized4Asymm(
static_cast<float>(PyFloat_AS_DOUBLE(scale_py)),
static_cast<uint8_t>(zero_point));
}
if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8" ||
dtype_name == "QuantizedS4") {
PyObject* scale_py = PyDict_GetItemString(metadata, "scale"); PyObject* scale_py = PyDict_GetItemString(metadata, "scale");
mgb_assert(scale_py, "Invalid metadata: missing scale"); mgb_assert(scale_py, "Invalid metadata: missing scale");
mgb_assert(PyFloat_Check(scale_py), mgb_assert(PyFloat_Check(scale_py),
@@ -537,8 +576,10 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) {
float scale = static_cast<float>(PyFloat_AS_DOUBLE(scale_py)); float scale = static_cast<float>(PyFloat_AS_DOUBLE(scale_py));
if (dtype_name == "QuantizedS32") { if (dtype_name == "QuantizedS32") {
return dtype::QuantizedS32(scale); return dtype::QuantizedS32(scale);
} else {
} else if (dtype_name == "QuantizedS8"){
return dtype::QuantizedS8(scale); return dtype::QuantizedS8(scale);
} else {
return dtype::QuantizedS4(scale);
} }
} }
throw ConversionError( throw ConversionError(


+ 57
- 0
src/core/impl/dtype.cpp View File

@@ -14,6 +14,7 @@
#include "megbrain/exception.h" #include "megbrain/exception.h"
#include "megbrain/utils/metahelper.h" #include "megbrain/utils/metahelper.h"
#include "megbrain/utils/arith_helper.h" #include "megbrain/utils/arith_helper.h"
#include "megdnn/dtype.h"


#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
@@ -357,6 +358,52 @@ struct LowbitMemcpy<bits, true> {
} }
} }
}; };

template<typename DT>
struct QuantizedLowbitTrait;

template<>
struct QuantizedLowbitTrait<dtype::Quantized4Asymm> {
static constexpr int8_t SHIFT = 0;
};

template<>
struct QuantizedLowbitTrait<dtype::QuantizedS4> {
static constexpr int8_t SHIFT = 8;
};

template <typename DT, bool div_byte = (DTypeTrait<DT>::category ==
DTypeCategory::QUANTIZED) &&
(8 % DTypeTrait<DT>::low_bit == 0)>
struct QuantizedLowbitMemcpy;

template <typename DT>
struct QuantizedLowbitMemcpy<DT, true> {
// cast with bits that 8 % bits == 0
static constexpr uint16_t bits = DTypeTrait<DT>::low_bit;
static constexpr uint8_t MASK = (1 << bits) - 1;
using Trait = QuantizedLowbitTrait<DT>;

static void byte2compact(void* dest_raw, const void* src_raw, size_t n) {
auto dest = static_cast<uint8_t*>(dest_raw);
auto src = static_cast<const int8_t*>(src_raw);
memset(dest, 0, divup<size_t>(n * bits, 8));
for (size_t i = 0; i < n; ++i) {
int8_t val = src[i] + Trait::SHIFT;
mgb_assert(val >= 0 && val < (1 << bits));
dest[i * bits / 8] |= val << (i * bits % 8);
}
}
static void compact2byte(void* dest_raw, const void* src_raw, size_t n) {
auto dest = static_cast<int8_t*>(dest_raw);
auto src = static_cast<const uint8_t*>(src_raw);
for (size_t i = 0; i < n; ++i) {
int8_t val = ((src[i * bits / 8] >> (i * bits % 8)) & MASK);
dest[i] = val - Trait::SHIFT;
}
}
};

} // anonymous namespace } // anonymous namespace


void mgb::lowbit_memcpy_byte2compact( void mgb::lowbit_memcpy_byte2compact(
@@ -366,6 +413,11 @@ void mgb::lowbit_memcpy_byte2compact(
return LowbitMemcpy<bits>::byte2compact(dest, src, n); return LowbitMemcpy<bits>::byte2compact(dest, src, n);
MEGDNN_FOREACH_LOWBIT_DTYPE(cb) MEGDNN_FOREACH_LOWBIT_DTYPE(cb)
#undef cb #undef cb
#define cb(dt) \
if (dtype.enumv() == DTypeTrait<dt>::enumv) \
return QuantizedLowbitMemcpy<dt>::byte2compact(dest, src, n);
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
#undef cb
mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name()); mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name());
} }


@@ -376,6 +428,11 @@ void mgb::lowbit_memcpy_compact2byte(
return LowbitMemcpy<bits>::compact2byte(dest, src, n); return LowbitMemcpy<bits>::compact2byte(dest, src, n);
MEGDNN_FOREACH_LOWBIT_DTYPE(cb) MEGDNN_FOREACH_LOWBIT_DTYPE(cb)
#undef cb #undef cb
#define cb(dt) \
if (dtype.enumv() == DTypeTrait<dt>::enumv) \
return QuantizedLowbitMemcpy<dt>::compact2byte(dest, src, n);
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
#undef cb
mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name()); mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name());
} }




Loading…
Cancel
Save