|
|
@@ -6,36 +6,25 @@ |
|
|
|
# software distributed under the License is distributed on an |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
|
|
|
|
import collections |
|
|
|
from typing import Union |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from .mgb import intb1, intb2, intb4 |
|
|
|
|
|
|
|
_QuantDtypeMetadata = collections.namedtuple( |
|
|
|
"QuantDtypeMetadata", ["name", "np_dtype_str", "is_unsigned", "qmin", "qmax",] |
|
|
|
) |
|
|
|
|
|
|
|
_metadata_dict = { |
|
|
|
"quint8": { |
|
|
|
"is_unsigned": True, |
|
|
|
"np_dtype_str": "uint8", |
|
|
|
"mgb_dtype": {"name": "Quantized8Asymm", "qmin": 0, "qmax": 255,}, |
|
|
|
}, |
|
|
|
"qint8": { |
|
|
|
"is_unsigned": False, |
|
|
|
"np_dtype_str": "int8", |
|
|
|
"mgb_dtype": {"name": "QuantizedS8", "qmin": -128, "qmax": 127,}, |
|
|
|
}, |
|
|
|
"quint4": { |
|
|
|
"is_unsigned": True, |
|
|
|
"np_dtype_str": "uint8", |
|
|
|
"mgb_dtype": {"name": "Quantized4Asymm", "qmin": 0, "qmax": 15,}, |
|
|
|
}, |
|
|
|
"qint4": { |
|
|
|
"is_unsigned": False, |
|
|
|
"np_dtype_str": "int8", |
|
|
|
"mgb_dtype": {"name": "QuantizedS4", "qmin": -8, "qmax": 7,}, |
|
|
|
}, |
|
|
|
"qint32": { |
|
|
|
"is_unsigned": False, |
|
|
|
"np_dtype_str": "int32", |
|
|
|
"mgb_dtype": {"name": "QuantizedS32", "qmin": -(2 ** 31), "qmax": 2 ** 31 - 1,}, |
|
|
|
}, |
|
|
|
"quint8": _QuantDtypeMetadata("Quantized8Asymm", "uint8", True, 0, 255), |
|
|
|
"qint8": _QuantDtypeMetadata("QuantizedS8", "int8", False, -128, 127), |
|
|
|
"quint4": _QuantDtypeMetadata("Quantized4Asymm", "uint8", True, 0, 15), |
|
|
|
"qint4": _QuantDtypeMetadata("QuantizedS4", "int8", False, -8, 7), |
|
|
|
"qint32": _QuantDtypeMetadata( |
|
|
|
"QuantizedS32", "int32", False, -(2 ** 31), 2 ** 31 - 1, |
|
|
|
), |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
@@ -64,26 +53,49 @@ def get_zero_point(dtype): |
|
|
|
|
|
|
|
|
|
|
|
def _check_zero_point(zp: int, dtype_str: str): |
|
|
|
qmin = _metadata_dict[dtype_str]["mgb_dtype"]["qmin"] |
|
|
|
qmax = _metadata_dict[dtype_str]["mgb_dtype"]["qmax"] |
|
|
|
qmin = _metadata_dict[dtype_str].qmin |
|
|
|
qmax = _metadata_dict[dtype_str].qmax |
|
|
|
if zp < qmin or zp > qmax: |
|
|
|
raise ValueError( |
|
|
|
"zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str) |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def _get_dtype(dtype_str: str, scale, zp): |
|
|
|
if zp is not None: |
|
|
|
if int(zp) != zp: |
|
|
|
def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]): |
|
|
|
r""" |
|
|
|
Get quantized dtype with metadata attribute according to _metadata_dict. |
|
|
|
|
|
|
|
Note that unsigned dtype must have ``zero_point`` and signed dtype must |
|
|
|
not have ``zero_point``, to be consitent with tensor generated by calling |
|
|
|
compiled function from `CompGraph.compile(inputs, outspec)`. |
|
|
|
|
|
|
|
:param dtype: a string indicating which dtype to return |
|
|
|
:param scale: a number for scale to store in dtype's metadata |
|
|
|
:param zp: a number for zero_point to store in dtype's metadata |
|
|
|
""" |
|
|
|
metadata = _metadata_dict[dtype_str] |
|
|
|
np_dtype_str = metadata.np_dtype_str |
|
|
|
is_unsigned = metadata.is_unsigned |
|
|
|
if is_unsigned: |
|
|
|
if zp is None or int(zp) != zp: |
|
|
|
raise ValueError("zero_point should be an integer") |
|
|
|
zp = int(zp) |
|
|
|
_check_zero_point(zp, dtype_str) |
|
|
|
metadata = _metadata_dict[dtype_str]["mgb_dtype"] |
|
|
|
np_dtype_str = _metadata_dict[dtype_str]["np_dtype_str"] |
|
|
|
return np.dtype( |
|
|
|
np_dtype_str, |
|
|
|
metadata={"mgb_dtype": {**metadata, "scale": float(scale), "zero_point": zp,}}, |
|
|
|
) |
|
|
|
return np.dtype( |
|
|
|
np_dtype_str, |
|
|
|
metadata={ |
|
|
|
"mgb_dtype": { |
|
|
|
"name": metadata.name, |
|
|
|
"scale": float(scale), |
|
|
|
"zero_point": zp, |
|
|
|
} |
|
|
|
}, |
|
|
|
) |
|
|
|
else: |
|
|
|
return np.dtype( |
|
|
|
np_dtype_str, |
|
|
|
metadata={"mgb_dtype": {"name": metadata.name, "scale": float(scale)}}, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def quint8(scale, zero_point): |
|
|
@@ -92,7 +104,7 @@ def quint8(scale, zero_point): |
|
|
|
``zero_point`` (uint8). The real value represented by a quint8 data type is |
|
|
|
float_val = scale * (uint8_val - zero_point) |
|
|
|
""" |
|
|
|
return _get_dtype("quint8", scale, zero_point) |
|
|
|
return get_quantized_dtype("quint8", scale, zero_point) |
|
|
|
|
|
|
|
|
|
|
|
def qint8(scale): |
|
|
@@ -100,7 +112,7 @@ def qint8(scale): |
|
|
|
Construct a quantized int8 data type with ``scale`` (float). The real value |
|
|
|
represented by a qint8 data type is float_val = scale * int8_val |
|
|
|
""" |
|
|
|
return _get_dtype("qint8", scale, None) |
|
|
|
return get_quantized_dtype("qint8", scale, None) |
|
|
|
|
|
|
|
|
|
|
|
def qint32(scale): |
|
|
@@ -108,7 +120,7 @@ def qint32(scale): |
|
|
|
Construct a quantized int32 data type with ``scale`` (float). The real value |
|
|
|
represented by a qint32 data type is float_val = scale * int32_val |
|
|
|
""" |
|
|
|
return _get_dtype("qint32", scale, None) |
|
|
|
return get_quantized_dtype("qint32", scale, None) |
|
|
|
|
|
|
|
|
|
|
|
def quint4(scale, zero_point): |
|
|
@@ -117,7 +129,7 @@ def quint4(scale, zero_point): |
|
|
|
``zero_point`` (uint8). The real value represented by a quint4 data type is |
|
|
|
float_val = scale * (uint4_val - zero_point) |
|
|
|
""" |
|
|
|
return _get_dtype("quint4", scale, zero_point) |
|
|
|
return get_quantized_dtype("quint4", scale, zero_point) |
|
|
|
|
|
|
|
|
|
|
|
def qint4(scale): |
|
|
@@ -125,17 +137,17 @@ def qint4(scale): |
|
|
|
Construct a quantized int4 data type with ``scale`` (float). The real value |
|
|
|
represented by a qint4 data type is float_val = scale * int4_val |
|
|
|
""" |
|
|
|
return _get_dtype("qint4", scale, None) |
|
|
|
return get_quantized_dtype("qint4", scale, None) |
|
|
|
|
|
|
|
|
|
|
|
def _convert_to_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str): |
|
|
|
metadata = _metadata_dict[dtype_str]["mgb_dtype"] |
|
|
|
def _convert_to_quantized_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str): |
|
|
|
metadata = _metadata_dict[dtype_str] |
|
|
|
arr_metadata = dtype.metadata["mgb_dtype"] |
|
|
|
if not isinstance(arr, np.ndarray): |
|
|
|
raise ValueError("arr parameter should be instance of np.ndarray") |
|
|
|
if not is_quantize(dtype) or arr_metadata["name"] != metadata["name"]: |
|
|
|
if not is_quantize(dtype) or arr_metadata["name"] != metadata.name: |
|
|
|
raise ValueError("dtype parameter should be a {} dtype".format(dtype_str)) |
|
|
|
is_unsigned = _metadata_dict[dtype_str]["is_unsigned"] |
|
|
|
is_unsigned = metadata.is_unsigned |
|
|
|
if is_unsigned: |
|
|
|
scale, zp = ( |
|
|
|
arr_metadata["scale"], |
|
|
@@ -143,25 +155,23 @@ def _convert_to_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str): |
|
|
|
) |
|
|
|
return ( |
|
|
|
(np.round(arr / scale) + zp) |
|
|
|
.clip(metadata["qmin"], metadata["qmax"]) |
|
|
|
.clip(metadata.qmin, metadata.qmax) |
|
|
|
.astype(dtype) |
|
|
|
) |
|
|
|
else: |
|
|
|
# don't trick to combine with is_unsigned for consistency with cpp interface |
|
|
|
# don't trick to combine with is_unsigned, seeing ``get_quantized_dtype`` |
|
|
|
scale = arr_metadata["scale"] |
|
|
|
return ( |
|
|
|
np.round(arr / scale).clip(metadata["qmin"], metadata["qmax"]).astype(dtype) |
|
|
|
) |
|
|
|
return np.round(arr / scale).clip(metadata.qmin, metadata.qmax).astype(dtype) |
|
|
|
|
|
|
|
|
|
|
|
def _convert_from_dtype(arr: np.ndarray, dtype_str: str): |
|
|
|
metadata = _metadata_dict[dtype_str]["mgb_dtype"] |
|
|
|
def _convert_from_quantized_dtype(arr: np.ndarray, dtype_str: str): |
|
|
|
metadata = _metadata_dict[dtype_str] |
|
|
|
arr_metadata = arr.dtype.metadata["mgb_dtype"] |
|
|
|
if not isinstance(arr, np.ndarray): |
|
|
|
raise ValueError("arr parameter should be instance of np.ndarray") |
|
|
|
if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata["name"]: |
|
|
|
if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata.name: |
|
|
|
raise ValueError("arr's dtype should be a {} dtype".format(dtype_str)) |
|
|
|
is_unsigned = _metadata_dict[dtype_str]["is_unsigned"] |
|
|
|
is_unsigned = metadata.is_unsigned |
|
|
|
if is_unsigned: |
|
|
|
scale, zp = ( |
|
|
|
arr_metadata["scale"], |
|
|
@@ -169,7 +179,7 @@ def _convert_from_dtype(arr: np.ndarray, dtype_str: str): |
|
|
|
) |
|
|
|
return (arr.astype(np.float32) - zp) * scale |
|
|
|
else: |
|
|
|
# don't trick to combine with is_unsigned for consistency with cpp interface |
|
|
|
# don't trick to combine with is_unsigned, seeing ``get_quantized_dtype`` |
|
|
|
scale = arr_metadata["scale"] |
|
|
|
return (arr.astype(np.float32)) * scale |
|
|
|
|
|
|
@@ -181,7 +191,7 @@ def convert_to_quint8(arr: np.ndarray, q: np.dtype): |
|
|
|
:param arr: Input ndarray. |
|
|
|
:param q: Target data type, should be a quint8. |
|
|
|
""" |
|
|
|
return _convert_to_dtype(arr, q, "quint8") |
|
|
|
return _convert_to_quantized_dtype(arr, q, "quint8") |
|
|
|
|
|
|
|
|
|
|
|
def convert_from_quint8(arr: np.ndarray): |
|
|
@@ -190,7 +200,7 @@ def convert_from_quint8(arr: np.ndarray): |
|
|
|
|
|
|
|
:param arr: Input ndarray. |
|
|
|
""" |
|
|
|
return _convert_from_dtype(arr, "quint8") |
|
|
|
return _convert_from_quantized_dtype(arr, "quint8") |
|
|
|
|
|
|
|
|
|
|
|
def convert_to_qint8(arr: np.ndarray, q: np.dtype): |
|
|
@@ -200,7 +210,7 @@ def convert_to_qint8(arr: np.ndarray, q: np.dtype): |
|
|
|
:param arr: Input ndarray. |
|
|
|
:param q: Target data type, should be a qint8. |
|
|
|
""" |
|
|
|
return _convert_to_dtype(arr, q, "qint8") |
|
|
|
return _convert_to_quantized_dtype(arr, q, "qint8") |
|
|
|
|
|
|
|
|
|
|
|
def convert_from_qint8(arr: np.ndarray): |
|
|
@@ -209,7 +219,7 @@ def convert_from_qint8(arr: np.ndarray): |
|
|
|
|
|
|
|
:param arr: Input ndarray. |
|
|
|
""" |
|
|
|
return _convert_from_dtype(arr, "qint8") |
|
|
|
return _convert_from_quantized_dtype(arr, "qint8") |
|
|
|
|
|
|
|
|
|
|
|
def convert_to_qint32(arr: np.ndarray, q: np.dtype): |
|
|
@@ -219,7 +229,7 @@ def convert_to_qint32(arr: np.ndarray, q: np.dtype): |
|
|
|
:param arr: Input ndarray. |
|
|
|
:param q: Target data type, should be a qint8. |
|
|
|
""" |
|
|
|
return _convert_to_dtype(arr, q, "qint32") |
|
|
|
return _convert_to_quantized_dtype(arr, q, "qint32") |
|
|
|
|
|
|
|
|
|
|
|
def convert_from_qint32(arr): |
|
|
@@ -228,7 +238,7 @@ def convert_from_qint32(arr): |
|
|
|
|
|
|
|
:param arr: Input ndarray. |
|
|
|
""" |
|
|
|
return _convert_from_dtype(arr, "qint32") |
|
|
|
return _convert_from_quantized_dtype(arr, "qint32") |
|
|
|
|
|
|
|
|
|
|
|
def convert_to_quint4(arr: np.ndarray, q: np.dtype): |
|
|
@@ -238,7 +248,7 @@ def convert_to_quint4(arr: np.ndarray, q: np.dtype): |
|
|
|
:param arr: Input ndarray. |
|
|
|
:param q: Target data type, should be a quint4. |
|
|
|
""" |
|
|
|
return _convert_to_dtype(arr, q, "quint4") |
|
|
|
return _convert_to_quantized_dtype(arr, q, "quint4") |
|
|
|
|
|
|
|
|
|
|
|
def convert_from_quint4(arr: np.ndarray): |
|
|
@@ -247,7 +257,7 @@ def convert_from_quint4(arr: np.ndarray): |
|
|
|
|
|
|
|
:param arr: Input ndarray. |
|
|
|
""" |
|
|
|
return _convert_from_dtype(arr, "quint4") |
|
|
|
return _convert_from_quantized_dtype(arr, "quint4") |
|
|
|
|
|
|
|
|
|
|
|
def convert_to_qint4(arr: np.ndarray, q: np.dtype): |
|
|
@@ -257,7 +267,7 @@ def convert_to_qint4(arr: np.ndarray, q: np.dtype): |
|
|
|
:param arr: Input ndarray. |
|
|
|
:param q: Target data type, should be a qint4. |
|
|
|
""" |
|
|
|
return _convert_to_dtype(arr, q, "qint4") |
|
|
|
return _convert_to_quantized_dtype(arr, q, "qint4") |
|
|
|
|
|
|
|
|
|
|
|
def convert_from_qint4(arr: np.ndarray): |
|
|
@@ -266,4 +276,4 @@ def convert_from_qint4(arr: np.ndarray): |
|
|
|
|
|
|
|
:param arr: Input ndarray. |
|
|
|
""" |
|
|
|
return _convert_from_dtype(arr, "qint4") |
|
|
|
return _convert_from_quantized_dtype(arr, "qint4") |