@@ -451,7 +451,12 @@ namespace fallback { | |||
void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | |||
check_exec(src.layout, dst.layout); | |||
if (src.layout.is_contiguous() && dst.layout.is_contiguous()) { | |||
auto is_quantize_lowbit = [](const DType& dt) { | |||
return dt.category() == DTypeCategory::QUANTIZED && dt.is_low_bit(); | |||
}; | |||
if (src.layout.is_contiguous() && dst.layout.is_contiguous() && | |||
!is_quantize_lowbit(src.layout.dtype) && | |||
!is_quantize_lowbit(dst.layout.dtype)) { | |||
MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst)); | |||
} else { | |||
naive::TypeCvtImpl::exec(src, dst); | |||
@@ -32,7 +32,7 @@ def get_scale(dtype): | |||
def get_zero_point(dtype): | |||
assert is_quantize(dtype) | |||
metadata = dtype.metadata["mgb_dtype"] | |||
assert metadata["name"] == "Quantized8Asymm" | |||
assert metadata["name"] in ("Quantized8Asymm", "Quantized4Asymm") | |||
return metadata["zero_point"] | |||
@@ -79,6 +79,38 @@ def qint32(scale): | |||
) | |||
def quint4(scale, zero_point): | |||
""" | |||
Consturct a quantized unsigned int4 data type with ``scale`` (float) and | |||
``zero_point`` (uint8). The real value represented by a quint4 data type is | |||
float_val = scale * (uint4_val - zero_point) | |||
""" | |||
int_zp = int(zero_point) | |||
assert int_zp == zero_point, "zero_point should be an integer" | |||
if int_zp < 0 or int_zp > 15: | |||
raise ValueError("zero_point should be within [0, 15] for quint4") | |||
return np.dtype( | |||
np.uint8, | |||
metadata={ | |||
"mgb_dtype": { | |||
"name": "Quantized4Asymm", | |||
"scale": float(scale), | |||
"zero_point": int(zero_point), | |||
} | |||
}, | |||
) | |||
def qint4(scale): | |||
""" | |||
Construct a quantized int4 data type with ``scale`` (float). The real value | |||
represented by a qint4 data type is float_val = scale * int4_val | |||
""" | |||
return np.dtype( | |||
np.int8, metadata={"mgb_dtype": {"name": "QuantizedS4", "scale": float(scale)}} | |||
) | |||
def convert_to_quint8(arr, q): | |||
""" | |||
Quantize a float NumPy ndarray into a quint8 one with specified params. | |||
@@ -177,3 +209,71 @@ def convert_from_qint32(arr): | |||
), "arr should be a ndarray with qint8 dtype" | |||
scale = arr.dtype.metadata["mgb_dtype"]["scale"] | |||
return arr.astype(np.float32) * scale | |||
def convert_to_quint4(arr, q): | |||
""" | |||
Quantize a float NumPy ndarray into a quint4 one with specified params. | |||
:param arr: Input ndarray. | |||
:type arr: :class:`np.ndarray` | |||
:param q: Target data type, should be a quint4. | |||
:type q: :class:`np.dtype` | |||
""" | |||
assert isinstance(arr, np.ndarray) | |||
assert ( | |||
"mgb_dtype" in q.metadata | |||
and q.metadata["mgb_dtype"]["name"] == "Quantized4Asymm" | |||
), "q should be a quint4 dtype" | |||
scale, zp = q.metadata["mgb_dtype"]["scale"], q.metadata["mgb_dtype"]["zero_point"] | |||
return (np.round(arr / scale) + zp).clip(0, 15).astype(q) | |||
def convert_from_quint4(arr): | |||
""" | |||
Dequantize a quint4 NumPy ndarray into a float one. | |||
:param arr: Input ndarray. | |||
""" | |||
assert isinstance(arr, np.ndarray) | |||
assert ( | |||
"mgb_dtype" in arr.dtype.metadata | |||
and arr.dtype.metadata["mgb_dtype"]["name"] == "Quantized4Asymm" | |||
), "arr should be a ndarray with quint4 dtype" | |||
scale, zp = ( | |||
arr.dtype.metadata["mgb_dtype"]["scale"], | |||
arr.dtype.metadata["mgb_dtype"]["zero_point"], | |||
) | |||
return (arr.astype(np.float32) - zp) * scale | |||
def convert_to_qint4(arr, q): | |||
""" | |||
Quantize a float NumPy ndarray into a qint4 one with specified params. | |||
:param arr: Input ndarray. | |||
:type arr: :class:`np.ndarray` | |||
:param q: Target data type, should be a qint4. | |||
:type q: :class:`np.dtype` | |||
""" | |||
assert isinstance(arr, np.ndarray) | |||
assert ( | |||
"mgb_dtype" in q.metadata and q.metadata["mgb_dtype"]["name"] == "QuantizedS4" | |||
), "q should be a qint4 dtype" | |||
scale = q.metadata["mgb_dtype"]["scale"] | |||
return (np.round(arr / scale)).clip(-8, 7).astype(q) | |||
def convert_from_qint4(arr): | |||
""" | |||
Dequantize a qint4 NumPy ndarray into a float one. | |||
:param arr: Input ndarray. | |||
""" | |||
assert isinstance(arr, np.ndarray) | |||
assert ( | |||
"mgb_dtype" in arr.dtype.metadata | |||
and arr.dtype.metadata["mgb_dtype"]["name"] == "QuantizedS4" | |||
), "arr should be a ndarray with qint4 dtype" | |||
scale = arr.dtype.metadata["mgb_dtype"]["scale"] | |||
return arr.astype(np.float32) * scale |
@@ -452,6 +452,23 @@ std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr( | |||
{{"scale", PyFloat_FromDouble(param.scale)}}); | |||
break; | |||
} | |||
case DTypeEnum::Quantized4Asymm: { | |||
auto& param = dtype.param<dtype::Quantized4Asymm>(); | |||
type_descr = PyArray_DescrNewFromType(NPY_UINT8); | |||
type_descr->metadata = build_mgb_dtype_dict( | |||
DTypeTrait<dtype::Quantized4Asymm>::name, | |||
{{"scale", PyFloat_FromDouble(param.scale)}, | |||
{"zero_point", PyLong_FromLong(param.zero_point)}}); | |||
break; | |||
} | |||
case DTypeEnum::QuantizedS4: { | |||
auto& param = dtype.param<dtype::QuantizedS4>(); | |||
type_descr = PyArray_DescrNewFromType(NPY_INT8); | |||
type_descr->metadata = build_mgb_dtype_dict( | |||
DTypeTrait<dtype::QuantizedS4>::name, | |||
{{"scale", PyFloat_FromDouble(param.scale)}}); | |||
break; | |||
} | |||
case DTypeEnum::QuantizedS32: { | |||
auto& param = dtype.param<dtype::QuantizedS32>(); | |||
type_descr = PyArray_DescrNewFromType(NPY_INT32); | |||
@@ -529,7 +546,29 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) { | |||
static_cast<float>(PyFloat_AS_DOUBLE(scale_py)), | |||
static_cast<uint8_t>(zero_point)); | |||
} | |||
if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8") { | |||
if (dtype_name == "Quantized4Asymm") { | |||
PyObject* scale_py = PyDict_GetItemString(metadata, "scale"); | |||
PyObject* zero_point_py = | |||
PyDict_GetItemString(metadata, "zero_point"); | |||
mgb_assert(scale_py && zero_point_py, | |||
"Invalid Quantized4Asymm metadata: missing scale or " | |||
"zero_point."); | |||
mgb_assert( | |||
PyFloat_Check(scale_py), | |||
"Invalid Quantized4Asymm metadata: scale should be float"); | |||
mgb_assert(PyLong_Check(zero_point_py), | |||
"Invalid Quantized4Asymm metadata: zero_point should be " | |||
"integer"); | |||
auto zero_point = PyLong_AS_LONG(zero_point_py); | |||
mgb_assert(zero_point >= 0 && zero_point < 15, | |||
"Invalid Quantized4Asymm metadata: zero_point should be " | |||
"in [0, 15)"); | |||
return dtype::Quantized4Asymm( | |||
static_cast<float>(PyFloat_AS_DOUBLE(scale_py)), | |||
static_cast<uint8_t>(zero_point)); | |||
} | |||
if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8" || | |||
dtype_name == "QuantizedS4") { | |||
PyObject* scale_py = PyDict_GetItemString(metadata, "scale"); | |||
mgb_assert(scale_py, "Invalid metadata: missing scale"); | |||
mgb_assert(PyFloat_Check(scale_py), | |||
@@ -537,8 +576,10 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) { | |||
float scale = static_cast<float>(PyFloat_AS_DOUBLE(scale_py)); | |||
if (dtype_name == "QuantizedS32") { | |||
return dtype::QuantizedS32(scale); | |||
} else { | |||
} else if (dtype_name == "QuantizedS8"){ | |||
return dtype::QuantizedS8(scale); | |||
} else { | |||
return dtype::QuantizedS4(scale); | |||
} | |||
} | |||
throw ConversionError( | |||
@@ -14,6 +14,7 @@ | |||
#include "megbrain/exception.h" | |||
#include "megbrain/utils/metahelper.h" | |||
#include "megbrain/utils/arith_helper.h" | |||
#include "megdnn/dtype.h" | |||
#include <cmath> | |||
#include <cstring> | |||
@@ -357,6 +358,52 @@ struct LowbitMemcpy<bits, true> { | |||
} | |||
} | |||
}; | |||
template<typename DT> | |||
struct QuantizedLowbitTrait; | |||
template<> | |||
struct QuantizedLowbitTrait<dtype::Quantized4Asymm> { | |||
static constexpr int8_t SHIFT = 0; | |||
}; | |||
template<> | |||
struct QuantizedLowbitTrait<dtype::QuantizedS4> { | |||
static constexpr int8_t SHIFT = 8; | |||
}; | |||
template <typename DT, bool div_byte = (DTypeTrait<DT>::category == | |||
DTypeCategory::QUANTIZED) && | |||
(8 % DTypeTrait<DT>::low_bit == 0)> | |||
struct QuantizedLowbitMemcpy; | |||
template <typename DT> | |||
struct QuantizedLowbitMemcpy<DT, true> { | |||
// cast with bits that 8 % bits == 0 | |||
static constexpr uint16_t bits = DTypeTrait<DT>::low_bit; | |||
static constexpr uint8_t MASK = (1 << bits) - 1; | |||
using Trait = QuantizedLowbitTrait<DT>; | |||
static void byte2compact(void* dest_raw, const void* src_raw, size_t n) { | |||
auto dest = static_cast<uint8_t*>(dest_raw); | |||
auto src = static_cast<const int8_t*>(src_raw); | |||
memset(dest, 0, divup<size_t>(n * bits, 8)); | |||
for (size_t i = 0; i < n; ++i) { | |||
int8_t val = src[i] + Trait::SHIFT; | |||
mgb_assert(val >= 0 && val < (1 << bits)); | |||
dest[i * bits / 8] |= val << (i * bits % 8); | |||
} | |||
} | |||
static void compact2byte(void* dest_raw, const void* src_raw, size_t n) { | |||
auto dest = static_cast<int8_t*>(dest_raw); | |||
auto src = static_cast<const uint8_t*>(src_raw); | |||
for (size_t i = 0; i < n; ++i) { | |||
int8_t val = ((src[i * bits / 8] >> (i * bits % 8)) & MASK); | |||
dest[i] = val - Trait::SHIFT; | |||
} | |||
} | |||
}; | |||
} // anonymous namespace | |||
void mgb::lowbit_memcpy_byte2compact( | |||
@@ -366,6 +413,11 @@ void mgb::lowbit_memcpy_byte2compact( | |||
return LowbitMemcpy<bits>::byte2compact(dest, src, n); | |||
MEGDNN_FOREACH_LOWBIT_DTYPE(cb) | |||
#undef cb | |||
#define cb(dt) \ | |||
if (dtype.enumv() == DTypeTrait<dt>::enumv) \ | |||
return QuantizedLowbitMemcpy<dt>::byte2compact(dest, src, n); | |||
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||
#undef cb | |||
mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name()); | |||
} | |||
@@ -376,6 +428,11 @@ void mgb::lowbit_memcpy_compact2byte( | |||
return LowbitMemcpy<bits>::compact2byte(dest, src, n); | |||
MEGDNN_FOREACH_LOWBIT_DTYPE(cb) | |||
#undef cb | |||
#define cb(dt) \ | |||
if (dtype.enumv() == DTypeTrait<dt>::enumv) \ | |||
return QuantizedLowbitMemcpy<dt>::compact2byte(dest, src, n); | |||
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||
#undef cb | |||
mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name()); | |||
} | |||