GitOrigin-RevId: df3416fe13
tags/v1.3.0
@@ -6,9 +6,6 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import warnings | |||||
from typing import Union | |||||
from ..._imperative_rt import OpDef, ops | from ..._imperative_rt import OpDef, ops | ||||
__all__ = ["OpDef"] | __all__ = ["OpDef"] | ||||
@@ -5,22 +5,24 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import collections | |||||
from collections import namedtuple | |||||
from typing import Union | from typing import Union | ||||
import numpy as np | import numpy as np | ||||
# normal dtype related | |||||
from .._imperative_rt import bfloat16, intb1, intb2, intb4 | |||||
from .._imperative_rt.common import ( | from .._imperative_rt.common import ( | ||||
bfloat16, | |||||
get_scale, | get_scale, | ||||
get_zero_point, | get_zero_point, | ||||
intb1, | |||||
intb2, | |||||
intb4, | |||||
is_dtype_equal, | is_dtype_equal, | ||||
is_quantize, | is_quantize, | ||||
) | ) | ||||
# normal dtype related | |||||
def is_lowbit(dtype): | def is_lowbit(dtype): | ||||
return (dtype is intb1) or (dtype is intb2) or (dtype is intb4) | return (dtype is intb1) or (dtype is intb2) or (dtype is intb4) | ||||
@@ -30,34 +32,80 @@ def is_bfloat16(dtype): | |||||
# quantization dtype related | # quantization dtype related | ||||
_QuantDtypeMetadata = collections.namedtuple( | |||||
"QuantDtypeMetadata", ["name", "np_dtype_str", "is_unsigned", "qmin", "qmax",] | |||||
) | |||||
_metadata_dict = { | |||||
"quint8": _QuantDtypeMetadata("Quantized8Asymm", "uint8", True, 0, 255), | |||||
"qint8": _QuantDtypeMetadata("QuantizedS8", "int8", False, -128, 127), | |||||
"quint4": _QuantDtypeMetadata("Quantized4Asymm", "uint8", True, 0, 15), | |||||
"qint4": _QuantDtypeMetadata("QuantizedS4", "int8", False, -8, 7), | |||||
"qint32": _QuantDtypeMetadata( | |||||
"QuantizedS32", "int32", False, -(2 ** 31), 2 ** 31 - 1, | |||||
# use namedtuple to make class immutable, comparable and easy to print | |||||
class QuantDtypeMeta( | |||||
namedtuple( | |||||
"QuantDtypeMeta", | |||||
["name", "cname", "np_dtype_str", "qmin", "qmax", "is_unsigned"], | |||||
) | |||||
): | |||||
r""" | |||||
Store metadata for quantize dtype. Could be used to create custom quant dtype | |||||
for QAT when the network don't need to be converted for inference, but only | |||||
to export network metadata for third-party platform inference. | |||||
:param name: a unique name string. | |||||
:param cname: used in :func:`~.create_quantized_dtype` for model dump and inference. | |||||
:param np_dtype_str: used in :func:`~.create_quantized_dtype` to generate ``np.dtype``. | |||||
:param qmin: a int number indicating quant dtype's lowerbound. | |||||
:param qmax: a int number indicating quant dtype's upperbound. | |||||
:param is_unsigned: a helper value that could be inference from np_dtype_str. | |||||
""" | |||||
def __new__( | |||||
cls, | |||||
name: str, | |||||
cname: str, | |||||
np_dtype_str: str, | |||||
qmin: int, | |||||
qmax: int, | |||||
is_unsigned: bool = None, | |||||
): | |||||
assert isinstance(np_dtype_str, str) | |||||
is_unsigned = np_dtype_str[0] == "u" if is_unsigned is None else is_unsigned | |||||
return super().__new__(cls, name, cname, np_dtype_str, qmin, qmax, is_unsigned) | |||||
def __copy__(self): | |||||
return self | |||||
def __deepcopy__(self, _): | |||||
""" | |||||
Ignore deepcopy so that a dtype meta can be treated as singleton, for more | |||||
strict check in :meth:`~.FakeQuantize.fake_quant_forward`. | |||||
""" | |||||
return self | |||||
_builtin_quant_dtypes = { | |||||
"quint8": QuantDtypeMeta("quint8", "Quantized8Asymm", "uint8", 0, 255), | |||||
"qint8": QuantDtypeMeta("qint8", "QuantizedS8", "int8", -128, 127), | |||||
"qint8_narrow": QuantDtypeMeta("qint8_narrow", "QuantizedS8", "int8", -127, 127), | |||||
"quint4": QuantDtypeMeta("quint4", "Quantized4Asymm", "uint8", 0, 15), | |||||
"qint4": QuantDtypeMeta("qint4", "QuantizedS4", "int8", -8, 7), | |||||
"qint32": QuantDtypeMeta( | |||||
"qint32", "QuantizedS32", "int32", -(2 ** 31), 2 ** 31 - 1, | |||||
), | ), | ||||
# NOTE: int2 is not supported for model dump yet | # NOTE: int2 is not supported for model dump yet | ||||
"quint2": _QuantDtypeMetadata(None, "uint8", True, 0, 3), | |||||
"qint2": _QuantDtypeMetadata(None, "int8", False, -2, 1), | |||||
"quint2": QuantDtypeMeta("quint2", None, "uint8", 0, 3), | |||||
"qint2": QuantDtypeMeta("qint2", None, "int8", -2, 1), | |||||
} | } | ||||
def _check_zero_point(zp: int, dtype_str: str): | |||||
qmin = _metadata_dict[dtype_str].qmin | |||||
qmax = _metadata_dict[dtype_str].qmax | |||||
def _check_zero_point(zp: int, dtype_meta: QuantDtypeMeta): | |||||
qmin = dtype_meta.qmin | |||||
qmax = dtype_meta.qmax | |||||
if zp < qmin or zp > qmax: | if zp < qmin or zp > qmax: | ||||
raise ValueError( | raise ValueError( | ||||
"zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str) | |||||
"zero_point should be within [{}, {}] for {}".format( | |||||
qmin, qmax, dtype_meta.name | |||||
) | |||||
) | ) | ||||
def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]): | |||||
def create_quantized_dtype( | |||||
dtype_meta: QuantDtypeMeta, scale: float, zp: Union[int, None] | |||||
): | |||||
r""" | r""" | ||||
Get quantized dtype with metadata attribute according to _metadata_dict. | Get quantized dtype with metadata attribute according to _metadata_dict. | ||||
@@ -65,32 +113,34 @@ def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]): | |||||
not have ``zero_point``, to be consitent with tensor generated by calling | not have ``zero_point``, to be consitent with tensor generated by calling | ||||
compiled function from `CompGraph.compile(inputs, outspec)`. | compiled function from `CompGraph.compile(inputs, outspec)`. | ||||
:param dtype: a string indicating which dtype to return | |||||
:param dtype_meta: a QuantDtypeMeta indicating which dtype to return. the | |||||
``cname`` attribute cannot be ``None``. | |||||
:param scale: a number for scale to store in dtype's metadata | :param scale: a number for scale to store in dtype's metadata | ||||
:param zp: a number for zero_point to store in dtype's metadata | :param zp: a number for zero_point to store in dtype's metadata | ||||
""" | """ | ||||
metadata = _metadata_dict[dtype_str] | |||||
np_dtype_str = metadata.np_dtype_str | |||||
is_unsigned = metadata.is_unsigned | |||||
if is_unsigned: | |||||
if dtype_meta.cname is None: | |||||
raise ValueError("dtype {} without cname attr is not supported.") | |||||
if dtype_meta.is_unsigned: | |||||
if zp is None or int(zp) != zp: | if zp is None or int(zp) != zp: | ||||
raise ValueError("zero_point should be an integer") | raise ValueError("zero_point should be an integer") | ||||
zp = int(zp) | zp = int(zp) | ||||
_check_zero_point(zp, dtype_str) | |||||
_check_zero_point(zp, dtype_meta) | |||||
return np.dtype( | return np.dtype( | ||||
np_dtype_str, | |||||
dtype_meta.np_dtype_str, | |||||
metadata={ | metadata={ | ||||
"mgb_dtype": { | "mgb_dtype": { | ||||
"name": metadata.name, | |||||
"name": dtype_meta.cname, | |||||
"scale": float(scale), | "scale": float(scale), | ||||
"zero_point": zp, | "zero_point": zp, | ||||
} | } | ||||
}, | }, | ||||
) | ) | ||||
else: | else: | ||||
# Don't trick to combine with is_unsigned. Metadata should not contain | |||||
# invalid field to keep consistent with c dtype. | |||||
return np.dtype( | return np.dtype( | ||||
np_dtype_str, | |||||
metadata={"mgb_dtype": {"name": metadata.name, "scale": float(scale)}}, | |||||
dtype_meta.np_dtype_str, | |||||
metadata={"mgb_dtype": {"name": dtype_meta.cname, "scale": float(scale)}}, | |||||
) | ) | ||||
@@ -100,7 +150,7 @@ def quint8(scale, zero_point): | |||||
``zero_point`` (uint8). The real value represented by a quint8 data type is | ``zero_point`` (uint8). The real value represented by a quint8 data type is | ||||
float_val = scale * (uint8_val - zero_point) | float_val = scale * (uint8_val - zero_point) | ||||
""" | """ | ||||
return get_quantized_dtype("quint8", scale, zero_point) | |||||
return create_quantized_dtype(_builtin_quant_dtypes["quint8"], scale, zero_point) | |||||
def qint8(scale): | def qint8(scale): | ||||
@@ -108,7 +158,7 @@ def qint8(scale): | |||||
Construct a quantized int8 data type with ``scale`` (float). The real value | Construct a quantized int8 data type with ``scale`` (float). The real value | ||||
represented by a qint8 data type is float_val = scale * int8_val | represented by a qint8 data type is float_val = scale * int8_val | ||||
""" | """ | ||||
return get_quantized_dtype("qint8", scale, None) | |||||
return create_quantized_dtype(_builtin_quant_dtypes["qint8"], scale, None) | |||||
def qint32(scale): | def qint32(scale): | ||||
@@ -116,7 +166,7 @@ def qint32(scale): | |||||
Construct a quantized int32 data type with ``scale`` (float). The real value | Construct a quantized int32 data type with ``scale`` (float). The real value | ||||
represented by a qint32 data type is float_val = scale * int32_val | represented by a qint32 data type is float_val = scale * int32_val | ||||
""" | """ | ||||
return get_quantized_dtype("qint32", scale, None) | |||||
return create_quantized_dtype(_builtin_quant_dtypes["qint32"], scale, None) | |||||
def quint4(scale, zero_point): | def quint4(scale, zero_point): | ||||
@@ -125,7 +175,7 @@ def quint4(scale, zero_point): | |||||
``zero_point`` (uint8). The real value represented by a quint4 data type is | ``zero_point`` (uint8). The real value represented by a quint4 data type is | ||||
float_val = scale * (uint4_val - zero_point) | float_val = scale * (uint4_val - zero_point) | ||||
""" | """ | ||||
return get_quantized_dtype("quint4", scale, zero_point) | |||||
return create_quantized_dtype(_builtin_quant_dtypes["quint4"], scale, zero_point) | |||||
def qint4(scale): | def qint4(scale): | ||||
@@ -133,42 +183,48 @@ def qint4(scale): | |||||
Construct a quantized int4 data type with ``scale`` (float). The real value | Construct a quantized int4 data type with ``scale`` (float). The real value | ||||
represented by a qint4 data type is float_val = scale * int4_val | represented by a qint4 data type is float_val = scale * int4_val | ||||
""" | """ | ||||
return get_quantized_dtype("qint4", scale, None) | |||||
return create_quantized_dtype(_builtin_quant_dtypes["qint4"], scale, None) | |||||
def _convert_to_quantized_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str): | |||||
metadata = _metadata_dict[dtype_str] | |||||
arr_metadata = dtype.metadata["mgb_dtype"] | |||||
def _convert_to_quantized_dtype( | |||||
arr: np.ndarray, dtype: np.dtype, dtype_meta: QuantDtypeMeta | |||||
): | |||||
if not isinstance(arr, np.ndarray): | if not isinstance(arr, np.ndarray): | ||||
raise ValueError("arr parameter should be instance of np.ndarray") | raise ValueError("arr parameter should be instance of np.ndarray") | ||||
if not is_quantize(dtype) or arr_metadata["name"] != metadata.name: | |||||
raise ValueError("dtype parameter should be a {} dtype".format(dtype_str)) | |||||
is_unsigned = metadata.is_unsigned | |||||
if is_unsigned: | |||||
if ( | |||||
not is_quantize(dtype) | |||||
or dtype.metadata["mgb_dtype"]["name"] != dtype_meta.cname | |||||
): | |||||
raise ValueError("dtype parameter should be a {} dtype".format(dtype_meta)) | |||||
arr_metadata = dtype.metadata["mgb_dtype"] | |||||
if dtype_meta.is_unsigned: | |||||
scale, zp = ( | scale, zp = ( | ||||
arr_metadata["scale"], | arr_metadata["scale"], | ||||
arr_metadata["zero_point"], | arr_metadata["zero_point"], | ||||
) | ) | ||||
return ( | return ( | ||||
(np.round(arr / scale) + zp) | (np.round(arr / scale) + zp) | ||||
.clip(metadata.qmin, metadata.qmax) | |||||
.clip(dtype_meta.qmin, dtype_meta.qmax) | |||||
.astype(dtype) | .astype(dtype) | ||||
) | ) | ||||
else: | else: | ||||
# don't trick to combine with is_unsigned, seeing ``get_quantized_dtype`` | # don't trick to combine with is_unsigned, seeing ``get_quantized_dtype`` | ||||
scale = arr_metadata["scale"] | scale = arr_metadata["scale"] | ||||
return np.round(arr / scale).clip(metadata.qmin, metadata.qmax).astype(dtype) | |||||
return ( | |||||
np.round(arr / scale).clip(dtype_meta.qmin, dtype_meta.qmax).astype(dtype) | |||||
) | |||||
def _convert_from_quantized_dtype(arr: np.ndarray, dtype_str: str): | |||||
metadata = _metadata_dict[dtype_str] | |||||
arr_metadata = arr.dtype.metadata["mgb_dtype"] | |||||
def _convert_from_quantized_dtype(arr: np.ndarray, dtype_meta: QuantDtypeMeta): | |||||
if not isinstance(arr, np.ndarray): | if not isinstance(arr, np.ndarray): | ||||
raise ValueError("arr parameter should be instance of np.ndarray") | raise ValueError("arr parameter should be instance of np.ndarray") | ||||
if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata.name: | |||||
raise ValueError("arr's dtype should be a {} dtype".format(dtype_str)) | |||||
is_unsigned = metadata.is_unsigned | |||||
if is_unsigned: | |||||
if ( | |||||
not is_quantize(arr.dtype) | |||||
or arr.dtype.metadata["mgb_dtype"]["name"] != dtype_meta.cname | |||||
): | |||||
raise ValueError("arr's dtype should be a {} dtype".format(dtype_meta)) | |||||
arr_metadata = arr.dtype.metadata["mgb_dtype"] | |||||
if dtype_meta.is_unsigned: | |||||
scale, zp = ( | scale, zp = ( | ||||
arr_metadata["scale"], | arr_metadata["scale"], | ||||
arr_metadata["zero_point"], | arr_metadata["zero_point"], | ||||
@@ -187,7 +243,7 @@ def convert_to_quint8(arr: np.ndarray, q: np.dtype): | |||||
:param arr: Input ndarray. | :param arr: Input ndarray. | ||||
:param q: Target data type, should be a quint8. | :param q: Target data type, should be a quint8. | ||||
""" | """ | ||||
return _convert_to_quantized_dtype(arr, q, "quint8") | |||||
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["quint8"]) | |||||
def convert_from_quint8(arr: np.ndarray): | def convert_from_quint8(arr: np.ndarray): | ||||
@@ -196,7 +252,7 @@ def convert_from_quint8(arr: np.ndarray): | |||||
:param arr: Input ndarray. | :param arr: Input ndarray. | ||||
""" | """ | ||||
return _convert_from_quantized_dtype(arr, "quint8") | |||||
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["quint8"]) | |||||
def convert_to_qint8(arr: np.ndarray, q: np.dtype): | def convert_to_qint8(arr: np.ndarray, q: np.dtype): | ||||
@@ -206,7 +262,7 @@ def convert_to_qint8(arr: np.ndarray, q: np.dtype): | |||||
:param arr: Input ndarray. | :param arr: Input ndarray. | ||||
:param q: Target data type, should be a qint8. | :param q: Target data type, should be a qint8. | ||||
""" | """ | ||||
return _convert_to_quantized_dtype(arr, q, "qint8") | |||||
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint8"]) | |||||
def convert_from_qint8(arr: np.ndarray): | def convert_from_qint8(arr: np.ndarray): | ||||
@@ -215,7 +271,7 @@ def convert_from_qint8(arr: np.ndarray): | |||||
:param arr: Input ndarray. | :param arr: Input ndarray. | ||||
""" | """ | ||||
return _convert_from_quantized_dtype(arr, "qint8") | |||||
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint8"]) | |||||
def convert_to_qint32(arr: np.ndarray, q: np.dtype): | def convert_to_qint32(arr: np.ndarray, q: np.dtype): | ||||
@@ -225,7 +281,7 @@ def convert_to_qint32(arr: np.ndarray, q: np.dtype): | |||||
:param arr: Input ndarray. | :param arr: Input ndarray. | ||||
:param q: Target data type, should be a qint8. | :param q: Target data type, should be a qint8. | ||||
""" | """ | ||||
return _convert_to_quantized_dtype(arr, q, "qint32") | |||||
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint32"]) | |||||
def convert_from_qint32(arr): | def convert_from_qint32(arr): | ||||
@@ -234,7 +290,7 @@ def convert_from_qint32(arr): | |||||
:param arr: Input ndarray. | :param arr: Input ndarray. | ||||
""" | """ | ||||
return _convert_from_quantized_dtype(arr, "qint32") | |||||
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint32"]) | |||||
def convert_to_quint4(arr: np.ndarray, q: np.dtype): | def convert_to_quint4(arr: np.ndarray, q: np.dtype): | ||||
@@ -244,7 +300,7 @@ def convert_to_quint4(arr: np.ndarray, q: np.dtype): | |||||
:param arr: Input ndarray. | :param arr: Input ndarray. | ||||
:param q: Target data type, should be a quint4. | :param q: Target data type, should be a quint4. | ||||
""" | """ | ||||
return _convert_to_quantized_dtype(arr, q, "quint4") | |||||
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["quint4"]) | |||||
def convert_from_quint4(arr: np.ndarray): | def convert_from_quint4(arr: np.ndarray): | ||||
@@ -253,7 +309,7 @@ def convert_from_quint4(arr: np.ndarray): | |||||
:param arr: Input ndarray. | :param arr: Input ndarray. | ||||
""" | """ | ||||
return _convert_from_quantized_dtype(arr, "quint4") | |||||
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["quint4"]) | |||||
def convert_to_qint4(arr: np.ndarray, q: np.dtype): | def convert_to_qint4(arr: np.ndarray, q: np.dtype): | ||||
@@ -263,7 +319,7 @@ def convert_to_qint4(arr: np.ndarray, q: np.dtype): | |||||
:param arr: Input ndarray. | :param arr: Input ndarray. | ||||
:param q: Target data type, should be a qint4. | :param q: Target data type, should be a qint4. | ||||
""" | """ | ||||
return _convert_to_quantized_dtype(arr, q, "qint4") | |||||
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint4"]) | |||||
def convert_from_qint4(arr: np.ndarray): | def convert_from_qint4(arr: np.ndarray): | ||||
@@ -272,4 +328,4 @@ def convert_from_qint4(arr: np.ndarray): | |||||
:param arr: Input ndarray. | :param arr: Input ndarray. | ||||
""" | """ | ||||
return _convert_from_quantized_dtype(arr, "qint4") | |||||
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint4"]) |
@@ -203,7 +203,7 @@ def conv_transpose2d( | |||||
assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT" | assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT" | ||||
if groups != 1: | if groups != 1: | ||||
raise NotImplementedError("TODO") | |||||
raise NotImplementedError("group transposed conv2d is not supported yet.") | |||||
stride_h, stride_w = expand_hw(stride) | stride_h, stride_w = expand_hw(stride) | ||||
pad_h, pad_w = expand_hw(padding) | pad_h, pad_w = expand_hw(padding) | ||||
@@ -13,7 +13,6 @@ import itertools | |||||
import json | import json | ||||
import os | import os | ||||
import typing | import typing | ||||
import warnings | |||||
import weakref | import weakref | ||||
import numpy as np | import numpy as np | ||||
@@ -5,7 +5,6 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import warnings | |||||
from abc import ABCMeta, abstractmethod | from abc import ABCMeta, abstractmethod | ||||
from collections import OrderedDict | from collections import OrderedDict | ||||
from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | ||||
@@ -204,10 +203,9 @@ class Module(metaclass=ABCMeta): | |||||
if "requires_grad" in kwargs: | if "requires_grad" in kwargs: | ||||
del kwargs["requires_grad"] | del kwargs["requires_grad"] | ||||
warnings.warn( | |||||
logger.warning( | |||||
"Tensor currently has no requires_grad attribute " | "Tensor currently has no requires_grad attribute " | ||||
"so requires_grad argument is ignored here", | |||||
DeprecationWarning, | |||||
"so requires_grad argument is ignored here" | |||||
) | ) | ||||
def predicate(obj) -> bool: | def predicate(obj) -> bool: | ||||
@@ -232,10 +230,9 @@ class Module(metaclass=ABCMeta): | |||||
if "requires_grad" in kwargs: | if "requires_grad" in kwargs: | ||||
del kwargs["requires_grad"] | del kwargs["requires_grad"] | ||||
warnings.warn( | |||||
logger.warning( | |||||
"Tensor currently has no requires_grad attribute " | "Tensor currently has no requires_grad attribute " | ||||
"so requires_grad argument is ignored here", | |||||
DeprecationWarning, | |||||
"so requires_grad argument is ignored here" | |||||
) | ) | ||||
def predicate(obj) -> bool: | def predicate(obj) -> bool: | ||||
@@ -7,7 +7,10 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from abc import abstractmethod | from abc import abstractmethod | ||||
from ...quantization import FakeQuantize, Observer, QConfig | |||||
# avoid circular reference | |||||
from ...quantization.fake_quant import FakeQuantize | |||||
from ...quantization.observer import Observer | |||||
from ...quantization.qconfig import QConfig | |||||
from ...tensor import Tensor | from ...tensor import Tensor | ||||
from ..module import Module | from ..module import Module | ||||
@@ -73,19 +76,19 @@ class QATModule(Module): | |||||
# do observer | # do observer | ||||
if observer is None: | if observer is None: | ||||
oup = target | oup = target | ||||
q_dict = None | |||||
qparams = None | |||||
else: | else: | ||||
oup = observer(target) | oup = observer(target) | ||||
q_dict = observer.get_qparams() | |||||
qparams = observer.get_qparams() | |||||
# do fake quant | # do fake quant | ||||
if fake_quant is not None: | if fake_quant is not None: | ||||
oup = fake_quant(oup, q_dict) | |||||
oup = fake_quant(oup, qparams) | |||||
# use qparams of fake_quant if have. | # use qparams of fake_quant if have. | ||||
if hasattr(fake_quant, "get_qparams"): | if hasattr(fake_quant, "get_qparams"): | ||||
q_dict = fake_quant.get_qparams() | |||||
qparams = fake_quant.get_qparams() | |||||
# set to tensor qparams. | # set to tensor qparams. | ||||
if q_dict is not None: | |||||
oup.q_dict.update(q_dict) | |||||
if qparams is not None: | |||||
oup.qparams.update(qparams) | |||||
return oup | return oup | ||||
def apply_quant_weight(self, target: Tensor): | def apply_quant_weight(self, target: Tensor): | ||||
@@ -118,7 +121,7 @@ class QATModule(Module): | |||||
Get weight's quantization dtype as the method from ``qconfig``. | Get weight's quantization dtype as the method from ``qconfig``. | ||||
""" | """ | ||||
return self._get_method_result( | return self._get_method_result( | ||||
"get_dtype", self.weight_fake_quant, self.weight_observer | |||||
"get_quantized_dtype", self.weight_fake_quant, self.weight_observer | |||||
) | ) | ||||
def get_activation_dtype(self): | def get_activation_dtype(self): | ||||
@@ -126,7 +129,7 @@ class QATModule(Module): | |||||
Get activation's quantization dtype as the method from ``qconfig``. | Get activation's quantization dtype as the method from ``qconfig``. | ||||
""" | """ | ||||
return self._get_method_result( | return self._get_method_result( | ||||
"get_dtype", self.act_fake_quant, self.act_observer | |||||
"get_quantized_dtype", self.act_fake_quant, self.act_observer | |||||
) | ) | ||||
def get_weight_qparams(self): | def get_weight_qparams(self): | ||||
@@ -7,8 +7,7 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .fake_quant import FakeQuantize | from .fake_quant import FakeQuantize | ||||
from .internal_fake_quant import * | |||||
from .observer import HistogramObserver, Observer | |||||
from .observer import Observer | |||||
from .qconfig import ( | from .qconfig import ( | ||||
QConfig, | QConfig, | ||||
calibration_qconfig, | calibration_qconfig, | ||||
@@ -20,4 +19,15 @@ from .qconfig import ( | |||||
sync_ema_fakequant_qconfig, | sync_ema_fakequant_qconfig, | ||||
tqt_qconfig, | tqt_qconfig, | ||||
) | ) | ||||
from .utils import QuantMode | |||||
from .quantize import ( | |||||
apply_easy_quant, | |||||
disable_fake_quant, | |||||
disable_observer, | |||||
enable_fake_quant, | |||||
enable_observer, | |||||
propagate_qconfig, | |||||
quantize, | |||||
quantize_qat, | |||||
reset_qconfig, | |||||
) | |||||
from .utils import QParams, QuantMode, create_qparams |
@@ -6,40 +6,48 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import math | import math | ||||
from typing import Union | |||||
from .. import functional as F | from .. import functional as F | ||||
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype | |||||
from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes | |||||
from ..logger import get_logger | |||||
from ..module import Module | from ..module import Module | ||||
from ..tensor import Parameter, Tensor | |||||
from .utils import QuantMode, fake_quant_tensor, get_qparam_dict, tqt_forward | |||||
from ..tensor import Parameter | |||||
from .utils import ( | |||||
QParams, | |||||
QParamsModuleMixin, | |||||
QuantMode, | |||||
create_qparams, | |||||
fake_quant_tensor, | |||||
tqt_forward, | |||||
) | |||||
logger = get_logger(__name__) | |||||
class _FakeQuantize(Module): | |||||
r""" | |||||
A Basic Fake Quant module. | |||||
:param dtype: a string indicating the target quantization type of input. | |||||
:param narrow_range: whether the absolute value of ``qmin`` is the same as ``qmax``, | |||||
instead of 1 greater. Usually True for weight and False for activation. | |||||
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``. | |||||
""" | |||||
class _FakeQuantize(Module): | |||||
def __init__( | def __init__( | ||||
self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs | |||||
self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs | |||||
): | ): | ||||
super().__init__() | super().__init__() | ||||
if not dtype in _metadata_dict.keys(): | |||||
raise ValueError( | |||||
"unknown dtype: {}, only support {}".format( | |||||
dtype, _metadata_dict.keys() | |||||
if isinstance(dtype, str): | |||||
if not dtype in _builtin_quant_dtypes: | |||||
raise ValueError( | |||||
"unknown dtype: {}, only support {}".format( | |||||
dtype, _builtin_quant_dtypes.keys() | |||||
) | |||||
) | ) | ||||
dtype = _builtin_quant_dtypes[dtype] | |||||
if "narrow_range" in kwargs: | |||||
del kwargs["narrow_range"] | |||||
logger.warning( | |||||
"FakeQuantize currently has no narrow_range param " | |||||
"so it is ignored here", | |||||
exc_info=DeprecationWarning, | |||||
) | ) | ||||
self.dtype = dtype | self.dtype = dtype | ||||
self.narrow_range = narrow_range | |||||
self.qmin = ( | |||||
-_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin | |||||
) | |||||
self.qmax = _metadata_dict[dtype].qmax | |||||
self.qmin = dtype.qmin | |||||
self.qmax = dtype.qmax | |||||
self.enabled = enable | self.enabled = enable | ||||
def enable(self): | def enable(self): | ||||
@@ -48,61 +56,64 @@ class _FakeQuantize(Module): | |||||
def disable(self): | def disable(self): | ||||
self.enabled = False | self.enabled = False | ||||
def fake_quant_forward(self, inp, q_dict=None): | |||||
return inp | |||||
def fake_quant_forward(self, inp, qparams: QParams = None): | |||||
raise NotImplementedError | |||||
def normal_foward(self, inp, q_dict=None): | |||||
def normal_foward(self, inp, qparams: QParams = None): | |||||
return inp | return inp | ||||
def forward(self, inp, q_dict=None): | |||||
def forward(self, inp, qparams: QParams = None): | |||||
if self.enabled: | if self.enabled: | ||||
return self.fake_quant_forward(inp, q_dict=q_dict) | |||||
return self.fake_quant_forward(inp, qparams=qparams) | |||||
else: | else: | ||||
return self.normal_foward(inp, q_dict=q_dict) | |||||
return self.normal_foward(inp, qparams=qparams) | |||||
class TQT(_FakeQuantize): | |||||
class TQT(_FakeQuantize, QParamsModuleMixin): | |||||
r""" | r""" | ||||
TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds | TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds | ||||
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks. | for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks. | ||||
:param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target | |||||
quantization dtype of input. | |||||
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``. | |||||
""" | """ | ||||
def __init__( | def __init__( | ||||
self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs | |||||
self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs | |||||
): | ): | ||||
super().__init__(dtype, narrow_range, enable, **kwargs) | |||||
super().__init__(dtype, enable, **kwargs) | |||||
self.scale = Parameter(0.0, dtype="float32") | self.scale = Parameter(0.0, dtype="float32") | ||||
def fake_quant_forward(self, inp, q_dict=None): | |||||
def fake_quant_forward(self, inp, qparams: QParams = None): | |||||
# when enable, TQT will do fakequant forward, finetune the scale | # when enable, TQT will do fakequant forward, finetune the scale | ||||
return tqt_forward(self.qmin, self.qmax, inp, self.scale) | return tqt_forward(self.qmin, self.qmax, inp, self.scale) | ||||
def get_qparams(self): | |||||
q_dict = get_qparam_dict(QuantMode.SYMMERTIC) | |||||
q_dict["scale"] = 2 ** self.scale.detach() | |||||
return q_dict | |||||
def set_qparams(self, q_dict): | |||||
def set_qparams(self, qparams: QParams): | |||||
assert ( | assert ( | ||||
q_dict["mode"] == QuantMode.SYMMERTIC | |||||
qparams.mode == QuantMode.SYMMERTIC | |||||
), "only symmetric quantization is supported by TQT" | ), "only symmetric quantization is supported by TQT" | ||||
if "scale" not in q_dict or q_dict["scale"] is None: | |||||
if qparams.scale is None: | |||||
raise AssertionError("Can not get an initialized scale") | raise AssertionError("Can not get an initialized scale") | ||||
self.scale._reset(F.log(q_dict["scale"]) / math.log(2)) | |||||
self.scale[...] = F.log(qparams.scale) / math.log(2) | |||||
def get_dtype(self): | |||||
q_dict = self.get_qparams() | |||||
scale = None if "scale" not in q_dict else q_dict["scale"].numpy() | |||||
zero_point = ( | |||||
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy() | |||||
) | |||||
return get_quantized_dtype(self.dtype, scale, zero_point) | |||||
def get_qparams(self): | |||||
return create_qparams(QuantMode.SYMMERTIC, self.dtype, scale=2 ** self.scale) | |||||
class FakeQuantize(_FakeQuantize): | class FakeQuantize(_FakeQuantize): | ||||
r""" | r""" | ||||
A module to do quant and dequant according to observer's scale and zero_point. | A module to do quant and dequant according to observer's scale and zero_point. | ||||
:param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target | |||||
quantization dtype of input. | |||||
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``. | |||||
""" | """ | ||||
def fake_quant_forward(self, inp, q_dict=None): | |||||
return fake_quant_tensor(inp, self.qmin, self.qmax, q_dict) | |||||
def fake_quant_forward(self, inp, qparams: QParams = None): | |||||
assert ( | |||||
qparams.dtype_meta is self.dtype | |||||
), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format( | |||||
qparams.dtype_meta, self.dtype | |||||
) | |||||
return fake_quant_tensor(inp, qparams) |
@@ -16,4 +16,6 @@ from ..autodiff import Function | |||||
from .fake_quant import _FakeQuantize | from .fake_quant import _FakeQuantize | ||||
from .observer import MinMaxObserver | from .observer import MinMaxObserver | ||||
from .qconfig import QConfig | from .qconfig import QConfig | ||||
from .utils import QParams | |||||
@@ -8,51 +8,51 @@ | |||||
import math | import math | ||||
from abc import abstractmethod | from abc import abstractmethod | ||||
from copy import deepcopy | from copy import deepcopy | ||||
from typing import Union | |||||
import numpy as np | import numpy as np | ||||
from .. import functional as F | from .. import functional as F | ||||
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype | |||||
from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes | |||||
from ..distributed import WORLD, get_rank, is_distributed | from ..distributed import WORLD, get_rank, is_distributed | ||||
from ..functional.distributed import all_reduce_max, all_reduce_min | from ..functional.distributed import all_reduce_max, all_reduce_min | ||||
from ..logger import get_logger | |||||
from ..module import Module | from ..module import Module | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from .utils import QuantMode, get_qparam_dict | |||||
from .utils import QParams, QParamsModuleMixin, QuantMode, create_qparams | |||||
logger = get_logger(__name__) | |||||
class Observer(Module): | |||||
class Observer(Module, QParamsModuleMixin): | |||||
r""" | r""" | ||||
A base class for Observer Module. | A base class for Observer Module. | ||||
:param dtype: a string indicating to collect scale and zero_point of which dtype. | :param dtype: a string indicating to collect scale and zero_point of which dtype. | ||||
:param narrow_range: whether the absolute value of ``qmin`` is the same as ``qmax``, | |||||
instead of 1 greater. Usually True for weight and False for activation. | |||||
""" | """ | ||||
def __init__(self, dtype: str, narrow_range: bool = False, **kwargs): | |||||
def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs): | |||||
super().__init__() | super().__init__() | ||||
if dtype not in _metadata_dict.keys(): | |||||
raise ValueError( | |||||
"unknown dtype: {}, only support {}".format( | |||||
dtype, _metadata_dict.keys() | |||||
if isinstance(dtype, str): | |||||
if not dtype in _builtin_quant_dtypes: | |||||
raise ValueError( | |||||
"unknown dtype: {}, only support {}".format( | |||||
dtype, _builtin_quant_dtypes.keys() | |||||
) | |||||
) | ) | ||||
dtype = _builtin_quant_dtypes[dtype] | |||||
if "narrow_range" in kwargs: | |||||
del kwargs["narrow_range"] | |||||
logger.warning( | |||||
"FakeQuantize currently has no narrow_range param " | |||||
"so it is ignored here", | |||||
exc_info=DeprecationWarning, | |||||
) | ) | ||||
self.dtype = dtype | self.dtype = dtype | ||||
self.narrow_range = narrow_range | |||||
self.qmin = ( | |||||
-_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin | |||||
) | |||||
self.qmax = _metadata_dict[dtype].qmax | |||||
self.qmin = dtype.qmin | |||||
self.qmax = dtype.qmax | |||||
self.enabled = True | self.enabled = True | ||||
def get_dtype(self): | |||||
q_dict = self.get_qparams() | |||||
numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy() | |||||
numpy_zero_point = ( | |||||
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy() | |||||
) | |||||
return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point) | |||||
def enable(self): | def enable(self): | ||||
self.enabled = True | self.enabled = True | ||||
@@ -70,21 +70,16 @@ class Observer(Module): | |||||
def forward(self, x): | def forward(self, x): | ||||
pass | pass | ||||
@abstractmethod | |||||
def get_qparams(self, **kwargs): | |||||
pass | |||||
class MinMaxObserver(Observer): | class MinMaxObserver(Observer): | ||||
def __init__( | def __init__( | ||||
self, | self, | ||||
mode=QuantMode.SYMMERTIC, | |||||
eps=0.00001, | |||||
dtype="qint8", | |||||
narrow_range: bool = False, | |||||
mode: QuantMode = QuantMode.SYMMERTIC, | |||||
eps: float = 0.00001, | |||||
dtype: Union[str, QuantDtypeMeta] = "qint8", | |||||
**kwargs | **kwargs | ||||
): | ): | ||||
super().__init__(dtype, narrow_range, **kwargs) | |||||
super().__init__(dtype, **kwargs) | |||||
self.mode = mode | self.mode = mode | ||||
self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32) | self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32) | ||||
self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32) | self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32) | ||||
@@ -93,26 +88,22 @@ class MinMaxObserver(Observer): | |||||
def _calculate_qparams(self, inp_min_val, inp_max_val): | def _calculate_qparams(self, inp_min_val, inp_max_val): | ||||
min_val = F.minimum(0.0, inp_min_val) | min_val = F.minimum(0.0, inp_min_val) | ||||
max_val = F.maximum(0.0, inp_max_val) | max_val = F.maximum(0.0, inp_max_val) | ||||
q_dict = get_qparam_dict(self.mode) | |||||
q_dict["min_val"] = inp_min_val | |||||
q_dict["max_val"] = inp_max_val | |||||
q_dict["enable_observer"] = self.enable | |||||
if self.mode == QuantMode.SYMMERTIC: | if self.mode == QuantMode.SYMMERTIC: | ||||
symmetric_max_vals = F.maximum(-min_val, max_val) | symmetric_max_vals = F.maximum(-min_val, max_val) | ||||
# use maximun to avoid scale too small at the begin | # use maximun to avoid scale too small at the begin | ||||
q_dict["scale"] = F.maximum( | |||||
scale = F.maximum( | |||||
symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit | symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit | ||||
) | ) | ||||
# zero_point = self.zero_point | |||||
zero_point = None | |||||
else: | else: | ||||
# use maximun to avoid scale too small at the begin | # use maximun to avoid scale too small at the begin | ||||
q_dict["scale"] = F.maximum( | |||||
scale = F.maximum( | |||||
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit | (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit | ||||
) | ) | ||||
# caculate zero_point | # caculate zero_point | ||||
q_dict["zero_point"] = self.qmin - F.round(min_val / q_dict["scale"]) | |||||
zero_point = self.qmin - F.round((min_val / scale)) | |||||
return q_dict | |||||
return create_qparams(self.mode, self.dtype, scale=scale, zero_point=zero_point) | |||||
def get_qparams(self): | def get_qparams(self): | ||||
return self._calculate_qparams(self.min_val, self.max_val) | return self._calculate_qparams(self.min_val, self.max_val) | ||||
@@ -122,8 +113,8 @@ class MinMaxObserver(Observer): | |||||
# stop gradient | # stop gradient | ||||
x = x_orig.detach() | x = x_orig.detach() | ||||
# find max and min | # find max and min | ||||
self.min_val._reset(F.minimum(self.min_val, x.min())) | |||||
self.max_val._reset(F.maximum(self.max_val, x.max())) | |||||
self.min_val[...] = F.minimum(self.min_val, x.min()) | |||||
self.max_val[...] = F.maximum(self.max_val, x.max()) | |||||
return x_orig | return x_orig | ||||
@@ -137,42 +128,43 @@ class SyncMinMaxObserver(MinMaxObserver): | |||||
else: | else: | ||||
min_x = x.min() | min_x = x.min() | ||||
max_x = x.max() | max_x = x.max() | ||||
self.min_val._reset(F.minimum(self.min_val, min_x)) | |||||
self.max_val._reset(F.maximum(self.max_val, max_x)) | |||||
self.min_val[...] = F.minimum(self.min_val, min_x) | |||||
self.max_val[...] = F.maximum(self.max_val, max_x) | |||||
return x_orig | return x_orig | ||||
class ExponentialMovingAverageObserver(MinMaxObserver): | class ExponentialMovingAverageObserver(MinMaxObserver): | ||||
def __init__( | def __init__( | ||||
self, | self, | ||||
momentum=0.9, | |||||
mode=QuantMode.SYMMERTIC, | |||||
eps=0.00001, | |||||
dtype="qint8", | |||||
narrow_range: bool = False, | |||||
momentum: float = 0.9, | |||||
mode: QuantMode = QuantMode.SYMMERTIC, | |||||
eps: float = 0.00001, | |||||
dtype: Union[str, QuantDtypeMeta] = "qint8", | |||||
**kwargs | **kwargs | ||||
): | ): | ||||
super().__init__(mode, eps, dtype, narrow_range, **kwargs) | |||||
super().__init__(mode, eps, dtype, **kwargs) | |||||
self.momentum = Tensor(momentum, dtype="float32") | self.momentum = Tensor(momentum, dtype="float32") | ||||
# used to avoid if-clauses in the first forward which is not supported | |||||
# in trace mode. | |||||
self.runtime_momentum = Tensor(0.0) | self.runtime_momentum = Tensor(0.0) | ||||
def set_momentum(self, momentum): | def set_momentum(self, momentum): | ||||
self.momentum = Tenosr(momentum, dtype="float32") | |||||
self.momentum = Tensor(momentum, dtype="float32") | |||||
def forward(self, x_orig): | def forward(self, x_orig): | ||||
if self.enabled: | if self.enabled: | ||||
# stop gradient | # stop gradient | ||||
x = x_orig.detach() | x = x_orig.detach() | ||||
# Exponential Moving Average | # Exponential Moving Average | ||||
self.min_val._reset( | |||||
self.min_val[...] = ( | |||||
self.min_val * self.runtime_momentum | self.min_val * self.runtime_momentum | ||||
+ (1 - self.runtime_momentum) * x.min() | + (1 - self.runtime_momentum) * x.min() | ||||
) | ) | ||||
self.max_val._reset( | |||||
self.max_val[...] = ( | |||||
self.max_val * self.runtime_momentum | self.max_val * self.runtime_momentum | ||||
+ (1 - self.runtime_momentum) * x.max() | + (1 - self.runtime_momentum) * x.max() | ||||
) | ) | ||||
self.runtime_momentum = self.momentum | |||||
self.runtime_momentum[...] = self.momentum | |||||
return x_orig | return x_orig | ||||
@@ -187,33 +179,34 @@ class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver): | |||||
else: | else: | ||||
min_x = x.min() | min_x = x.min() | ||||
max_x = x.max() | max_x = x.max() | ||||
self.min_val._reset( | |||||
self.min_val[...] = ( | |||||
self.min_val * self.runtime_momentum | self.min_val * self.runtime_momentum | ||||
+ (1 - self.runtime_momentum) * min_x | + (1 - self.runtime_momentum) * min_x | ||||
) | ) | ||||
self.max_val._reset( | |||||
self.max_val[...] = ( | |||||
self.max_val * self.runtime_momentum | self.max_val * self.runtime_momentum | ||||
+ (1 - self.runtime_momentum) * max_x | + (1 - self.runtime_momentum) * max_x | ||||
) | ) | ||||
self.runtime_momentum = self.momentum | |||||
self.runtime_momentum[...] = self.momentum | |||||
return x_orig | return x_orig | ||||
class HistogramObserver(MinMaxObserver): | class HistogramObserver(MinMaxObserver): | ||||
def __init__( | def __init__( | ||||
self, | self, | ||||
bins=2048, | |||||
upsample_rate=128, | |||||
mode=QuantMode.SYMMERTIC, | |||||
eps=0.00001, | |||||
dtype="qint8", | |||||
narrow_range: bool = False, | |||||
bins: int = 2048, | |||||
upsample_rate: int = 128, | |||||
mode: QuantMode = QuantMode.SYMMERTIC, | |||||
eps: float = 0.00001, | |||||
dtype: Union[str, QuantDtypeMeta] = "qint8", | |||||
**kwargs | **kwargs | ||||
): | ): | ||||
super().__init__(mode, eps, dtype, narrow_range, **kwargs) | |||||
super().__init__(mode, eps, dtype, **kwargs) | |||||
self.bins = bins | self.bins = bins | ||||
self.upsample_rate = upsample_rate | self.upsample_rate = upsample_rate | ||||
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 | |||||
self.dst_nbins = ( | |||||
_builtin_quant_dtypes[dtype].qmax - _builtin_quant_dtypes[dtype].qmin + 1 | |||||
) | |||||
self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32") | self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32") | ||||
def _non_linear_param_search(self): | def _non_linear_param_search(self): | ||||
@@ -450,34 +443,45 @@ class HistogramObserver(MinMaxObserver): | |||||
class PassiveObserver(Observer): | class PassiveObserver(Observer): | ||||
r""" | r""" | ||||
This class can be set :attr:`scale` derectly. | |||||
An Observer that supports setting :attr:`scale` directly. | |||||
""" | """ | ||||
def __init__(self, dtype: str, narrow_range: bool = False, **kwargs): | |||||
super().__init__(dtype, narrow_range, **kwargs) | |||||
self.q_dict = None | |||||
def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs): | |||||
super().__init__(dtype, **kwargs) | |||||
self.qparams = None | |||||
self.orig_scale = None | self.orig_scale = None | ||||
@property | @property | ||||
def scale(self): | def scale(self): | ||||
return self.q_dict["scale"] | |||||
return self.qparams.scale | |||||
@scale.setter | @scale.setter | ||||
def scale(self, value): | |||||
assert value > 0 | |||||
self.q_dict["scale"][...] = Tensor(value) | |||||
def scale(self, value: np.ndarray): | |||||
assert np.all(value > 0) | |||||
self.qparams.scale[...] = Tensor(value) | |||||
def get_qparams(self): | def get_qparams(self): | ||||
return self.q_dict | |||||
return self.qparams | |||||
def set_qparams(self, q_dict): | |||||
self.q_dict = deepcopy(q_dict) | |||||
if "scale" not in q_dict or q_dict["scale"] is None: | |||||
def set_qparams(self, qparams: QParams): | |||||
""" | |||||
:param qparams: used to set initial scale. | |||||
""" | |||||
self.qparams = deepcopy(qparams) | |||||
if qparams.scale is None: | |||||
raise AssertionError("Can not get an initialized scale") | raise AssertionError("Can not get an initialized scale") | ||||
self.orig_scale = q_dict["scale"].numpy() | |||||
if qparams.dtype_meta is None: | |||||
qparams.dtype_meta = self.dtype | |||||
else: | |||||
assert ( | |||||
qparams.dtype_meta is self.dtype | |||||
), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format( | |||||
qparams.dtype_meta, self.dtype | |||||
) | |||||
self.orig_scale = qparams.scale.numpy() | |||||
def forward(self, x): | def forward(self, x): | ||||
r""" | r""" | ||||
Just return input because :attr:`q_dict` is set by :func:`~.apply_easy_quant`. | |||||
Just return input because :attr:`qparams` is set by :func:`~.apply_easy_quant`. | |||||
""" | """ | ||||
return x | return x |
@@ -5,6 +5,7 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from collections import namedtuple | |||||
from functools import partial | from functools import partial | ||||
from ..module import Module | from ..module import Module | ||||
@@ -19,7 +20,13 @@ from .observer import ( | |||||
) | ) | ||||
class QConfig: | |||||
# use namedtuple to make class immutable, comparable and easy to print | |||||
class QConfig( | |||||
namedtuple( | |||||
"QConfig", | |||||
["weight_observer", "act_observer", "weight_fake_quant", "act_fake_quant"], | |||||
) | |||||
): | |||||
r""" | r""" | ||||
A config class indicating how to do quantize toward :class:`~.QATModule`'s | A config class indicating how to do quantize toward :class:`~.QATModule`'s | ||||
``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage. | ``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage. | ||||
@@ -37,90 +44,66 @@ class QConfig: | |||||
# Default EMA QConfig for QAT. | # Default EMA QConfig for QAT. | ||||
ema_fakequant_qconfig = QConfig( | ema_fakequant_qconfig = QConfig( | ||||
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||||
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False), | |||||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||||
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"), | |||||
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8"), | |||||
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"), | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint8"), | |||||
) | ) | ||||
Each parameter is a ``class`` rather than an instance. And we recommand using ``functools.partial`` | Each parameter is a ``class`` rather than an instance. And we recommand using ``functools.partial`` | ||||
to add initialization parameters of the ``class``, so that don't need to provide parameters in | to add initialization parameters of the ``class``, so that don't need to provide parameters in | ||||
:meth:`~.QATModule.set_qconfig`. | :meth:`~.QATModule.set_qconfig`. | ||||
Usually we set ``narrow_range`` of weight related paramters to ``True`` and of activation related | |||||
parameters to ``False``. For the result of multiplication and addition as ``a * b + c * d``, if | |||||
four variables are all -128 of dtype ``qint8``, then the result will be ``2^15`` and cause overflow. | |||||
Weights are commonly calculated in this way, so needed to narrow the range. | |||||
Usually we choose narrow version dtype (like ``qint8_narrow``) for weight related | |||||
paramters and normal version for activation related ones. For the result of | |||||
multiplication and addition as ``a * b + c * d``, if four variables are all -128 of | |||||
dtype ``qint8``, then the result will be ``2^15`` and cause overflow. | |||||
Weights are commonly calculated in this way, so need to narrow qmin to -127. | |||||
""" | """ | ||||
def __init__( | |||||
self, weight_observer, act_observer, weight_fake_quant, act_fake_quant | |||||
): | |||||
def __new__(cls, weight_observer, act_observer, weight_fake_quant, act_fake_quant): | |||||
if isinstance(act_observer, Module) or isinstance(weight_observer, Module): | if isinstance(act_observer, Module) or isinstance(weight_observer, Module): | ||||
raise ValueError( | raise ValueError( | ||||
"QConfig must not receive observer instance, please pass observer" | "QConfig must not receive observer instance, please pass observer" | ||||
" class generator using `partial(Observer, ...)` instead. Use" | " class generator using `partial(Observer, ...)` instead. Use" | ||||
" partial(MyObserver, x=1) to override arguments to constructor if needed" | " partial(MyObserver, x=1) to override arguments to constructor if needed" | ||||
) | ) | ||||
self.weight_observer = weight_observer | |||||
self.act_observer = act_observer | |||||
self.weight_fake_quant = weight_fake_quant | |||||
self.act_fake_quant = act_fake_quant | |||||
def __eq__(self, other): | |||||
def eq(a, b): | |||||
if isinstance(a, partial) and isinstance(b, partial): | |||||
return all( | |||||
[a.func == b.func, a.args == b.args, a.keywords == b.keywords] | |||||
) | |||||
else: | |||||
return a == b | |||||
return ( | |||||
eq(self.weight_observer, other.weight_observer) | |||||
and eq(self.act_observer, other.act_observer) | |||||
and eq(self.weight_fake_quant, other.weight_fake_quant) | |||||
and eq(self.act_fake_quant, other.act_fake_quant) | |||||
return super().__new__( | |||||
cls, weight_observer, act_observer, weight_fake_quant, act_fake_quant | |||||
) | ) | ||||
min_max_fakequant_qconfig = QConfig( | min_max_fakequant_qconfig = QConfig( | ||||
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||||
act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False), | |||||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||||
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"), | |||||
act_observer=partial(MinMaxObserver, dtype="qint8"), | |||||
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"), | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint8"), | |||||
) | ) | ||||
ema_fakequant_qconfig = QConfig( | ema_fakequant_qconfig = QConfig( | ||||
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||||
act_observer=partial( | |||||
ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False | |||||
), | |||||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||||
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"), | |||||
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8"), | |||||
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"), | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint8"), | |||||
) | ) | ||||
sync_ema_fakequant_qconfig = QConfig( | sync_ema_fakequant_qconfig = QConfig( | ||||
weight_observer=partial(SyncMinMaxObserver, dtype="qint8", narrow_range=True), | |||||
act_observer=partial( | |||||
SyncExponentialMovingAverageObserver, dtype="qint8", narrow_range=False | |||||
), | |||||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||||
weight_observer=partial(SyncMinMaxObserver, dtype="qint8_narrow"), | |||||
act_observer=partial(SyncExponentialMovingAverageObserver, dtype="qint8"), | |||||
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"), | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint8"), | |||||
) | ) | ||||
ema_lowbit_fakequant_qconfig = QConfig( | ema_lowbit_fakequant_qconfig = QConfig( | ||||
weight_observer=partial(MinMaxObserver, dtype="qint4", narrow_range=False), | |||||
act_observer=partial( | |||||
ExponentialMovingAverageObserver, dtype="qint4", narrow_range=False | |||||
), | |||||
weight_fake_quant=partial(FakeQuantize, dtype="qint4", narrow_range=False), | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint4", narrow_range=False), | |||||
weight_observer=partial(MinMaxObserver, dtype="qint4"), | |||||
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint4"), | |||||
weight_fake_quant=partial(FakeQuantize, dtype="qint4"), | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint4"), | |||||
) | ) | ||||
calibration_qconfig = QConfig( | calibration_qconfig = QConfig( | ||||
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||||
act_observer=partial(HistogramObserver, dtype="qint8", narrow_range=False), | |||||
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"), | |||||
act_observer=partial(HistogramObserver, dtype="qint8"), | |||||
weight_fake_quant=None, | weight_fake_quant=None, | ||||
act_fake_quant=None, | act_fake_quant=None, | ||||
) | ) | ||||
@@ -128,15 +111,15 @@ calibration_qconfig = QConfig( | |||||
tqt_qconfig = QConfig( | tqt_qconfig = QConfig( | ||||
weight_observer=None, | weight_observer=None, | ||||
act_observer=None, | act_observer=None, | ||||
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), | |||||
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), | |||||
weight_fake_quant=partial(TQT, dtype="qint8_narrow"), | |||||
act_fake_quant=partial(TQT, dtype="qint8"), | |||||
) | ) | ||||
passive_qconfig = QConfig( | passive_qconfig = QConfig( | ||||
weight_observer=partial(PassiveObserver, dtype="qint8", narrow_range=True), | |||||
act_observer=partial(PassiveObserver, dtype="qint8", narrow_range=False), | |||||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||||
weight_observer=partial(PassiveObserver, dtype="qint8_narrow"), | |||||
act_observer=partial(PassiveObserver, dtype="qint8"), | |||||
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"), | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint8"), | |||||
) | ) | ||||
easyquant_qconfig = passive_qconfig | easyquant_qconfig = passive_qconfig |
@@ -18,6 +18,7 @@ from ..module import qat as QAT | |||||
from ..module import quantized as Quantized | from ..module import quantized as Quantized | ||||
from ..module.qat import QATModule | from ..module.qat import QATModule | ||||
from ..module.quantized import QuantizedModule | from ..module.quantized import QuantizedModule | ||||
from ..tensor import Tensor | |||||
from .qconfig import QConfig, ema_fakequant_qconfig | from .qconfig import QConfig, ema_fakequant_qconfig | ||||
@@ -147,10 +148,10 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True): | |||||
if not inplace: | if not inplace: | ||||
module = deepcopy(module) | module = deepcopy(module) | ||||
def safe_call(func, q_dict): | |||||
def safe_call(func, qparams): | |||||
inst = func() if func is not None else None | inst = func() if func is not None else None | ||||
if inst is not None and getattr(inst, "set_qparams", None) is not None: | if inst is not None and getattr(inst, "set_qparams", None) is not None: | ||||
inst.set_qparams(q_dict) | |||||
inst.set_qparams(qparams) | |||||
return inst | return inst | ||||
def is_qat(mod: Module): | def is_qat(mod: Module): | ||||
@@ -158,13 +159,13 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True): | |||||
for m in list(module._flatten(predicate=is_qat)): | for m in list(module._flatten(predicate=is_qat)): | ||||
if m.with_weight: | if m.with_weight: | ||||
weight_q_dict = m.get_weight_qparams() | |||||
m.weight_observer = safe_call(qconfig.weight_observer, weight_q_dict) | |||||
m.weight_fake_quant = safe_call(qconfig.weight_fake_quant, weight_q_dict) | |||||
weight_params = m.get_weight_qparams() | |||||
m.weight_observer = safe_call(qconfig.weight_observer, weight_params) | |||||
m.weight_fake_quant = safe_call(qconfig.weight_fake_quant, weight_params) | |||||
if m.with_act: | if m.with_act: | ||||
act_q_dict = m.get_activation_qparams() | |||||
m.act_observer = safe_call(qconfig.act_observer, act_q_dict) | |||||
m.act_fake_quant = safe_call(qconfig.act_fake_quant, act_q_dict) | |||||
act_params = m.get_activation_qparams() | |||||
m.act_observer = safe_call(qconfig.act_observer, act_params) | |||||
m.act_fake_quant = safe_call(qconfig.act_fake_quant, act_params) | |||||
return module | return module | ||||
@@ -202,7 +203,9 @@ def hook_qat_module(module: Module, func: Callable): | |||||
return hooks | return hooks | ||||
def apply_easy_quant(module, data, start=0.8, stop=1.2, num=40): | |||||
def apply_easy_quant( | |||||
module: Module, data: Tensor, start: float = 0.8, stop: float = 1.2, num: int = 40 | |||||
): | |||||
r""" | r""" | ||||
Implementation of ``EasyQuant``: https://arxiv.org/pdf/2006.16669. | Implementation of ``EasyQuant``: https://arxiv.org/pdf/2006.16669. | ||||
Search for optimal scales. | Search for optimal scales. | ||||
@@ -5,9 +5,10 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import abc | |||||
from enum import Enum | from enum import Enum | ||||
from functools import partial, update_wrapper, wraps | from functools import partial, update_wrapper, wraps | ||||
from typing import Dict | |||||
from typing import Union | |||||
import numpy as np | import numpy as np | ||||
@@ -15,7 +16,11 @@ from .. import functional as F | |||||
from ..autodiff import Function | from ..autodiff import Function | ||||
from ..core._imperative_rt.core2 import apply | from ..core._imperative_rt.core2 import apply | ||||
from ..core.ops import builtin | from ..core.ops import builtin | ||||
from ..core.tensor.dtype import _metadata_dict | |||||
from ..core.tensor.dtype import ( | |||||
QuantDtypeMeta, | |||||
_builtin_quant_dtypes, | |||||
create_quantized_dtype, | |||||
) | |||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
@@ -61,37 +66,100 @@ class QuantMode(Enum): | |||||
ASYMMERTIC = 2 | ASYMMERTIC = 2 | ||||
qparam_dict = { | |||||
QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None}, | |||||
QuantMode.ASYMMERTIC: { | |||||
"mode": QuantMode.ASYMMERTIC, | |||||
"scale": None, | |||||
"zero_point": None, | |||||
}, | |||||
class QParams: | |||||
""" | |||||
To standardize FakeQuant, Observer and Tensor's qparams format. If custom | |||||
qparams is needed, inherit this class and add custom ``__slots__``. | |||||
""" | |||||
__slots__ = "mode", "dtype_meta", "scale", "zero_point" | |||||
def __init__( | |||||
self, | |||||
mode: QuantMode, | |||||
dtype_meta: QuantDtypeMeta, | |||||
scale: Tensor, | |||||
zero_point: Tensor, | |||||
): | |||||
self.mode = mode | |||||
self.dtype_meta = dtype_meta | |||||
self.scale = scale | |||||
self.zero_point = zero_point | |||||
def update(self, qparams: "QParams"): | |||||
for key in self.__slots__: | |||||
setattr(self, key, getattr(qparams, key)) | |||||
def __eq__(self, other): | |||||
if len(self.__slots__) != len(other.__slots__): | |||||
return False | |||||
for key in self.__slots__: | |||||
if not hasattr(other, key) or getattr(self, key) != getattr(other, key): | |||||
return False | |||||
return True | |||||
def __repr__(self): | |||||
content = ", ".join( | |||||
["{}={}".format(key, getattr(self, key)) for key in self.__slots__] | |||||
) | |||||
return "QParams({})".format(content) | |||||
class QParamsModuleMixin(abc.ABC): | |||||
def get_quantized_dtype(self): | |||||
qparams = self.get_qparams() | |||||
dtype = qparams.dtype_meta | |||||
scale = float(qparams.scale.numpy()) if qparams.scale is not None else None | |||||
zero_point = ( | |||||
int(qparams.zero_point.numpy()) if qparams.zero_point is not None else None | |||||
) | |||||
return create_quantized_dtype(dtype, scale, zero_point) | |||||
@abc.abstractmethod | |||||
def get_qparams(self) -> QParams: | |||||
pass | |||||
_builtin_qparams = { | |||||
QuantMode.SYMMERTIC: partial(QParams, mode=QuantMode.SYMMERTIC), | |||||
QuantMode.ASYMMERTIC: partial(QParams, mode=QuantMode.ASYMMERTIC), | |||||
} | } | ||||
def get_qparam_dict(mode: QuantMode): | |||||
def create_qparams( | |||||
mode: QuantMode = QuantMode.SYMMERTIC, | |||||
dtype_meta: Union[str, QuantDtypeMeta] = None, | |||||
scale: Tensor = None, | |||||
zero_point: Tensor = None, | |||||
): | |||||
""" | """ | ||||
Return the quantization parameters dictionary according to the mode. | |||||
Return :class:`~.QParams` according to the mode. | |||||
""" | """ | ||||
return qparam_dict.get(mode, None) | |||||
if isinstance(dtype_meta, str): | |||||
dtype_meta = _builtin_quant_dtypes[dtype_meta] | |||||
if mode is None: | |||||
return QParams(mode, dtype_meta, scale, zero_point) | |||||
assert isinstance(mode, QuantMode) | |||||
return _builtin_qparams[mode]( | |||||
dtype_meta=dtype_meta, scale=scale, zero_point=zero_point | |||||
) | |||||
def fake_quant_tensor(inp: Tensor, qmin: int, qmax: int, q_dict: Dict) -> Tensor: | |||||
def fake_quant_tensor(inp: Tensor, qparams: QParams) -> Tensor: | |||||
""" | """ | ||||
Apply fake quantization to the inp tensor. | Apply fake quantization to the inp tensor. | ||||
:param inp: the input tensor which need to be faked. | :param inp: the input tensor which need to be faked. | ||||
:param qmin: the minimum value which the integer limit to. | |||||
:param qmax: the maximum value which the integer limit to. | |||||
:param q_dict: the quantization parameter dict. | |||||
:param qparams: to get mode, qmin, qmax, scale and zero_point from. | |||||
""" | """ | ||||
scale = q_dict["scale"] | |||||
zero_point = Tensor([0.0], dtype=np.float32) | |||||
if q_dict["mode"] == QuantMode.ASYMMERTIC: | |||||
zero_point = q_dict["zero_point"] | |||||
scale = qparams.scale | |||||
if qparams.mode == QuantMode.ASYMMERTIC: | |||||
zero_point = qparams.zero_point | |||||
else: | |||||
zero_point = Tensor([0.0], dtype=np.float32) | |||||
qmin = qparams.dtype_meta.qmin | |||||
qmax = qparams.dtype_meta.qmax | |||||
op = builtin.FakeQuant(qmin=qmin, qmax=qmax) | op = builtin.FakeQuant(qmin=qmin, qmax=qmax) | ||||
return apply(op, inp, scale, zero_point)[0] | return apply(op, inp, scale, zero_point)[0] | ||||
@@ -104,22 +172,34 @@ def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor: | |||||
:param bias: the bias tensor which need to be faked. | :param bias: the bias tensor which need to be faked. | ||||
:param inp: the input tensor which contain the quantization parameters. | :param inp: the input tensor which contain the quantization parameters. | ||||
:param qmax: the weight tensor which contain the quantization parameters. | |||||
:param w_qat: the weight tensor which contain the quantization parameters. | |||||
.. warning:: | .. warning:: | ||||
Only work for symmetric quantization method now. | Only work for symmetric quantization method now. | ||||
""" | """ | ||||
b_qat = bias | b_qat = bias | ||||
if hasattr(inp, "q_dict") and b_qat is not None: | |||||
if inp.q_dict["scale"] is not None and w_qat.q_dict["scale"] is not None: | |||||
# use the same mode with weight. | |||||
b_dict = get_qparam_dict(w_qat.q_dict["mode"]) | |||||
b_dict["scale"] = inp.q_dict["scale"] * w_qat.q_dict["scale"] | |||||
# TODO: add zero_point for ASYMMERTIC mode. | |||||
qmax = _metadata_dict["qint32"].qmax | |||||
qmin = _metadata_dict["qint32"].qmin | |||||
b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict) | |||||
b_qat.q_dict.update(b_dict) | |||||
if ( | |||||
getattr(inp, "qparams", None) is not None | |||||
and getattr(w_qat, "qparams", None) is not None | |||||
and bias is not None | |||||
): | |||||
inp_params = inp.qparams | |||||
w_params = w_qat.qparams | |||||
if inp_params.scale is not None and w_params.scale is not None: | |||||
assert inp_params.mode == w_params.mode, "incompatible QuantMode" | |||||
# TODO: support quint8 dtype. | |||||
assert ( | |||||
inp_params.dtype_meta.np_dtype_str == "int8" | |||||
and w_params.dtype_meta.np_dtype_str == "int8" | |||||
), "fake_quant_bias only support int8 like dtype now" | |||||
# use the same mode with weight. | |||||
# TODO: avoid hardcode | |||||
b_dtype = _builtin_quant_dtypes["qint32"] | |||||
b_param = create_qparams( | |||||
w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale | |||||
) | |||||
b_qat = fake_quant_tensor(bias, b_param) | |||||
b_qat.qparams.update(b_param) | |||||
return b_qat | return b_qat |
@@ -22,6 +22,8 @@ from .logger import get_logger | |||||
from .utils.deprecation import deprecated | from .utils.deprecation import deprecated | ||||
from .utils.naming import auto_naming | from .utils.naming import auto_naming | ||||
logger = get_logger(__name__) | |||||
class Tensor(_Tensor, ArrayMethodMixin): | class Tensor(_Tensor, ArrayMethodMixin): | ||||
r""" | r""" | ||||
@@ -30,7 +32,7 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
grad = None | grad = None | ||||
dmap_callback = None | dmap_callback = None | ||||
_q_dict = None | |||||
_qparams = None | |||||
def __new__( | def __new__( | ||||
cls, data, dtype=None, device=None, is_const=False, no_cache=False, name=None | cls, data, dtype=None, device=None, is_const=False, no_cache=False, name=None | ||||
@@ -50,7 +52,7 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
if isinstance(data, _Tensor): | if isinstance(data, _Tensor): | ||||
if dtype is not None: | if dtype is not None: | ||||
get_logger().warning( | |||||
logger.warning( | |||||
"dtype does not work when creating a new Tensor with another Tensor" | "dtype does not work when creating a new Tensor with another Tensor" | ||||
) | ) | ||||
obj = _Tensor.__new__(cls, data) | obj = _Tensor.__new__(cls, data) | ||||
@@ -101,10 +103,12 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
return super().dtype | return super().dtype | ||||
@property | @property | ||||
def q_dict(self): | |||||
if self._q_dict is None: | |||||
self._q_dict = {"mode": None, "scale": None, "zero_point": None} | |||||
return self._q_dict | |||||
def qparams(self): | |||||
from .quantization.utils import create_qparams # pylint: disable=all | |||||
if self._qparams is None: | |||||
self._qparams = create_qparams() | |||||
return self._qparams | |||||
def numpy(self) -> np.ndarray: | def numpy(self) -> np.ndarray: | ||||
r""" | r""" | ||||
@@ -185,14 +189,29 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||||
def __getstate__(self): | def __getstate__(self): | ||||
r""" __getstate__ will be called for pickle serialization or deep copy | r""" __getstate__ will be called for pickle serialization or deep copy | ||||
""" | """ | ||||
state = { | state = { | ||||
"qdict": self.q_dict, | |||||
"numpy": self.numpy(), | |||||
"dtype": self.dtype, | |||||
"device": self.device.logical_name, | |||||
} | } | ||||
if self._qparams is not None: | |||||
state["qparams"] = self._qparams | |||||
return state | return state | ||||
def __setstate__(self, state): | def __setstate__(self, state): | ||||
self._q_dict = state.pop("qdict") | |||||
from .quantization.utils import create_qparams # pylint: disable=all | |||||
if "qdict" in state: | |||||
qparams = state.pop("qdict") | |||||
logger.warning( | |||||
"Tensor's 'qdict' state is depreciated. Use 'qparams' instead" | |||||
) | |||||
elif "qparams" in state: | |||||
qparams = state.pop("qparams") | |||||
else: | |||||
qparams = None | |||||
self._reset(Tensor(state.pop("numpy"), state.pop("dtype"), state.pop("device"))) | |||||
self._qparams = qparams | |||||
tensor = Tensor | tensor = Tensor | ||||
@@ -14,7 +14,7 @@ import pytest | |||||
import megengine.core.tensor.megbrain_graph as G | import megengine.core.tensor.megbrain_graph as G | ||||
from megengine.core.ops import builtin as ops | from megengine.core.ops import builtin as ops | ||||
from megengine.core.tensor.dtype import ( | from megengine.core.tensor.dtype import ( | ||||
_metadata_dict, | |||||
_builtin_quant_dtypes, | |||||
convert_from_qint4, | convert_from_qint4, | ||||
convert_from_qint8, | convert_from_qint8, | ||||
convert_from_quint4, | convert_from_quint4, | ||||
@@ -76,10 +76,10 @@ def _get_compiled_result(inp, dtype, shape, device, calc_func=None): | |||||
def _check_result_attr(oup, dtype, dtype_str, is_unsigned=True): | def _check_result_attr(oup, dtype, dtype_str, is_unsigned=True): | ||||
metadata = _metadata_dict[dtype_str] | |||||
metadata = _builtin_quant_dtypes[dtype_str] | |||||
assert "mgb_dtype" in oup.dtype.metadata | assert "mgb_dtype" in oup.dtype.metadata | ||||
assert is_quantize(oup.dtype) | assert is_quantize(oup.dtype) | ||||
np.testing.assert_equal(oup.dtype.metadata["mgb_dtype"]["name"], metadata.name) | |||||
np.testing.assert_equal(oup.dtype.metadata["mgb_dtype"]["name"], metadata.cname) | |||||
np.testing.assert_allclose(get_scale(oup.dtype), get_scale(dtype)) | np.testing.assert_allclose(get_scale(oup.dtype), get_scale(dtype)) | ||||
if is_unsigned: | if is_unsigned: | ||||
np.testing.assert_equal(get_zero_point(oup.dtype), get_zero_point(dtype)) | np.testing.assert_equal(get_zero_point(oup.dtype), get_zero_point(dtype)) | ||||
@@ -65,9 +65,9 @@ def test_tensor_serialization(): | |||||
with TemporaryFile() as f: | with TemporaryFile() as f: | ||||
a = Tensor(0) | a = Tensor(0) | ||||
a.q_dict["scale"] = Tensor(1.0) | |||||
a.qparams.scale = Tensor(1.0) | |||||
pickle.dump(a, f) | pickle.dump(a, f) | ||||
f.seek(0) | f.seek(0) | ||||
b = pickle.load(f) | b = pickle.load(f) | ||||
assert isinstance(b.q_dict["scale"], Tensor) | |||||
np.testing.assert_equal(b.q_dict["scale"].numpy(), 1.0) | |||||
assert isinstance(b.qparams.scale, Tensor) | |||||
np.testing.assert_equal(b.qparams.scale.numpy(), 1.0) |
@@ -6,6 +6,8 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import copy | |||||
import numpy as np | import numpy as np | ||||
from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 | from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 | ||||
@@ -86,3 +88,23 @@ def test_as_type(): | |||||
b = a.astype(quint8(0.3, 128)) | b = a.astype(quint8(0.3, 128)) | ||||
np.testing.assert_almost_equal(get_scale(b.dtype), 0.3) | np.testing.assert_almost_equal(get_scale(b.dtype), 0.3) | ||||
np.testing.assert_equal(get_zero_point(b.dtype), 128) | np.testing.assert_equal(get_zero_point(b.dtype), 128) | ||||
def test_qparams(): | |||||
x = Tensor(1) | |||||
assert x.qparams.scale is None | |||||
x.qparams.scale = Tensor(1.0) | |||||
assert x.qparams.scale.numpy() == 1.0 | |||||
x2 = copy.copy(x) | |||||
assert x.qparams is x2.qparams and x2.qparams.scale.numpy() == 1.0 | |||||
x3 = copy.deepcopy(x) | |||||
assert x.qparams is not x3.qparams and x3.qparams.scale.numpy() == 1.0 | |||||
def test_name(): | |||||
x = Tensor(0) | |||||
assert x.name == "" | |||||
x.name = "x" | |||||
assert x.name == "x" | |||||
x = Tensor(0, name="x") | |||||
assert x.name == "x" |
@@ -406,28 +406,3 @@ def test_copy_d2h(): | |||||
def test_copy_d2d(): | def test_copy_d2d(): | ||||
copy_test("gpu0", "gpu1") | copy_test("gpu0", "gpu1") | ||||
copy_test("gpu0:0", "gpu0:1") | copy_test("gpu0:0", "gpu0:1") | ||||
def test_name(): | |||||
x = tensor(0) | |||||
assert x.name == "" | |||||
x.name = "x" | |||||
assert x.name == "x" | |||||
x = tensor(0, name="x") | |||||
assert x.name == "x" | |||||
def test_q_dict(): | |||||
x = tensor(1) | |||||
assert x.q_dict["scale"] is None | |||||
x.q_dict["scale"] = tensor(1.0) | |||||
y = tensor(1) | |||||
assert y.q_dict["scale"] is None | |||||
y.q_dict["scale"] = tensor(2.0) | |||||
assert x.q_dict["scale"].numpy() == 1.0 | |||||
assert y.q_dict["scale"].numpy() == 2.0 | |||||
z = x + y | |||||
assert z.q_dict["scale"] is None |
@@ -12,9 +12,15 @@ import pytest | |||||
import megengine as mge | import megengine as mge | ||||
from megengine import tensor | from megengine import tensor | ||||
from megengine.core.autodiff.grad import Function, Grad | from megengine.core.autodiff.grad import Function, Grad | ||||
from megengine.core.tensor.dtype import QuantDtypeMeta | |||||
from megengine.core.tensor.utils import make_shape_tuple | from megengine.core.tensor.utils import make_shape_tuple | ||||
from megengine.quantization.internal_fake_quant import * | from megengine.quantization.internal_fake_quant import * | ||||
from megengine.quantization.utils import QuantMode, fake_quant_tensor, tqt_forward | |||||
from megengine.quantization.utils import ( | |||||
QuantMode, | |||||
create_qparams, | |||||
fake_quant_tensor, | |||||
tqt_forward, | |||||
) | |||||
class TQT_numpy: | class TQT_numpy: | ||||
@@ -111,16 +117,14 @@ def fake_quant_tensor_gt(inp, scale, zero_point, qmin, qmax): | |||||
def test_fakequant(): | def test_fakequant(): | ||||
qmin = -126 | qmin = -126 | ||||
qmax = 129 | qmax = 129 | ||||
test_dtype = QuantDtypeMeta("test_qint8", None, "int8", qmin, qmax) | |||||
def run(zero_point, scale): | def run(zero_point, scale): | ||||
q_dict = {} | |||||
q_dict["mode"] = QuantMode.ASYMMERTIC | |||||
q_dict["scale"] = scale | |||||
q_dict["zero_point"] = zero_point | |||||
qparams = create_qparams(QuantMode.ASYMMERTIC, test_dtype, scale, zero_point) | |||||
inp_data = np.random.uniform(low=-512.0, high=512.0, size=(1, 32, 32, 32)) | inp_data = np.random.uniform(low=-512.0, high=512.0, size=(1, 32, 32, 32)) | ||||
inp = tensor(inp_data, dtype=np.float32) | inp = tensor(inp_data, dtype=np.float32) | ||||
# test forward | # test forward | ||||
oup = fake_quant_tensor(inp, qmin, qmax, q_dict).numpy() | |||||
oup = fake_quant_tensor(inp, qparams).numpy() | |||||
oup_gt = fake_quant_tensor_gt(inp, scale, zero_point, qmin, qmax).numpy() | oup_gt = fake_quant_tensor_gt(inp, scale, zero_point, qmin, qmax).numpy() | ||||
assert np.allclose(oup, oup_gt) | assert np.allclose(oup, oup_gt) | ||||
assert oup.shape == oup_gt.shape | assert oup.shape == oup_gt.shape | ||||
@@ -128,7 +132,7 @@ def test_fakequant(): | |||||
# test backward | # test backward | ||||
x = tensor(inp_data, dtype=np.float32) | x = tensor(inp_data, dtype=np.float32) | ||||
grad = Grad().wrt(x, callback=_save_to(x)) | grad = Grad().wrt(x, callback=_save_to(x)) | ||||
y = fake_quant_tensor(x, qmin, qmax, q_dict) | |||||
y = fake_quant_tensor(x, qparams) | |||||
grad(y, tensor(F.ones_like(x))) | grad(y, tensor(F.ones_like(x))) | ||||
x1 = tensor(inp_data, dtype=np.float32) | x1 = tensor(inp_data, dtype=np.float32) | ||||
@@ -10,7 +10,13 @@ import megengine.module.qat as QAT | |||||
import megengine.module.quantized as Q | import megengine.module.quantized as Q | ||||
from megengine import Parameter, Tensor | from megengine import Parameter, Tensor | ||||
from megengine.core.tensor import dtype | from megengine.core.tensor import dtype | ||||
from megengine.quantization import FakeQuantize, MinMaxObserver, QConfig | |||||
from megengine.quantization import ( | |||||
FakeQuantize, | |||||
MinMaxObserver, | |||||
QConfig, | |||||
QuantMode, | |||||
create_qparams, | |||||
) | |||||
from megengine.quantization.quantize import ( | from megengine.quantization.quantize import ( | ||||
disable_fake_quant, | disable_fake_quant, | ||||
disable_observer, | disable_observer, | ||||
@@ -18,10 +24,10 @@ from megengine.quantization.quantize import ( | |||||
) | ) | ||||
min_max_fakequant_qconfig = QConfig( | min_max_fakequant_qconfig = QConfig( | ||||
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||||
act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False), | |||||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||||
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"), | |||||
act_observer=partial(MinMaxObserver, dtype="qint8"), | |||||
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"), | |||||
act_fake_quant=partial(FakeQuantize, dtype="qint8"), | |||||
) | ) | ||||
inp_scale = np.float32(np.random.rand() + 1) | inp_scale = np.float32(np.random.rand() + 1) | ||||
@@ -111,7 +117,7 @@ def test_dequant_stub(): | |||||
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | ||||
x = fake_quant_act(x, inp_scale) | x = fake_quant_act(x, inp_scale) | ||||
x.q_dict["scale"] = inp_scale | |||||
x.qparams.scale = inp_scale | |||||
normal = normal_net(x) | normal = normal_net(x) | ||||
qat_without_fakequant = qat_from_float(x) | qat_without_fakequant = qat_from_float(x) | ||||
@@ -146,12 +152,12 @@ def test_elemwise(kind): | |||||
x1_scale = np.float32(np.random.rand() + 1) | x1_scale = np.float32(np.random.rand() + 1) | ||||
x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | ||||
x1 = fake_quant_act(x1, x1_scale) | x1 = fake_quant_act(x1, x1_scale) | ||||
x1.q_dict["scale"] = x1_scale | |||||
x1.qparams.scale = x1_scale | |||||
x2_scale = np.float32(np.random.rand() + 1) | x2_scale = np.float32(np.random.rand() + 1) | ||||
x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | ||||
x2 = fake_quant_act(x2, x2_scale) | x2 = fake_quant_act(x2, x2_scale) | ||||
x2.q_dict["scale"] = x2_scale | |||||
x2.qparams.scale = x2_scale | |||||
x1_int8 = quant(x1, x1_scale) | x1_int8 = quant(x1, x1_scale) | ||||
x2_int8 = quant(x2, x2_scale) | x2_int8 = quant(x2, x2_scale) | ||||
@@ -187,7 +193,7 @@ def test_linear(): | |||||
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | ||||
x = fake_quant_act(x, inp_scale) | x = fake_quant_act(x, inp_scale) | ||||
x.q_dict["scale"] = inp_scale | |||||
x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) | |||||
x_int8 = quant(x, inp_scale) | x_int8 = quant(x, inp_scale) | ||||
@@ -230,7 +236,7 @@ def test_conv(module): | |||||
x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32")) | x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32")) | ||||
x = fake_quant_act(x, inp_scale) | x = fake_quant_act(x, inp_scale) | ||||
x.q_dict["scale"] = inp_scale | |||||
x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) | |||||
x_int8 = quant(x, inp_scale) | x_int8 = quant(x, inp_scale) | ||||
@@ -6,6 +6,7 @@ import pytest | |||||
import megengine as mge | import megengine as mge | ||||
import megengine.distributed as dist | import megengine.distributed as dist | ||||
from megengine.distributed.helper import get_device_count_by_fork | from megengine.distributed.helper import get_device_count_by_fork | ||||
from megengine.quantization import QuantMode, create_qparams | |||||
from megengine.quantization.observer import ( | from megengine.quantization.observer import ( | ||||
ExponentialMovingAverageObserver, | ExponentialMovingAverageObserver, | ||||
HistogramObserver, | HistogramObserver, | ||||
@@ -56,14 +57,14 @@ def test_histogram_observer(): | |||||
def test_passive_observer(): | def test_passive_observer(): | ||||
q_dict = {"scale": mge.tensor(1.0)} | |||||
qparams = create_qparams(QuantMode.SYMMERTIC, "qint8", mge.tensor(1.0)) | |||||
m = PassiveObserver("qint8") | m = PassiveObserver("qint8") | ||||
m.set_qparams(q_dict) | |||||
m.set_qparams(qparams) | |||||
assert m.orig_scale == 1.0 | assert m.orig_scale == 1.0 | ||||
assert m.scale == 1.0 | |||||
m.scale = 2.0 | |||||
assert m.scale == 2.0 | |||||
assert m.get_qparams() == {"scale": mge.tensor(2.0)} | |||||
assert m.scale.numpy() == 1.0 | |||||
assert m.get_qparams().dtype_meta == qparams.dtype_meta | |||||
assert m.get_qparams().scale == qparams.scale | |||||
assert m.get_qparams() == qparams | |||||
@pytest.mark.require_ngpu(2) | @pytest.mark.require_ngpu(2) | ||||
@@ -6,6 +6,7 @@ import megengine.functional as F | |||||
from megengine.core.tensor import dtype | from megengine.core.tensor import dtype | ||||
from megengine.distributed.helper import get_device_count_by_fork | from megengine.distributed.helper import get_device_count_by_fork | ||||
from megengine.functional.elemwise import _elemwise_multi_type, _elwise | from megengine.functional.elemwise import _elemwise_multi_type, _elwise | ||||
from megengine.quantization import QuantMode, create_qparams | |||||
def quant(x, scale): | def quant(x, scale): | ||||
@@ -26,13 +27,13 @@ def test_elemwise(kind): | |||||
x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | ||||
x1_scale = np.float32(np.random.rand() + 1) | x1_scale = np.float32(np.random.rand() + 1) | ||||
x1 = fake_quant(x1, x1_scale) | x1 = fake_quant(x1, x1_scale) | ||||
x1.q_dict["scale"] = x1_scale | |||||
x1.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", x1_scale)) | |||||
x1_int8 = quant(x1, x1_scale) | x1_int8 = quant(x1, x1_scale) | ||||
x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | ||||
x2_scale = np.float32(np.random.rand() + 1) | x2_scale = np.float32(np.random.rand() + 1) | ||||
x2 = fake_quant(x2, x2_scale) | x2 = fake_quant(x2, x2_scale) | ||||
x2.q_dict["scale"] = x2_scale | |||||
x2.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", x2_scale)) | |||||
x2_int8 = quant(x2, x2_scale) | x2_int8 = quant(x2, x2_scale) | ||||
output_scale = np.float32(np.random.rand() + 1) | output_scale = np.float32(np.random.rand() + 1) | ||||
@@ -1,14 +0,0 @@ | |||||
from functools import partial | |||||
from megengine.quantization import QConfig, tqt_qconfig | |||||
from megengine.quantization.fake_quant import TQT | |||||
def test_equal(): | |||||
qconfig = QConfig( | |||||
weight_observer=None, | |||||
act_observer=None, | |||||
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), | |||||
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), | |||||
) | |||||
assert qconfig == tqt_qconfig |
@@ -33,7 +33,7 @@ from megengine.quantization.quantize import ( | |||||
) | ) | ||||
class Net(Float.Module): | |||||
class FloatNet(Float.Module): | |||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
self.quant = Float.QuantStub() | self.quant = Float.QuantStub() | ||||
@@ -113,25 +113,25 @@ def test_reset_qconfig(): | |||||
def test_enable_and_disable_observer(): | def test_enable_and_disable_observer(): | ||||
net = init_qat_net() | net = init_qat_net() | ||||
enable_observer(net) | enable_observer(net) | ||||
assert net.quant.act_observer.enabled == True | |||||
assert net.linear.weight_observer.enabled == True | |||||
assert net.linear.act_observer.enabled == True | |||||
assert net.quant.act_observer.enabled is True | |||||
assert net.linear.weight_observer.enabled is True | |||||
assert net.linear.act_observer.enabled is True | |||||
disable_observer(net) | disable_observer(net) | ||||
assert net.quant.act_observer.enabled == False | |||||
assert net.linear.weight_observer.enabled == False | |||||
assert net.linear.act_observer.enabled == False | |||||
assert net.quant.act_observer.enabled is False | |||||
assert net.linear.weight_observer.enabled is False | |||||
assert net.linear.act_observer.enabled is False | |||||
def test_enable_and_disable_fake_quant(): | def test_enable_and_disable_fake_quant(): | ||||
net = init_qat_net() | net = init_qat_net() | ||||
disable_fake_quant(net) | disable_fake_quant(net) | ||||
assert net.quant.act_fake_quant.enabled == False | |||||
assert net.linear.weight_fake_quant.enabled == False | |||||
assert net.linear.act_fake_quant.enabled == False | |||||
assert net.quant.act_fake_quant.enabled is False | |||||
assert net.linear.weight_fake_quant.enabled is False | |||||
assert net.linear.act_fake_quant.enabled is False | |||||
enable_fake_quant(net) | enable_fake_quant(net) | ||||
assert net.quant.act_fake_quant.enabled == True | |||||
assert net.linear.weight_fake_quant.enabled == True | |||||
assert net.linear.act_fake_quant.enabled == True | |||||
assert net.quant.act_fake_quant.enabled is True | |||||
assert net.linear.weight_fake_quant.enabled is True | |||||
assert net.linear.act_fake_quant.enabled is True | |||||
def init_observer(module, data): | def init_observer(module, data): | ||||
@@ -144,7 +144,7 @@ def init_observer(module, data): | |||||
def test_enable_and_disable_all(): | def test_enable_and_disable_all(): | ||||
x = Tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32)) | x = Tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32)) | ||||
net = Net() | |||||
net = FloatNet() | |||||
y1 = net(x).numpy() | y1 = net(x).numpy() | ||||
net = quantize_qat(net, min_max_fakequant_qconfig) | net = quantize_qat(net, min_max_fakequant_qconfig) | ||||
@@ -162,7 +162,7 @@ def test_enable_and_disable_all(): | |||||
def test_quantize_qat(): | def test_quantize_qat(): | ||||
net = Net() | |||||
net = FloatNet() | |||||
qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig) | qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig) | ||||
assert isinstance(qat_net.quant, QAT.QuantStub) | assert isinstance(qat_net.quant, QAT.QuantStub) | ||||
assert isinstance(qat_net.linear, QAT.Linear) | assert isinstance(qat_net.linear, QAT.Linear) | ||||