@@ -451,7 +451,12 @@ namespace fallback { | |||||
void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | void TypeCvtImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst) { | ||||
check_exec(src.layout, dst.layout); | check_exec(src.layout, dst.layout); | ||||
if (src.layout.is_contiguous() && dst.layout.is_contiguous()) { | |||||
auto is_quantize_lowbit = [](const DType& dt) { | |||||
return dt.category() == DTypeCategory::QUANTIZED && dt.is_low_bit(); | |||||
}; | |||||
if (src.layout.is_contiguous() && dst.layout.is_contiguous() && | |||||
!is_quantize_lowbit(src.layout.dtype) && | |||||
!is_quantize_lowbit(dst.layout.dtype)) { | |||||
MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst)); | MEGDNN_DISPATCH_CPU_KERN_OPR(run_contiguous(src, dst)); | ||||
} else { | } else { | ||||
naive::TypeCvtImpl::exec(src, dst); | naive::TypeCvtImpl::exec(src, dst); | ||||
@@ -32,7 +32,7 @@ def get_scale(dtype): | |||||
def get_zero_point(dtype): | def get_zero_point(dtype): | ||||
assert is_quantize(dtype) | assert is_quantize(dtype) | ||||
metadata = dtype.metadata["mgb_dtype"] | metadata = dtype.metadata["mgb_dtype"] | ||||
assert metadata["name"] == "Quantized8Asymm" | |||||
assert metadata["name"] in ("Quantized8Asymm", "Quantized4Asymm") | |||||
return metadata["zero_point"] | return metadata["zero_point"] | ||||
@@ -79,6 +79,38 @@ def qint32(scale): | |||||
) | ) | ||||
def quint4(scale, zero_point): | |||||
""" | |||||
Consturct a quantized unsigned int4 data type with ``scale`` (float) and | |||||
``zero_point`` (uint8). The real value represented by a quint4 data type is | |||||
float_val = scale * (uint4_val - zero_point) | |||||
""" | |||||
int_zp = int(zero_point) | |||||
assert int_zp == zero_point, "zero_point should be an integer" | |||||
if int_zp < 0 or int_zp > 15: | |||||
raise ValueError("zero_point should be within [0, 15] for quint4") | |||||
return np.dtype( | |||||
np.uint8, | |||||
metadata={ | |||||
"mgb_dtype": { | |||||
"name": "Quantized4Asymm", | |||||
"scale": float(scale), | |||||
"zero_point": int(zero_point), | |||||
} | |||||
}, | |||||
) | |||||
def qint4(scale): | |||||
""" | |||||
Construct a quantized int4 data type with ``scale`` (float). The real value | |||||
represented by a qint4 data type is float_val = scale * int4_val | |||||
""" | |||||
return np.dtype( | |||||
np.int8, metadata={"mgb_dtype": {"name": "QuantizedS4", "scale": float(scale)}} | |||||
) | |||||
def convert_to_quint8(arr, q): | def convert_to_quint8(arr, q): | ||||
""" | """ | ||||
Quantize a float NumPy ndarray into a quint8 one with specified params. | Quantize a float NumPy ndarray into a quint8 one with specified params. | ||||
@@ -177,3 +209,71 @@ def convert_from_qint32(arr): | |||||
), "arr should be a ndarray with qint8 dtype" | ), "arr should be a ndarray with qint8 dtype" | ||||
scale = arr.dtype.metadata["mgb_dtype"]["scale"] | scale = arr.dtype.metadata["mgb_dtype"]["scale"] | ||||
return arr.astype(np.float32) * scale | return arr.astype(np.float32) * scale | ||||
def convert_to_quint4(arr, q): | |||||
""" | |||||
Quantize a float NumPy ndarray into a quint4 one with specified params. | |||||
:param arr: Input ndarray. | |||||
:type arr: :class:`np.ndarray` | |||||
:param q: Target data type, should be a quint4. | |||||
:type q: :class:`np.dtype` | |||||
""" | |||||
assert isinstance(arr, np.ndarray) | |||||
assert ( | |||||
"mgb_dtype" in q.metadata | |||||
and q.metadata["mgb_dtype"]["name"] == "Quantized4Asymm" | |||||
), "q should be a quint4 dtype" | |||||
scale, zp = q.metadata["mgb_dtype"]["scale"], q.metadata["mgb_dtype"]["zero_point"] | |||||
return (np.round(arr / scale) + zp).clip(0, 15).astype(q) | |||||
def convert_from_quint4(arr): | |||||
""" | |||||
Dequantize a quint4 NumPy ndarray into a float one. | |||||
:param arr: Input ndarray. | |||||
""" | |||||
assert isinstance(arr, np.ndarray) | |||||
assert ( | |||||
"mgb_dtype" in arr.dtype.metadata | |||||
and arr.dtype.metadata["mgb_dtype"]["name"] == "Quantized4Asymm" | |||||
), "arr should be a ndarray with quint4 dtype" | |||||
scale, zp = ( | |||||
arr.dtype.metadata["mgb_dtype"]["scale"], | |||||
arr.dtype.metadata["mgb_dtype"]["zero_point"], | |||||
) | |||||
return (arr.astype(np.float32) - zp) * scale | |||||
def convert_to_qint4(arr, q): | |||||
""" | |||||
Quantize a float NumPy ndarray into a qint4 one with specified params. | |||||
:param arr: Input ndarray. | |||||
:type arr: :class:`np.ndarray` | |||||
:param q: Target data type, should be a qint4. | |||||
:type q: :class:`np.dtype` | |||||
""" | |||||
assert isinstance(arr, np.ndarray) | |||||
assert ( | |||||
"mgb_dtype" in q.metadata and q.metadata["mgb_dtype"]["name"] == "QuantizedS4" | |||||
), "q should be a qint4 dtype" | |||||
scale = q.metadata["mgb_dtype"]["scale"] | |||||
return (np.round(arr / scale)).clip(-8, 7).astype(q) | |||||
def convert_from_qint4(arr): | |||||
""" | |||||
Dequantize a qint4 NumPy ndarray into a float one. | |||||
:param arr: Input ndarray. | |||||
""" | |||||
assert isinstance(arr, np.ndarray) | |||||
assert ( | |||||
"mgb_dtype" in arr.dtype.metadata | |||||
and arr.dtype.metadata["mgb_dtype"]["name"] == "QuantizedS4" | |||||
), "arr should be a ndarray with qint4 dtype" | |||||
scale = arr.dtype.metadata["mgb_dtype"]["scale"] | |||||
return arr.astype(np.float32) * scale |
@@ -452,6 +452,23 @@ std::unique_ptr<PyArray_Descr, PyArrayDescrDeleter> dtype_mgb2np_descr( | |||||
{{"scale", PyFloat_FromDouble(param.scale)}}); | {{"scale", PyFloat_FromDouble(param.scale)}}); | ||||
break; | break; | ||||
} | } | ||||
case DTypeEnum::Quantized4Asymm: { | |||||
auto& param = dtype.param<dtype::Quantized4Asymm>(); | |||||
type_descr = PyArray_DescrNewFromType(NPY_UINT8); | |||||
type_descr->metadata = build_mgb_dtype_dict( | |||||
DTypeTrait<dtype::Quantized4Asymm>::name, | |||||
{{"scale", PyFloat_FromDouble(param.scale)}, | |||||
{"zero_point", PyLong_FromLong(param.zero_point)}}); | |||||
break; | |||||
} | |||||
case DTypeEnum::QuantizedS4: { | |||||
auto& param = dtype.param<dtype::QuantizedS4>(); | |||||
type_descr = PyArray_DescrNewFromType(NPY_INT8); | |||||
type_descr->metadata = build_mgb_dtype_dict( | |||||
DTypeTrait<dtype::QuantizedS4>::name, | |||||
{{"scale", PyFloat_FromDouble(param.scale)}}); | |||||
break; | |||||
} | |||||
case DTypeEnum::QuantizedS32: { | case DTypeEnum::QuantizedS32: { | ||||
auto& param = dtype.param<dtype::QuantizedS32>(); | auto& param = dtype.param<dtype::QuantizedS32>(); | ||||
type_descr = PyArray_DescrNewFromType(NPY_INT32); | type_descr = PyArray_DescrNewFromType(NPY_INT32); | ||||
@@ -529,7 +546,29 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) { | |||||
static_cast<float>(PyFloat_AS_DOUBLE(scale_py)), | static_cast<float>(PyFloat_AS_DOUBLE(scale_py)), | ||||
static_cast<uint8_t>(zero_point)); | static_cast<uint8_t>(zero_point)); | ||||
} | } | ||||
if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8") { | |||||
if (dtype_name == "Quantized4Asymm") { | |||||
PyObject* scale_py = PyDict_GetItemString(metadata, "scale"); | |||||
PyObject* zero_point_py = | |||||
PyDict_GetItemString(metadata, "zero_point"); | |||||
mgb_assert(scale_py && zero_point_py, | |||||
"Invalid Quantized4Asymm metadata: missing scale or " | |||||
"zero_point."); | |||||
mgb_assert( | |||||
PyFloat_Check(scale_py), | |||||
"Invalid Quantized4Asymm metadata: scale should be float"); | |||||
mgb_assert(PyLong_Check(zero_point_py), | |||||
"Invalid Quantized4Asymm metadata: zero_point should be " | |||||
"integer"); | |||||
auto zero_point = PyLong_AS_LONG(zero_point_py); | |||||
mgb_assert(zero_point >= 0 && zero_point < 15, | |||||
"Invalid Quantized4Asymm metadata: zero_point should be " | |||||
"in [0, 15)"); | |||||
return dtype::Quantized4Asymm( | |||||
static_cast<float>(PyFloat_AS_DOUBLE(scale_py)), | |||||
static_cast<uint8_t>(zero_point)); | |||||
} | |||||
if (dtype_name == "QuantizedS32" || dtype_name == "QuantizedS8" || | |||||
dtype_name == "QuantizedS4") { | |||||
PyObject* scale_py = PyDict_GetItemString(metadata, "scale"); | PyObject* scale_py = PyDict_GetItemString(metadata, "scale"); | ||||
mgb_assert(scale_py, "Invalid metadata: missing scale"); | mgb_assert(scale_py, "Invalid metadata: missing scale"); | ||||
mgb_assert(PyFloat_Check(scale_py), | mgb_assert(PyFloat_Check(scale_py), | ||||
@@ -537,8 +576,10 @@ DType dtype_np2mgb_descr(PyArray_Descr* descr) { | |||||
float scale = static_cast<float>(PyFloat_AS_DOUBLE(scale_py)); | float scale = static_cast<float>(PyFloat_AS_DOUBLE(scale_py)); | ||||
if (dtype_name == "QuantizedS32") { | if (dtype_name == "QuantizedS32") { | ||||
return dtype::QuantizedS32(scale); | return dtype::QuantizedS32(scale); | ||||
} else { | |||||
} else if (dtype_name == "QuantizedS8"){ | |||||
return dtype::QuantizedS8(scale); | return dtype::QuantizedS8(scale); | ||||
} else { | |||||
return dtype::QuantizedS4(scale); | |||||
} | } | ||||
} | } | ||||
throw ConversionError( | throw ConversionError( | ||||
@@ -14,6 +14,7 @@ | |||||
#include "megbrain/exception.h" | #include "megbrain/exception.h" | ||||
#include "megbrain/utils/metahelper.h" | #include "megbrain/utils/metahelper.h" | ||||
#include "megbrain/utils/arith_helper.h" | #include "megbrain/utils/arith_helper.h" | ||||
#include "megdnn/dtype.h" | |||||
#include <cmath> | #include <cmath> | ||||
#include <cstring> | #include <cstring> | ||||
@@ -357,6 +358,52 @@ struct LowbitMemcpy<bits, true> { | |||||
} | } | ||||
} | } | ||||
}; | }; | ||||
template<typename DT> | |||||
struct QuantizedLowbitTrait; | |||||
template<> | |||||
struct QuantizedLowbitTrait<dtype::Quantized4Asymm> { | |||||
static constexpr int8_t SHIFT = 0; | |||||
}; | |||||
template<> | |||||
struct QuantizedLowbitTrait<dtype::QuantizedS4> { | |||||
static constexpr int8_t SHIFT = 8; | |||||
}; | |||||
template <typename DT, bool div_byte = (DTypeTrait<DT>::category == | |||||
DTypeCategory::QUANTIZED) && | |||||
(8 % DTypeTrait<DT>::low_bit == 0)> | |||||
struct QuantizedLowbitMemcpy; | |||||
template <typename DT> | |||||
struct QuantizedLowbitMemcpy<DT, true> { | |||||
// cast with bits that 8 % bits == 0 | |||||
static constexpr uint16_t bits = DTypeTrait<DT>::low_bit; | |||||
static constexpr uint8_t MASK = (1 << bits) - 1; | |||||
using Trait = QuantizedLowbitTrait<DT>; | |||||
static void byte2compact(void* dest_raw, const void* src_raw, size_t n) { | |||||
auto dest = static_cast<uint8_t*>(dest_raw); | |||||
auto src = static_cast<const int8_t*>(src_raw); | |||||
memset(dest, 0, divup<size_t>(n * bits, 8)); | |||||
for (size_t i = 0; i < n; ++i) { | |||||
int8_t val = src[i] + Trait::SHIFT; | |||||
mgb_assert(val >= 0 && val < (1 << bits)); | |||||
dest[i * bits / 8] |= val << (i * bits % 8); | |||||
} | |||||
} | |||||
static void compact2byte(void* dest_raw, const void* src_raw, size_t n) { | |||||
auto dest = static_cast<int8_t*>(dest_raw); | |||||
auto src = static_cast<const uint8_t*>(src_raw); | |||||
for (size_t i = 0; i < n; ++i) { | |||||
int8_t val = ((src[i * bits / 8] >> (i * bits % 8)) & MASK); | |||||
dest[i] = val - Trait::SHIFT; | |||||
} | |||||
} | |||||
}; | |||||
} // anonymous namespace | } // anonymous namespace | ||||
void mgb::lowbit_memcpy_byte2compact( | void mgb::lowbit_memcpy_byte2compact( | ||||
@@ -366,6 +413,11 @@ void mgb::lowbit_memcpy_byte2compact( | |||||
return LowbitMemcpy<bits>::byte2compact(dest, src, n); | return LowbitMemcpy<bits>::byte2compact(dest, src, n); | ||||
MEGDNN_FOREACH_LOWBIT_DTYPE(cb) | MEGDNN_FOREACH_LOWBIT_DTYPE(cb) | ||||
#undef cb | #undef cb | ||||
#define cb(dt) \ | |||||
if (dtype.enumv() == DTypeTrait<dt>::enumv) \ | |||||
return QuantizedLowbitMemcpy<dt>::byte2compact(dest, src, n); | |||||
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||||
#undef cb | |||||
mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name()); | mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name()); | ||||
} | } | ||||
@@ -376,6 +428,11 @@ void mgb::lowbit_memcpy_compact2byte( | |||||
return LowbitMemcpy<bits>::compact2byte(dest, src, n); | return LowbitMemcpy<bits>::compact2byte(dest, src, n); | ||||
MEGDNN_FOREACH_LOWBIT_DTYPE(cb) | MEGDNN_FOREACH_LOWBIT_DTYPE(cb) | ||||
#undef cb | #undef cb | ||||
#define cb(dt) \ | |||||
if (dtype.enumv() == DTypeTrait<dt>::enumv) \ | |||||
return QuantizedLowbitMemcpy<dt>::compact2byte(dest, src, n); | |||||
MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb) | |||||
#undef cb | |||||
mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name()); | mgb_throw(MegBrainError, "bad dtype for lowbit: %s", dtype.name()); | ||||
} | } | ||||