|
|
@@ -1,4 +1,3 @@ |
|
|
|
# -*- coding: utf-8 -*- |
|
|
|
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") |
|
|
|
# |
|
|
|
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. |
|
|
@@ -11,6 +10,34 @@ import numpy as np |
|
|
|
|
|
|
|
from .mgb import intb1, intb2, intb4 |
|
|
|
|
|
|
|
_metadata_dict = { |
|
|
|
"quint8": { |
|
|
|
"is_unsigned": True, |
|
|
|
"np_dtype_str": "uint8", |
|
|
|
"mgb_dtype": {"name": "Quantized8Asymm", "qmin": 0, "qmax": 255,}, |
|
|
|
}, |
|
|
|
"qint8": { |
|
|
|
"is_unsigned": False, |
|
|
|
"np_dtype_str": "int8", |
|
|
|
"mgb_dtype": {"name": "QuantizedS8", "qmin": -128, "qmax": 127,}, |
|
|
|
}, |
|
|
|
"quint4": { |
|
|
|
"is_unsigned": True, |
|
|
|
"np_dtype_str": "uint8", |
|
|
|
"mgb_dtype": {"name": "Quantized4Asymm", "qmin": 0, "qmax": 15,}, |
|
|
|
}, |
|
|
|
"qint4": { |
|
|
|
"is_unsigned": False, |
|
|
|
"np_dtype_str": "int8", |
|
|
|
"mgb_dtype": {"name": "QuantizedS4", "qmin": -8, "qmax": 7,}, |
|
|
|
}, |
|
|
|
"qint32": { |
|
|
|
"is_unsigned": False, |
|
|
|
"np_dtype_str": "int32", |
|
|
|
"mgb_dtype": {"name": "QuantizedS32", "qmin": -(2 ** 31), "qmax": 2 ** 31 - 1,}, |
|
|
|
}, |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
def is_quantize(dtype): |
|
|
|
return ( |
|
|
@@ -36,26 +63,36 @@ def get_zero_point(dtype): |
|
|
|
return metadata["zero_point"] |
|
|
|
|
|
|
|
|
|
|
|
def _check_zero_point(zp: int, dtype_str: str): |
|
|
|
qmin = _metadata_dict[dtype_str]["mgb_dtype"]["qmin"] |
|
|
|
qmax = _metadata_dict[dtype_str]["mgb_dtype"]["qmax"] |
|
|
|
if zp < qmin or zp > qmax: |
|
|
|
raise ValueError( |
|
|
|
"zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _get_dtype(dtype_str: str, scale, zp): |
|
|
|
if zp is not None: |
|
|
|
if int(zp) != zp: |
|
|
|
raise ValueError("zero_point should be an integer") |
|
|
|
zp = int(zp) |
|
|
|
_check_zero_point(zp, dtype_str) |
|
|
|
metadata = _metadata_dict[dtype_str]["mgb_dtype"] |
|
|
|
np_dtype_str = _metadata_dict[dtype_str]["np_dtype_str"] |
|
|
|
return np.dtype( |
|
|
|
np_dtype_str, |
|
|
|
metadata={"mgb_dtype": {**metadata, "scale": float(scale), "zero_point": zp,}}, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def quint8(scale, zero_point): |
|
|
|
""" |
|
|
|
Consturct a quantized unsigned int8 data type with ``scale`` (float) and |
|
|
|
``zero_point`` (uint8). The real value represented by a quint8 data type is |
|
|
|
float_val = scale * (uint8_val - zero_point) |
|
|
|
""" |
|
|
|
int_zp = int(zero_point) |
|
|
|
assert int_zp == zero_point, "zero_point should be an integer" |
|
|
|
if int_zp < 0 or int_zp > 255: |
|
|
|
raise ValueError("zero_point should be within [0, 255] for quint8") |
|
|
|
return np.dtype( |
|
|
|
np.uint8, |
|
|
|
metadata={ |
|
|
|
"mgb_dtype": { |
|
|
|
"name": "Quantized8Asymm", |
|
|
|
"scale": float(scale), |
|
|
|
"zero_point": int(zero_point), |
|
|
|
} |
|
|
|
}, |
|
|
|
) |
|
|
|
return _get_dtype("quint8", scale, zero_point) |
|
|
|
|
|
|
|
|
|
|
|
def qint8(scale): |
|
|
@@ -63,9 +100,7 @@ def qint8(scale): |
|
|
|
Construct a quantized int8 data type with ``scale`` (float). The real value |
|
|
|
represented by a qint8 data type is float_val = scale * int8_val |
|
|
|
""" |
|
|
|
return np.dtype( |
|
|
|
np.int8, metadata={"mgb_dtype": {"name": "QuantizedS8", "scale": float(scale)}} |
|
|
|
) |
|
|
|
return _get_dtype("qint8", scale, None) |
|
|
|
|
|
|
|
|
|
|
|
def qint32(scale): |
|
|
@@ -73,10 +108,7 @@ def qint32(scale): |
|
|
|
Construct a quantized int32 data type with ``scale`` (float). The real value |
|
|
|
represented by a qint32 data type is float_val = scale * int32_val |
|
|
|
""" |
|
|
|
return np.dtype( |
|
|
|
np.int32, |
|
|
|
metadata={"mgb_dtype": {"name": "QuantizedS32", "scale": float(scale)}}, |
|
|
|
) |
|
|
|
return _get_dtype("qint32", scale, None) |
|
|
|
|
|
|
|
|
|
|
|
def quint4(scale, zero_point): |
|
|
@@ -85,20 +117,7 @@ def quint4(scale, zero_point): |
|
|
|
``zero_point`` (uint8). The real value represented by a quint4 data type is |
|
|
|
float_val = scale * (uint4_val - zero_point) |
|
|
|
""" |
|
|
|
int_zp = int(zero_point) |
|
|
|
assert int_zp == zero_point, "zero_point should be an integer" |
|
|
|
if int_zp < 0 or int_zp > 15: |
|
|
|
raise ValueError("zero_point should be within [0, 15] for quint4") |
|
|
|
return np.dtype( |
|
|
|
np.uint8, |
|
|
|
metadata={ |
|
|
|
"mgb_dtype": { |
|
|
|
"name": "Quantized4Asymm", |
|
|
|
"scale": float(scale), |
|
|
|
"zero_point": int(zero_point), |
|
|
|
} |
|
|
|
}, |
|
|
|
) |
|
|
|
return _get_dtype("quint4", scale, zero_point) |
|
|
|
|
|
|
|
|
|
|
|
def qint4(scale): |
|
|
@@ -106,94 +125,101 @@ def qint4(scale): |
|
|
|
Construct a quantized int4 data type with ``scale`` (float). The real value |
|
|
|
represented by a qint4 data type is float_val = scale * int4_val |
|
|
|
""" |
|
|
|
return np.dtype( |
|
|
|
np.int8, metadata={"mgb_dtype": {"name": "QuantizedS4", "scale": float(scale)}} |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def convert_to_quint8(arr, q): |
|
|
|
return _get_dtype("qint4", scale, None) |
|
|
|
|
|
|
|
|
|
|
|
def _convert_to_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str): |
|
|
|
metadata = _metadata_dict[dtype_str]["mgb_dtype"] |
|
|
|
arr_metadata = dtype.metadata["mgb_dtype"] |
|
|
|
if not isinstance(arr, np.ndarray): |
|
|
|
raise ValueError("arr parameter should be instance of np.ndarray") |
|
|
|
if not is_quantize(dtype) or arr_metadata["name"] != metadata["name"]: |
|
|
|
raise ValueError("dtype parameter should be a {} dtype".format(dtype_str)) |
|
|
|
is_unsigned = _metadata_dict[dtype_str]["is_unsigned"] |
|
|
|
if is_unsigned: |
|
|
|
scale, zp = ( |
|
|
|
arr_metadata["scale"], |
|
|
|
arr_metadata["zero_point"], |
|
|
|
) |
|
|
|
return ( |
|
|
|
(np.round(arr / scale) + zp) |
|
|
|
.clip(metadata["qmin"], metadata["qmax"]) |
|
|
|
.astype(dtype) |
|
|
|
) |
|
|
|
else: |
|
|
|
# don't trick to combine with is_unsigned for consistency with cpp interface |
|
|
|
scale = arr_metadata["scale"] |
|
|
|
return ( |
|
|
|
np.round(arr / scale).clip(metadata["qmin"], metadata["qmax"]).astype(dtype) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _convert_from_dtype(arr: np.ndarray, dtype_str: str): |
|
|
|
metadata = _metadata_dict[dtype_str]["mgb_dtype"] |
|
|
|
arr_metadata = arr.dtype.metadata["mgb_dtype"] |
|
|
|
if not isinstance(arr, np.ndarray): |
|
|
|
raise ValueError("arr parameter should be instance of np.ndarray") |
|
|
|
if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata["name"]: |
|
|
|
raise ValueError("arr's dtype should be a {} dtype".format(dtype_str)) |
|
|
|
is_unsigned = _metadata_dict[dtype_str]["is_unsigned"] |
|
|
|
if is_unsigned: |
|
|
|
scale, zp = ( |
|
|
|
arr_metadata["scale"], |
|
|
|
arr_metadata["zero_point"], |
|
|
|
) |
|
|
|
return (arr.astype(np.float32) - zp) * scale |
|
|
|
else: |
|
|
|
# don't trick to combine with is_unsigned for consistency with cpp interface |
|
|
|
scale = arr_metadata["scale"] |
|
|
|
return (arr.astype(np.float32)) * scale |
|
|
|
|
|
|
|
|
|
|
|
def convert_to_quint8(arr: np.ndarray, q: np.dtype): |
|
|
|
""" |
|
|
|
Quantize a float NumPy ndarray into a quint8 one with specified params. |
|
|
|
|
|
|
|
:param arr: Input ndarray. |
|
|
|
:type arr: :class:`np.ndarray` |
|
|
|
:param q: Target data type, should be a quint8. |
|
|
|
:type q: :class:`np.dtype` |
|
|
|
""" |
|
|
|
assert isinstance(arr, np.ndarray) |
|
|
|
assert ( |
|
|
|
"mgb_dtype" in q.metadata |
|
|
|
and q.metadata["mgb_dtype"]["name"] == "Quantized8Asymm" |
|
|
|
), "q should be a quint8 dtype" |
|
|
|
scale, zp = q.metadata["mgb_dtype"]["scale"], q.metadata["mgb_dtype"]["zero_point"] |
|
|
|
return (np.round(arr / scale) + zp).clip(0, 255).astype(q) |
|
|
|
return _convert_to_dtype(arr, q, "quint8") |
|
|
|
|
|
|
|
|
|
|
|
def convert_from_quint8(arr): |
|
|
|
def convert_from_quint8(arr: np.ndarray): |
|
|
|
""" |
|
|
|
Dequantize a quint8 NumPy ndarray into a float one. |
|
|
|
|
|
|
|
:param arr: Input ndarray. |
|
|
|
""" |
|
|
|
assert isinstance(arr, np.ndarray) |
|
|
|
assert ( |
|
|
|
"mgb_dtype" in arr.dtype.metadata |
|
|
|
and arr.dtype.metadata["mgb_dtype"]["name"] == "Quantized8Asymm" |
|
|
|
), "arr should be a ndarray with quint8 dtype" |
|
|
|
scale, zp = ( |
|
|
|
arr.dtype.metadata["mgb_dtype"]["scale"], |
|
|
|
arr.dtype.metadata["mgb_dtype"]["zero_point"], |
|
|
|
) |
|
|
|
return (arr.astype(np.float32) - zp) * scale |
|
|
|
return _convert_from_dtype(arr, "quint8") |
|
|
|
|
|
|
|
|
|
|
|
def convert_to_qint8(arr, q): |
|
|
|
def convert_to_qint8(arr: np.ndarray, q: np.dtype): |
|
|
|
""" |
|
|
|
Quantize a float NumPy ndarray into a qint8 one with specified params. |
|
|
|
|
|
|
|
:param arr: Input ndarray. |
|
|
|
:type arr: :class:`np.ndarray` |
|
|
|
:param q: Target data type, should be a qint8. |
|
|
|
:type q: :class:`np.dtype` |
|
|
|
""" |
|
|
|
assert isinstance(arr, np.ndarray) |
|
|
|
assert ( |
|
|
|
"mgb_dtype" in q.metadata and q.metadata["mgb_dtype"]["name"] == "QuantizedS8" |
|
|
|
), "q should be a qint8 dtype" |
|
|
|
scale = q.metadata["mgb_dtype"]["scale"] |
|
|
|
return (np.round(arr / scale)).clip(-128, 127).astype(q) |
|
|
|
return _convert_to_dtype(arr, q, "qint8") |
|
|
|
|
|
|
|
|
|
|
|
def convert_from_qint8(arr): |
|
|
|
def convert_from_qint8(arr: np.ndarray): |
|
|
|
""" |
|
|
|
Dequantize a qint8 NumPy ndarray into a float one. |
|
|
|
|
|
|
|
:param arr: Input ndarray. |
|
|
|
""" |
|
|
|
assert isinstance(arr, np.ndarray) |
|
|
|
assert ( |
|
|
|
"mgb_dtype" in arr.dtype.metadata |
|
|
|
and arr.dtype.metadata["mgb_dtype"]["name"] == "QuantizedS8" |
|
|
|
), "arr should be a ndarray with qint8 dtype" |
|
|
|
scale = arr.dtype.metadata["mgb_dtype"]["scale"] |
|
|
|
return arr.astype(np.float32) * scale |
|
|
|
return _convert_from_dtype(arr, "qint8") |
|
|
|
|
|
|
|
|
|
|
|
def convert_to_qint32(arr, q): |
|
|
|
def convert_to_qint32(arr: np.ndarray, q: np.dtype): |
|
|
|
""" |
|
|
|
Quantize a float NumPy ndarray into a qint32 one with specified params. |
|
|
|
|
|
|
|
:param arr: Input ndarray. |
|
|
|
:type arr: :class:`np.ndarray` |
|
|
|
:param q: Target data type, should be a qint8. |
|
|
|
:type q: :class:`np.dtype` |
|
|
|
""" |
|
|
|
assert isinstance(arr, np.ndarray) |
|
|
|
assert ( |
|
|
|
"mgb_dtype" in q.metadata and q.metadata["mgb_dtype"]["name"] == "QuantizedS32" |
|
|
|
), "q should be a qint32 dtype" |
|
|
|
scale = q.metadata["mgb_dtype"]["scale"] |
|
|
|
return (np.round(arr / scale)).clip(-(2 ** 31), 2 ** 31).astype(q) |
|
|
|
return _convert_to_dtype(arr, q, "qint32") |
|
|
|
|
|
|
|
|
|
|
|
def convert_from_qint32(arr): |
|
|
@@ -202,78 +228,42 @@ def convert_from_qint32(arr): |
|
|
|
|
|
|
|
:param arr: Input ndarray. |
|
|
|
""" |
|
|
|
assert isinstance(arr, np.ndarray) |
|
|
|
assert ( |
|
|
|
"mgb_dtype" in arr.dtype.metadata |
|
|
|
and arr.dtype.metadata["mgb_dtype"]["name"] == "QuantizedS32" |
|
|
|
), "arr should be a ndarray with qint8 dtype" |
|
|
|
scale = arr.dtype.metadata["mgb_dtype"]["scale"] |
|
|
|
return arr.astype(np.float32) * scale |
|
|
|
return _convert_from_dtype(arr, "qint32") |
|
|
|
|
|
|
|
|
|
|
|
def convert_to_quint4(arr, q): |
|
|
|
def convert_to_quint4(arr: np.ndarray, q: np.dtype): |
|
|
|
""" |
|
|
|
Quantize a float NumPy ndarray into a quint4 one with specified params. |
|
|
|
|
|
|
|
:param arr: Input ndarray. |
|
|
|
:type arr: :class:`np.ndarray` |
|
|
|
:param q: Target data type, should be a quint4. |
|
|
|
:type q: :class:`np.dtype` |
|
|
|
""" |
|
|
|
assert isinstance(arr, np.ndarray) |
|
|
|
assert ( |
|
|
|
"mgb_dtype" in q.metadata |
|
|
|
and q.metadata["mgb_dtype"]["name"] == "Quantized4Asymm" |
|
|
|
), "q should be a quint4 dtype" |
|
|
|
scale, zp = q.metadata["mgb_dtype"]["scale"], q.metadata["mgb_dtype"]["zero_point"] |
|
|
|
return (np.round(arr / scale) + zp).clip(0, 15).astype(q) |
|
|
|
return _convert_to_dtype(arr, q, "quint4") |
|
|
|
|
|
|
|
|
|
|
|
def convert_from_quint4(arr): |
|
|
|
def convert_from_quint4(arr: np.ndarray): |
|
|
|
""" |
|
|
|
Dequantize a quint4 NumPy ndarray into a float one. |
|
|
|
|
|
|
|
:param arr: Input ndarray. |
|
|
|
""" |
|
|
|
assert isinstance(arr, np.ndarray) |
|
|
|
assert ( |
|
|
|
"mgb_dtype" in arr.dtype.metadata |
|
|
|
and arr.dtype.metadata["mgb_dtype"]["name"] == "Quantized4Asymm" |
|
|
|
), "arr should be a ndarray with quint4 dtype" |
|
|
|
scale, zp = ( |
|
|
|
arr.dtype.metadata["mgb_dtype"]["scale"], |
|
|
|
arr.dtype.metadata["mgb_dtype"]["zero_point"], |
|
|
|
) |
|
|
|
return (arr.astype(np.float32) - zp) * scale |
|
|
|
return _convert_from_dtype(arr, "quint4") |
|
|
|
|
|
|
|
|
|
|
|
def convert_to_qint4(arr, q): |
|
|
|
def convert_to_qint4(arr: np.ndarray, q: np.dtype): |
|
|
|
""" |
|
|
|
Quantize a float NumPy ndarray into a qint4 one with specified params. |
|
|
|
|
|
|
|
:param arr: Input ndarray. |
|
|
|
:type arr: :class:`np.ndarray` |
|
|
|
:param q: Target data type, should be a qint4. |
|
|
|
:type q: :class:`np.dtype` |
|
|
|
""" |
|
|
|
assert isinstance(arr, np.ndarray) |
|
|
|
assert ( |
|
|
|
"mgb_dtype" in q.metadata and q.metadata["mgb_dtype"]["name"] == "QuantizedS4" |
|
|
|
), "q should be a qint4 dtype" |
|
|
|
scale = q.metadata["mgb_dtype"]["scale"] |
|
|
|
return (np.round(arr / scale)).clip(-8, 7).astype(q) |
|
|
|
return _convert_to_dtype(arr, q, "qint4") |
|
|
|
|
|
|
|
|
|
|
|
def convert_from_qint4(arr): |
|
|
|
def convert_from_qint4(arr: np.ndarray): |
|
|
|
""" |
|
|
|
Dequantize a qint4 NumPy ndarray into a float one. |
|
|
|
|
|
|
|
:param arr: Input ndarray. |
|
|
|
""" |
|
|
|
assert isinstance(arr, np.ndarray) |
|
|
|
assert ( |
|
|
|
"mgb_dtype" in arr.dtype.metadata |
|
|
|
and arr.dtype.metadata["mgb_dtype"]["name"] == "QuantizedS4" |
|
|
|
), "arr should be a ndarray with qint4 dtype" |
|
|
|
scale = arr.dtype.metadata["mgb_dtype"]["scale"] |
|
|
|
return arr.astype(np.float32) * scale |
|
|
|
return _convert_from_dtype(arr, "qint4") |