Browse Source

refactor(mge/dtype): modify some interface name and enrich comments

GitOrigin-RevId: f9217f6d27
tags/v0.4.0
Megvii Engine Team Xinran Xu 5 years ago
parent
commit
a6f456415f
1 changed files with 75 additions and 65 deletions
  1. +75
    -65
      python_module/megengine/_internal/dtype.py

+ 75
- 65
python_module/megengine/_internal/dtype.py View File

@@ -6,36 +6,25 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import collections
from typing import Union

import numpy as np

from .mgb import intb1, intb2, intb4

_QuantDtypeMetadata = collections.namedtuple(
"QuantDtypeMetadata", ["name", "np_dtype_str", "is_unsigned", "qmin", "qmax",]
)

_metadata_dict = {
"quint8": {
"is_unsigned": True,
"np_dtype_str": "uint8",
"mgb_dtype": {"name": "Quantized8Asymm", "qmin": 0, "qmax": 255,},
},
"qint8": {
"is_unsigned": False,
"np_dtype_str": "int8",
"mgb_dtype": {"name": "QuantizedS8", "qmin": -128, "qmax": 127,},
},
"quint4": {
"is_unsigned": True,
"np_dtype_str": "uint8",
"mgb_dtype": {"name": "Quantized4Asymm", "qmin": 0, "qmax": 15,},
},
"qint4": {
"is_unsigned": False,
"np_dtype_str": "int8",
"mgb_dtype": {"name": "QuantizedS4", "qmin": -8, "qmax": 7,},
},
"qint32": {
"is_unsigned": False,
"np_dtype_str": "int32",
"mgb_dtype": {"name": "QuantizedS32", "qmin": -(2 ** 31), "qmax": 2 ** 31 - 1,},
},
"quint8": _QuantDtypeMetadata("Quantized8Asymm", "uint8", True, 0, 255),
"qint8": _QuantDtypeMetadata("QuantizedS8", "int8", False, -128, 127),
"quint4": _QuantDtypeMetadata("Quantized4Asymm", "uint8", True, 0, 15),
"qint4": _QuantDtypeMetadata("QuantizedS4", "int8", False, -8, 7),
"qint32": _QuantDtypeMetadata(
"QuantizedS32", "int32", False, -(2 ** 31), 2 ** 31 - 1,
),
}


@@ -64,26 +53,49 @@ def get_zero_point(dtype):


def _check_zero_point(zp: int, dtype_str: str):
qmin = _metadata_dict[dtype_str]["mgb_dtype"]["qmin"]
qmax = _metadata_dict[dtype_str]["mgb_dtype"]["qmax"]
qmin = _metadata_dict[dtype_str].qmin
qmax = _metadata_dict[dtype_str].qmax
if zp < qmin or zp > qmax:
raise ValueError(
"zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str)
)


def _get_dtype(dtype_str: str, scale, zp):
if zp is not None:
if int(zp) != zp:
def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]):
r"""
Get quantized dtype with metadata attribute according to _metadata_dict.

Note that unsigned dtype must have ``zero_point`` and signed dtype must
not have ``zero_point``, to be consitent with tensor generated by calling
compiled function from `CompGraph.compile(inputs, outspec)`.

:param dtype: a string indicating which dtype to return
:param scale: a number for scale to store in dtype's metadata
:param zp: a number for zero_point to store in dtype's metadata
"""
metadata = _metadata_dict[dtype_str]
np_dtype_str = metadata.np_dtype_str
is_unsigned = metadata.is_unsigned
if is_unsigned:
if zp is None or int(zp) != zp:
raise ValueError("zero_point should be an integer")
zp = int(zp)
_check_zero_point(zp, dtype_str)
metadata = _metadata_dict[dtype_str]["mgb_dtype"]
np_dtype_str = _metadata_dict[dtype_str]["np_dtype_str"]
return np.dtype(
np_dtype_str,
metadata={"mgb_dtype": {**metadata, "scale": float(scale), "zero_point": zp,}},
)
return np.dtype(
np_dtype_str,
metadata={
"mgb_dtype": {
"name": metadata.name,
"scale": float(scale),
"zero_point": zp,
}
},
)
else:
return np.dtype(
np_dtype_str,
metadata={"mgb_dtype": {"name": metadata.name, "scale": float(scale)}},
)


def quint8(scale, zero_point):
@@ -92,7 +104,7 @@ def quint8(scale, zero_point):
``zero_point`` (uint8). The real value represented by a quint8 data type is
float_val = scale * (uint8_val - zero_point)
"""
return _get_dtype("quint8", scale, zero_point)
return get_quantized_dtype("quint8", scale, zero_point)


def qint8(scale):
@@ -100,7 +112,7 @@ def qint8(scale):
Construct a quantized int8 data type with ``scale`` (float). The real value
represented by a qint8 data type is float_val = scale * int8_val
"""
return _get_dtype("qint8", scale, None)
return get_quantized_dtype("qint8", scale, None)


def qint32(scale):
@@ -108,7 +120,7 @@ def qint32(scale):
Construct a quantized int32 data type with ``scale`` (float). The real value
represented by a qint32 data type is float_val = scale * int32_val
"""
return _get_dtype("qint32", scale, None)
return get_quantized_dtype("qint32", scale, None)


def quint4(scale, zero_point):
@@ -117,7 +129,7 @@ def quint4(scale, zero_point):
``zero_point`` (uint8). The real value represented by a quint4 data type is
float_val = scale * (uint4_val - zero_point)
"""
return _get_dtype("quint4", scale, zero_point)
return get_quantized_dtype("quint4", scale, zero_point)


def qint4(scale):
@@ -125,17 +137,17 @@ def qint4(scale):
Construct a quantized int4 data type with ``scale`` (float). The real value
represented by a qint4 data type is float_val = scale * int4_val
"""
return _get_dtype("qint4", scale, None)
return get_quantized_dtype("qint4", scale, None)


def _convert_to_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str):
metadata = _metadata_dict[dtype_str]["mgb_dtype"]
def _convert_to_quantized_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str):
metadata = _metadata_dict[dtype_str]
arr_metadata = dtype.metadata["mgb_dtype"]
if not isinstance(arr, np.ndarray):
raise ValueError("arr parameter should be instance of np.ndarray")
if not is_quantize(dtype) or arr_metadata["name"] != metadata["name"]:
if not is_quantize(dtype) or arr_metadata["name"] != metadata.name:
raise ValueError("dtype parameter should be a {} dtype".format(dtype_str))
is_unsigned = _metadata_dict[dtype_str]["is_unsigned"]
is_unsigned = metadata.is_unsigned
if is_unsigned:
scale, zp = (
arr_metadata["scale"],
@@ -143,25 +155,23 @@ def _convert_to_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str):
)
return (
(np.round(arr / scale) + zp)
.clip(metadata["qmin"], metadata["qmax"])
.clip(metadata.qmin, metadata.qmax)
.astype(dtype)
)
else:
# don't trick to combine with is_unsigned for consistency with cpp interface
# don't trick to combine with is_unsigned, seeing ``get_quantized_dtype``
scale = arr_metadata["scale"]
return (
np.round(arr / scale).clip(metadata["qmin"], metadata["qmax"]).astype(dtype)
)
return np.round(arr / scale).clip(metadata.qmin, metadata.qmax).astype(dtype)


def _convert_from_dtype(arr: np.ndarray, dtype_str: str):
metadata = _metadata_dict[dtype_str]["mgb_dtype"]
def _convert_from_quantized_dtype(arr: np.ndarray, dtype_str: str):
metadata = _metadata_dict[dtype_str]
arr_metadata = arr.dtype.metadata["mgb_dtype"]
if not isinstance(arr, np.ndarray):
raise ValueError("arr parameter should be instance of np.ndarray")
if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata["name"]:
if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata.name:
raise ValueError("arr's dtype should be a {} dtype".format(dtype_str))
is_unsigned = _metadata_dict[dtype_str]["is_unsigned"]
is_unsigned = metadata.is_unsigned
if is_unsigned:
scale, zp = (
arr_metadata["scale"],
@@ -169,7 +179,7 @@ def _convert_from_dtype(arr: np.ndarray, dtype_str: str):
)
return (arr.astype(np.float32) - zp) * scale
else:
# don't trick to combine with is_unsigned for consistency with cpp interface
# don't trick to combine with is_unsigned, seeing ``get_quantized_dtype``
scale = arr_metadata["scale"]
return (arr.astype(np.float32)) * scale

@@ -181,7 +191,7 @@ def convert_to_quint8(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a quint8.
"""
return _convert_to_dtype(arr, q, "quint8")
return _convert_to_quantized_dtype(arr, q, "quint8")


def convert_from_quint8(arr: np.ndarray):
@@ -190,7 +200,7 @@ def convert_from_quint8(arr: np.ndarray):

:param arr: Input ndarray.
"""
return _convert_from_dtype(arr, "quint8")
return _convert_from_quantized_dtype(arr, "quint8")


def convert_to_qint8(arr: np.ndarray, q: np.dtype):
@@ -200,7 +210,7 @@ def convert_to_qint8(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a qint8.
"""
return _convert_to_dtype(arr, q, "qint8")
return _convert_to_quantized_dtype(arr, q, "qint8")


def convert_from_qint8(arr: np.ndarray):
@@ -209,7 +219,7 @@ def convert_from_qint8(arr: np.ndarray):

:param arr: Input ndarray.
"""
return _convert_from_dtype(arr, "qint8")
return _convert_from_quantized_dtype(arr, "qint8")


def convert_to_qint32(arr: np.ndarray, q: np.dtype):
@@ -219,7 +229,7 @@ def convert_to_qint32(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a qint8.
"""
return _convert_to_dtype(arr, q, "qint32")
return _convert_to_quantized_dtype(arr, q, "qint32")


def convert_from_qint32(arr):
@@ -228,7 +238,7 @@ def convert_from_qint32(arr):

:param arr: Input ndarray.
"""
return _convert_from_dtype(arr, "qint32")
return _convert_from_quantized_dtype(arr, "qint32")


def convert_to_quint4(arr: np.ndarray, q: np.dtype):
@@ -238,7 +248,7 @@ def convert_to_quint4(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a quint4.
"""
return _convert_to_dtype(arr, q, "quint4")
return _convert_to_quantized_dtype(arr, q, "quint4")


def convert_from_quint4(arr: np.ndarray):
@@ -247,7 +257,7 @@ def convert_from_quint4(arr: np.ndarray):

:param arr: Input ndarray.
"""
return _convert_from_dtype(arr, "quint4")
return _convert_from_quantized_dtype(arr, "quint4")


def convert_to_qint4(arr: np.ndarray, q: np.dtype):
@@ -257,7 +267,7 @@ def convert_to_qint4(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray.
:param q: Target data type, should be a qint4.
"""
return _convert_to_dtype(arr, q, "qint4")
return _convert_to_quantized_dtype(arr, q, "qint4")


def convert_from_qint4(arr: np.ndarray):
@@ -266,4 +276,4 @@ def convert_from_qint4(arr: np.ndarray):

:param arr: Input ndarray.
"""
return _convert_from_dtype(arr, "qint4")
return _convert_from_quantized_dtype(arr, "qint4")

Loading…
Cancel
Save