GitOrigin-RevId: df3416fe13
tags/v1.3.0
@@ -6,9 +6,6 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import warnings | |||
from typing import Union | |||
from ..._imperative_rt import OpDef, ops | |||
__all__ = ["OpDef"] | |||
@@ -5,22 +5,24 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import collections | |||
from collections import namedtuple | |||
from typing import Union | |||
import numpy as np | |||
# normal dtype related | |||
from .._imperative_rt import bfloat16, intb1, intb2, intb4 | |||
from .._imperative_rt.common import ( | |||
bfloat16, | |||
get_scale, | |||
get_zero_point, | |||
intb1, | |||
intb2, | |||
intb4, | |||
is_dtype_equal, | |||
is_quantize, | |||
) | |||
# normal dtype related | |||
def is_lowbit(dtype): | |||
return (dtype is intb1) or (dtype is intb2) or (dtype is intb4) | |||
@@ -30,34 +32,80 @@ def is_bfloat16(dtype): | |||
# quantization dtype related | |||
_QuantDtypeMetadata = collections.namedtuple( | |||
"QuantDtypeMetadata", ["name", "np_dtype_str", "is_unsigned", "qmin", "qmax",] | |||
) | |||
_metadata_dict = { | |||
"quint8": _QuantDtypeMetadata("Quantized8Asymm", "uint8", True, 0, 255), | |||
"qint8": _QuantDtypeMetadata("QuantizedS8", "int8", False, -128, 127), | |||
"quint4": _QuantDtypeMetadata("Quantized4Asymm", "uint8", True, 0, 15), | |||
"qint4": _QuantDtypeMetadata("QuantizedS4", "int8", False, -8, 7), | |||
"qint32": _QuantDtypeMetadata( | |||
"QuantizedS32", "int32", False, -(2 ** 31), 2 ** 31 - 1, | |||
# use namedtuple to make class immutable, comparable and easy to print | |||
class QuantDtypeMeta( | |||
namedtuple( | |||
"QuantDtypeMeta", | |||
["name", "cname", "np_dtype_str", "qmin", "qmax", "is_unsigned"], | |||
) | |||
): | |||
r""" | |||
Store metadata for quantize dtype. Could be used to create custom quant dtype | |||
for QAT when the network don't need to be converted for inference, but only | |||
to export network metadata for third-party platform inference. | |||
:param name: a unique name string. | |||
:param cname: used in :func:`~.create_quantized_dtype` for model dump and inference. | |||
:param np_dtype_str: used in :func:`~.create_quantized_dtype` to generate ``np.dtype``. | |||
:param qmin: a int number indicating quant dtype's lowerbound. | |||
:param qmax: a int number indicating quant dtype's upperbound. | |||
:param is_unsigned: a helper value that could be inference from np_dtype_str. | |||
""" | |||
def __new__( | |||
cls, | |||
name: str, | |||
cname: str, | |||
np_dtype_str: str, | |||
qmin: int, | |||
qmax: int, | |||
is_unsigned: bool = None, | |||
): | |||
assert isinstance(np_dtype_str, str) | |||
is_unsigned = np_dtype_str[0] == "u" if is_unsigned is None else is_unsigned | |||
return super().__new__(cls, name, cname, np_dtype_str, qmin, qmax, is_unsigned) | |||
def __copy__(self): | |||
return self | |||
def __deepcopy__(self, _): | |||
""" | |||
Ignore deepcopy so that a dtype meta can be treated as singleton, for more | |||
strict check in :meth:`~.FakeQuantize.fake_quant_forward`. | |||
""" | |||
return self | |||
_builtin_quant_dtypes = { | |||
"quint8": QuantDtypeMeta("quint8", "Quantized8Asymm", "uint8", 0, 255), | |||
"qint8": QuantDtypeMeta("qint8", "QuantizedS8", "int8", -128, 127), | |||
"qint8_narrow": QuantDtypeMeta("qint8_narrow", "QuantizedS8", "int8", -127, 127), | |||
"quint4": QuantDtypeMeta("quint4", "Quantized4Asymm", "uint8", 0, 15), | |||
"qint4": QuantDtypeMeta("qint4", "QuantizedS4", "int8", -8, 7), | |||
"qint32": QuantDtypeMeta( | |||
"qint32", "QuantizedS32", "int32", -(2 ** 31), 2 ** 31 - 1, | |||
), | |||
# NOTE: int2 is not supported for model dump yet | |||
"quint2": _QuantDtypeMetadata(None, "uint8", True, 0, 3), | |||
"qint2": _QuantDtypeMetadata(None, "int8", False, -2, 1), | |||
"quint2": QuantDtypeMeta("quint2", None, "uint8", 0, 3), | |||
"qint2": QuantDtypeMeta("qint2", None, "int8", -2, 1), | |||
} | |||
def _check_zero_point(zp: int, dtype_str: str): | |||
qmin = _metadata_dict[dtype_str].qmin | |||
qmax = _metadata_dict[dtype_str].qmax | |||
def _check_zero_point(zp: int, dtype_meta: QuantDtypeMeta): | |||
qmin = dtype_meta.qmin | |||
qmax = dtype_meta.qmax | |||
if zp < qmin or zp > qmax: | |||
raise ValueError( | |||
"zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str) | |||
"zero_point should be within [{}, {}] for {}".format( | |||
qmin, qmax, dtype_meta.name | |||
) | |||
) | |||
def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]): | |||
def create_quantized_dtype( | |||
dtype_meta: QuantDtypeMeta, scale: float, zp: Union[int, None] | |||
): | |||
r""" | |||
Get quantized dtype with metadata attribute according to _metadata_dict. | |||
@@ -65,32 +113,34 @@ def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]): | |||
not have ``zero_point``, to be consitent with tensor generated by calling | |||
compiled function from `CompGraph.compile(inputs, outspec)`. | |||
:param dtype: a string indicating which dtype to return | |||
:param dtype_meta: a QuantDtypeMeta indicating which dtype to return. the | |||
``cname`` attribute cannot be ``None``. | |||
:param scale: a number for scale to store in dtype's metadata | |||
:param zp: a number for zero_point to store in dtype's metadata | |||
""" | |||
metadata = _metadata_dict[dtype_str] | |||
np_dtype_str = metadata.np_dtype_str | |||
is_unsigned = metadata.is_unsigned | |||
if is_unsigned: | |||
if dtype_meta.cname is None: | |||
raise ValueError("dtype {} without cname attr is not supported.") | |||
if dtype_meta.is_unsigned: | |||
if zp is None or int(zp) != zp: | |||
raise ValueError("zero_point should be an integer") | |||
zp = int(zp) | |||
_check_zero_point(zp, dtype_str) | |||
_check_zero_point(zp, dtype_meta) | |||
return np.dtype( | |||
np_dtype_str, | |||
dtype_meta.np_dtype_str, | |||
metadata={ | |||
"mgb_dtype": { | |||
"name": metadata.name, | |||
"name": dtype_meta.cname, | |||
"scale": float(scale), | |||
"zero_point": zp, | |||
} | |||
}, | |||
) | |||
else: | |||
# Don't trick to combine with is_unsigned. Metadata should not contain | |||
# invalid field to keep consistent with c dtype. | |||
return np.dtype( | |||
np_dtype_str, | |||
metadata={"mgb_dtype": {"name": metadata.name, "scale": float(scale)}}, | |||
dtype_meta.np_dtype_str, | |||
metadata={"mgb_dtype": {"name": dtype_meta.cname, "scale": float(scale)}}, | |||
) | |||
@@ -100,7 +150,7 @@ def quint8(scale, zero_point): | |||
``zero_point`` (uint8). The real value represented by a quint8 data type is | |||
float_val = scale * (uint8_val - zero_point) | |||
""" | |||
return get_quantized_dtype("quint8", scale, zero_point) | |||
return create_quantized_dtype(_builtin_quant_dtypes["quint8"], scale, zero_point) | |||
def qint8(scale): | |||
@@ -108,7 +158,7 @@ def qint8(scale): | |||
Construct a quantized int8 data type with ``scale`` (float). The real value | |||
represented by a qint8 data type is float_val = scale * int8_val | |||
""" | |||
return get_quantized_dtype("qint8", scale, None) | |||
return create_quantized_dtype(_builtin_quant_dtypes["qint8"], scale, None) | |||
def qint32(scale): | |||
@@ -116,7 +166,7 @@ def qint32(scale): | |||
Construct a quantized int32 data type with ``scale`` (float). The real value | |||
represented by a qint32 data type is float_val = scale * int32_val | |||
""" | |||
return get_quantized_dtype("qint32", scale, None) | |||
return create_quantized_dtype(_builtin_quant_dtypes["qint32"], scale, None) | |||
def quint4(scale, zero_point): | |||
@@ -125,7 +175,7 @@ def quint4(scale, zero_point): | |||
``zero_point`` (uint8). The real value represented by a quint4 data type is | |||
float_val = scale * (uint4_val - zero_point) | |||
""" | |||
return get_quantized_dtype("quint4", scale, zero_point) | |||
return create_quantized_dtype(_builtin_quant_dtypes["quint4"], scale, zero_point) | |||
def qint4(scale): | |||
@@ -133,42 +183,48 @@ def qint4(scale): | |||
Construct a quantized int4 data type with ``scale`` (float). The real value | |||
represented by a qint4 data type is float_val = scale * int4_val | |||
""" | |||
return get_quantized_dtype("qint4", scale, None) | |||
return create_quantized_dtype(_builtin_quant_dtypes["qint4"], scale, None) | |||
def _convert_to_quantized_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str): | |||
metadata = _metadata_dict[dtype_str] | |||
arr_metadata = dtype.metadata["mgb_dtype"] | |||
def _convert_to_quantized_dtype( | |||
arr: np.ndarray, dtype: np.dtype, dtype_meta: QuantDtypeMeta | |||
): | |||
if not isinstance(arr, np.ndarray): | |||
raise ValueError("arr parameter should be instance of np.ndarray") | |||
if not is_quantize(dtype) or arr_metadata["name"] != metadata.name: | |||
raise ValueError("dtype parameter should be a {} dtype".format(dtype_str)) | |||
is_unsigned = metadata.is_unsigned | |||
if is_unsigned: | |||
if ( | |||
not is_quantize(dtype) | |||
or dtype.metadata["mgb_dtype"]["name"] != dtype_meta.cname | |||
): | |||
raise ValueError("dtype parameter should be a {} dtype".format(dtype_meta)) | |||
arr_metadata = dtype.metadata["mgb_dtype"] | |||
if dtype_meta.is_unsigned: | |||
scale, zp = ( | |||
arr_metadata["scale"], | |||
arr_metadata["zero_point"], | |||
) | |||
return ( | |||
(np.round(arr / scale) + zp) | |||
.clip(metadata.qmin, metadata.qmax) | |||
.clip(dtype_meta.qmin, dtype_meta.qmax) | |||
.astype(dtype) | |||
) | |||
else: | |||
# don't trick to combine with is_unsigned, seeing ``get_quantized_dtype`` | |||
scale = arr_metadata["scale"] | |||
return np.round(arr / scale).clip(metadata.qmin, metadata.qmax).astype(dtype) | |||
return ( | |||
np.round(arr / scale).clip(dtype_meta.qmin, dtype_meta.qmax).astype(dtype) | |||
) | |||
def _convert_from_quantized_dtype(arr: np.ndarray, dtype_str: str): | |||
metadata = _metadata_dict[dtype_str] | |||
arr_metadata = arr.dtype.metadata["mgb_dtype"] | |||
def _convert_from_quantized_dtype(arr: np.ndarray, dtype_meta: QuantDtypeMeta): | |||
if not isinstance(arr, np.ndarray): | |||
raise ValueError("arr parameter should be instance of np.ndarray") | |||
if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata.name: | |||
raise ValueError("arr's dtype should be a {} dtype".format(dtype_str)) | |||
is_unsigned = metadata.is_unsigned | |||
if is_unsigned: | |||
if ( | |||
not is_quantize(arr.dtype) | |||
or arr.dtype.metadata["mgb_dtype"]["name"] != dtype_meta.cname | |||
): | |||
raise ValueError("arr's dtype should be a {} dtype".format(dtype_meta)) | |||
arr_metadata = arr.dtype.metadata["mgb_dtype"] | |||
if dtype_meta.is_unsigned: | |||
scale, zp = ( | |||
arr_metadata["scale"], | |||
arr_metadata["zero_point"], | |||
@@ -187,7 +243,7 @@ def convert_to_quint8(arr: np.ndarray, q: np.dtype): | |||
:param arr: Input ndarray. | |||
:param q: Target data type, should be a quint8. | |||
""" | |||
return _convert_to_quantized_dtype(arr, q, "quint8") | |||
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["quint8"]) | |||
def convert_from_quint8(arr: np.ndarray): | |||
@@ -196,7 +252,7 @@ def convert_from_quint8(arr: np.ndarray): | |||
:param arr: Input ndarray. | |||
""" | |||
return _convert_from_quantized_dtype(arr, "quint8") | |||
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["quint8"]) | |||
def convert_to_qint8(arr: np.ndarray, q: np.dtype): | |||
@@ -206,7 +262,7 @@ def convert_to_qint8(arr: np.ndarray, q: np.dtype): | |||
:param arr: Input ndarray. | |||
:param q: Target data type, should be a qint8. | |||
""" | |||
return _convert_to_quantized_dtype(arr, q, "qint8") | |||
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint8"]) | |||
def convert_from_qint8(arr: np.ndarray): | |||
@@ -215,7 +271,7 @@ def convert_from_qint8(arr: np.ndarray): | |||
:param arr: Input ndarray. | |||
""" | |||
return _convert_from_quantized_dtype(arr, "qint8") | |||
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint8"]) | |||
def convert_to_qint32(arr: np.ndarray, q: np.dtype): | |||
@@ -225,7 +281,7 @@ def convert_to_qint32(arr: np.ndarray, q: np.dtype): | |||
:param arr: Input ndarray. | |||
:param q: Target data type, should be a qint8. | |||
""" | |||
return _convert_to_quantized_dtype(arr, q, "qint32") | |||
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint32"]) | |||
def convert_from_qint32(arr): | |||
@@ -234,7 +290,7 @@ def convert_from_qint32(arr): | |||
:param arr: Input ndarray. | |||
""" | |||
return _convert_from_quantized_dtype(arr, "qint32") | |||
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint32"]) | |||
def convert_to_quint4(arr: np.ndarray, q: np.dtype): | |||
@@ -244,7 +300,7 @@ def convert_to_quint4(arr: np.ndarray, q: np.dtype): | |||
:param arr: Input ndarray. | |||
:param q: Target data type, should be a quint4. | |||
""" | |||
return _convert_to_quantized_dtype(arr, q, "quint4") | |||
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["quint4"]) | |||
def convert_from_quint4(arr: np.ndarray): | |||
@@ -253,7 +309,7 @@ def convert_from_quint4(arr: np.ndarray): | |||
:param arr: Input ndarray. | |||
""" | |||
return _convert_from_quantized_dtype(arr, "quint4") | |||
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["quint4"]) | |||
def convert_to_qint4(arr: np.ndarray, q: np.dtype): | |||
@@ -263,7 +319,7 @@ def convert_to_qint4(arr: np.ndarray, q: np.dtype): | |||
:param arr: Input ndarray. | |||
:param q: Target data type, should be a qint4. | |||
""" | |||
return _convert_to_quantized_dtype(arr, q, "qint4") | |||
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint4"]) | |||
def convert_from_qint4(arr: np.ndarray): | |||
@@ -272,4 +328,4 @@ def convert_from_qint4(arr: np.ndarray): | |||
:param arr: Input ndarray. | |||
""" | |||
return _convert_from_quantized_dtype(arr, "qint4") | |||
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint4"]) |
@@ -203,7 +203,7 @@ def conv_transpose2d( | |||
assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT" | |||
if groups != 1: | |||
raise NotImplementedError("TODO") | |||
raise NotImplementedError("group transposed conv2d is not supported yet.") | |||
stride_h, stride_w = expand_hw(stride) | |||
pad_h, pad_w = expand_hw(padding) | |||
@@ -13,7 +13,6 @@ import itertools | |||
import json | |||
import os | |||
import typing | |||
import warnings | |||
import weakref | |||
import numpy as np | |||
@@ -5,7 +5,6 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import warnings | |||
from abc import ABCMeta, abstractmethod | |||
from collections import OrderedDict | |||
from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||
@@ -204,10 +203,9 @@ class Module(metaclass=ABCMeta): | |||
if "requires_grad" in kwargs: | |||
del kwargs["requires_grad"] | |||
warnings.warn( | |||
logger.warning( | |||
"Tensor currently has no requires_grad attribute " | |||
"so requires_grad argument is ignored here", | |||
DeprecationWarning, | |||
"so requires_grad argument is ignored here" | |||
) | |||
def predicate(obj) -> bool: | |||
@@ -232,10 +230,9 @@ class Module(metaclass=ABCMeta): | |||
if "requires_grad" in kwargs: | |||
del kwargs["requires_grad"] | |||
warnings.warn( | |||
logger.warning( | |||
"Tensor currently has no requires_grad attribute " | |||
"so requires_grad argument is ignored here", | |||
DeprecationWarning, | |||
"so requires_grad argument is ignored here" | |||
) | |||
def predicate(obj) -> bool: | |||
@@ -7,7 +7,10 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from abc import abstractmethod | |||
from ...quantization import FakeQuantize, Observer, QConfig | |||
# avoid circular reference | |||
from ...quantization.fake_quant import FakeQuantize | |||
from ...quantization.observer import Observer | |||
from ...quantization.qconfig import QConfig | |||
from ...tensor import Tensor | |||
from ..module import Module | |||
@@ -73,19 +76,19 @@ class QATModule(Module): | |||
# do observer | |||
if observer is None: | |||
oup = target | |||
q_dict = None | |||
qparams = None | |||
else: | |||
oup = observer(target) | |||
q_dict = observer.get_qparams() | |||
qparams = observer.get_qparams() | |||
# do fake quant | |||
if fake_quant is not None: | |||
oup = fake_quant(oup, q_dict) | |||
oup = fake_quant(oup, qparams) | |||
# use qparams of fake_quant if have. | |||
if hasattr(fake_quant, "get_qparams"): | |||
q_dict = fake_quant.get_qparams() | |||
qparams = fake_quant.get_qparams() | |||
# set to tensor qparams. | |||
if q_dict is not None: | |||
oup.q_dict.update(q_dict) | |||
if qparams is not None: | |||
oup.qparams.update(qparams) | |||
return oup | |||
def apply_quant_weight(self, target: Tensor): | |||
@@ -118,7 +121,7 @@ class QATModule(Module): | |||
Get weight's quantization dtype as the method from ``qconfig``. | |||
""" | |||
return self._get_method_result( | |||
"get_dtype", self.weight_fake_quant, self.weight_observer | |||
"get_quantized_dtype", self.weight_fake_quant, self.weight_observer | |||
) | |||
def get_activation_dtype(self): | |||
@@ -126,7 +129,7 @@ class QATModule(Module): | |||
Get activation's quantization dtype as the method from ``qconfig``. | |||
""" | |||
return self._get_method_result( | |||
"get_dtype", self.act_fake_quant, self.act_observer | |||
"get_quantized_dtype", self.act_fake_quant, self.act_observer | |||
) | |||
def get_weight_qparams(self): | |||
@@ -7,8 +7,7 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .fake_quant import FakeQuantize | |||
from .internal_fake_quant import * | |||
from .observer import HistogramObserver, Observer | |||
from .observer import Observer | |||
from .qconfig import ( | |||
QConfig, | |||
calibration_qconfig, | |||
@@ -20,4 +19,15 @@ from .qconfig import ( | |||
sync_ema_fakequant_qconfig, | |||
tqt_qconfig, | |||
) | |||
from .utils import QuantMode | |||
from .quantize import ( | |||
apply_easy_quant, | |||
disable_fake_quant, | |||
disable_observer, | |||
enable_fake_quant, | |||
enable_observer, | |||
propagate_qconfig, | |||
quantize, | |||
quantize_qat, | |||
reset_qconfig, | |||
) | |||
from .utils import QParams, QuantMode, create_qparams |
@@ -6,40 +6,48 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import math | |||
from typing import Union | |||
from .. import functional as F | |||
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype | |||
from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes | |||
from ..logger import get_logger | |||
from ..module import Module | |||
from ..tensor import Parameter, Tensor | |||
from .utils import QuantMode, fake_quant_tensor, get_qparam_dict, tqt_forward | |||
from ..tensor import Parameter | |||
from .utils import ( | |||
QParams, | |||
QParamsModuleMixin, | |||
QuantMode, | |||
create_qparams, | |||
fake_quant_tensor, | |||
tqt_forward, | |||
) | |||
logger = get_logger(__name__) | |||
class _FakeQuantize(Module): | |||
r""" | |||
A Basic Fake Quant module. | |||
:param dtype: a string indicating the target quantization type of input. | |||
:param narrow_range: whether the absolute value of ``qmin`` is the same as ``qmax``, | |||
instead of 1 greater. Usually True for weight and False for activation. | |||
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``. | |||
""" | |||
class _FakeQuantize(Module): | |||
def __init__( | |||
self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs | |||
self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs | |||
): | |||
super().__init__() | |||
if not dtype in _metadata_dict.keys(): | |||
raise ValueError( | |||
"unknown dtype: {}, only support {}".format( | |||
dtype, _metadata_dict.keys() | |||
if isinstance(dtype, str): | |||
if not dtype in _builtin_quant_dtypes: | |||
raise ValueError( | |||
"unknown dtype: {}, only support {}".format( | |||
dtype, _builtin_quant_dtypes.keys() | |||
) | |||
) | |||
dtype = _builtin_quant_dtypes[dtype] | |||
if "narrow_range" in kwargs: | |||
del kwargs["narrow_range"] | |||
logger.warning( | |||
"FakeQuantize currently has no narrow_range param " | |||
"so it is ignored here", | |||
exc_info=DeprecationWarning, | |||
) | |||
self.dtype = dtype | |||
self.narrow_range = narrow_range | |||
self.qmin = ( | |||
-_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin | |||
) | |||
self.qmax = _metadata_dict[dtype].qmax | |||
self.qmin = dtype.qmin | |||
self.qmax = dtype.qmax | |||
self.enabled = enable | |||
def enable(self): | |||
@@ -48,61 +56,64 @@ class _FakeQuantize(Module): | |||
def disable(self): | |||
self.enabled = False | |||
def fake_quant_forward(self, inp, q_dict=None): | |||
return inp | |||
def fake_quant_forward(self, inp, qparams: QParams = None): | |||
raise NotImplementedError | |||
def normal_foward(self, inp, q_dict=None): | |||
def normal_foward(self, inp, qparams: QParams = None): | |||
return inp | |||
def forward(self, inp, q_dict=None): | |||
def forward(self, inp, qparams: QParams = None): | |||
if self.enabled: | |||
return self.fake_quant_forward(inp, q_dict=q_dict) | |||
return self.fake_quant_forward(inp, qparams=qparams) | |||
else: | |||
return self.normal_foward(inp, q_dict=q_dict) | |||
return self.normal_foward(inp, qparams=qparams) | |||
class TQT(_FakeQuantize): | |||
class TQT(_FakeQuantize, QParamsModuleMixin): | |||
r""" | |||
TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds | |||
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks. | |||
:param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target | |||
quantization dtype of input. | |||
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``. | |||
""" | |||
def __init__( | |||
self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs | |||
self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs | |||
): | |||
super().__init__(dtype, narrow_range, enable, **kwargs) | |||
super().__init__(dtype, enable, **kwargs) | |||
self.scale = Parameter(0.0, dtype="float32") | |||
def fake_quant_forward(self, inp, q_dict=None): | |||
def fake_quant_forward(self, inp, qparams: QParams = None): | |||
# when enable, TQT will do fakequant forward, finetune the scale | |||
return tqt_forward(self.qmin, self.qmax, inp, self.scale) | |||
def get_qparams(self): | |||
q_dict = get_qparam_dict(QuantMode.SYMMERTIC) | |||
q_dict["scale"] = 2 ** self.scale.detach() | |||
return q_dict | |||
def set_qparams(self, q_dict): | |||
def set_qparams(self, qparams: QParams): | |||
assert ( | |||
q_dict["mode"] == QuantMode.SYMMERTIC | |||
qparams.mode == QuantMode.SYMMERTIC | |||
), "only symmetric quantization is supported by TQT" | |||
if "scale" not in q_dict or q_dict["scale"] is None: | |||
if qparams.scale is None: | |||
raise AssertionError("Can not get an initialized scale") | |||
self.scale._reset(F.log(q_dict["scale"]) / math.log(2)) | |||
self.scale[...] = F.log(qparams.scale) / math.log(2) | |||
def get_dtype(self): | |||
q_dict = self.get_qparams() | |||
scale = None if "scale" not in q_dict else q_dict["scale"].numpy() | |||
zero_point = ( | |||
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy() | |||
) | |||
return get_quantized_dtype(self.dtype, scale, zero_point) | |||
def get_qparams(self): | |||
return create_qparams(QuantMode.SYMMERTIC, self.dtype, scale=2 ** self.scale) | |||
class FakeQuantize(_FakeQuantize): | |||
r""" | |||
A module to do quant and dequant according to observer's scale and zero_point. | |||
:param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target | |||
quantization dtype of input. | |||
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``. | |||
""" | |||
def fake_quant_forward(self, inp, q_dict=None): | |||
return fake_quant_tensor(inp, self.qmin, self.qmax, q_dict) | |||
def fake_quant_forward(self, inp, qparams: QParams = None): | |||
assert ( | |||
qparams.dtype_meta is self.dtype | |||
), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format( | |||
qparams.dtype_meta, self.dtype | |||
) | |||
return fake_quant_tensor(inp, qparams) |
@@ -16,4 +16,6 @@ from ..autodiff import Function | |||
from .fake_quant import _FakeQuantize | |||
from .observer import MinMaxObserver | |||
from .qconfig import QConfig | |||
from .utils import QParams | |||
@@ -8,51 +8,51 @@ | |||
import math | |||
from abc import abstractmethod | |||
from copy import deepcopy | |||
from typing import Union | |||
import numpy as np | |||
from .. import functional as F | |||
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype | |||
from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes | |||
from ..distributed import WORLD, get_rank, is_distributed | |||
from ..functional.distributed import all_reduce_max, all_reduce_min | |||
from ..logger import get_logger | |||
from ..module import Module | |||
from ..tensor import Tensor | |||
from .utils import QuantMode, get_qparam_dict | |||
from .utils import QParams, QParamsModuleMixin, QuantMode, create_qparams | |||
logger = get_logger(__name__) | |||
class Observer(Module): | |||
class Observer(Module, QParamsModuleMixin): | |||
r""" | |||
A base class for Observer Module. | |||
:param dtype: a string indicating to collect scale and zero_point of which dtype. | |||
:param narrow_range: whether the absolute value of ``qmin`` is the same as ``qmax``, | |||
instead of 1 greater. Usually True for weight and False for activation. | |||
""" | |||
def __init__(self, dtype: str, narrow_range: bool = False, **kwargs): | |||
def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs): | |||
super().__init__() | |||
if dtype not in _metadata_dict.keys(): | |||
raise ValueError( | |||
"unknown dtype: {}, only support {}".format( | |||
dtype, _metadata_dict.keys() | |||
if isinstance(dtype, str): | |||
if not dtype in _builtin_quant_dtypes: | |||
raise ValueError( | |||
"unknown dtype: {}, only support {}".format( | |||
dtype, _builtin_quant_dtypes.keys() | |||
) | |||
) | |||
dtype = _builtin_quant_dtypes[dtype] | |||
if "narrow_range" in kwargs: | |||
del kwargs["narrow_range"] | |||
logger.warning( | |||
"FakeQuantize currently has no narrow_range param " | |||
"so it is ignored here", | |||
exc_info=DeprecationWarning, | |||
) | |||
self.dtype = dtype | |||
self.narrow_range = narrow_range | |||
self.qmin = ( | |||
-_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin | |||
) | |||
self.qmax = _metadata_dict[dtype].qmax | |||
self.qmin = dtype.qmin | |||
self.qmax = dtype.qmax | |||
self.enabled = True | |||
def get_dtype(self): | |||
q_dict = self.get_qparams() | |||
numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy() | |||
numpy_zero_point = ( | |||
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy() | |||
) | |||
return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point) | |||
def enable(self): | |||
self.enabled = True | |||
@@ -70,21 +70,16 @@ class Observer(Module): | |||
def forward(self, x): | |||
pass | |||
@abstractmethod | |||
def get_qparams(self, **kwargs): | |||
pass | |||
class MinMaxObserver(Observer): | |||
def __init__( | |||
self, | |||
mode=QuantMode.SYMMERTIC, | |||
eps=0.00001, | |||
dtype="qint8", | |||
narrow_range: bool = False, | |||
mode: QuantMode = QuantMode.SYMMERTIC, | |||
eps: float = 0.00001, | |||
dtype: Union[str, QuantDtypeMeta] = "qint8", | |||
**kwargs | |||
): | |||
super().__init__(dtype, narrow_range, **kwargs) | |||
super().__init__(dtype, **kwargs) | |||
self.mode = mode | |||
self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32) | |||
self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32) | |||
@@ -93,26 +88,22 @@ class MinMaxObserver(Observer): | |||
def _calculate_qparams(self, inp_min_val, inp_max_val): | |||
min_val = F.minimum(0.0, inp_min_val) | |||
max_val = F.maximum(0.0, inp_max_val) | |||
q_dict = get_qparam_dict(self.mode) | |||
q_dict["min_val"] = inp_min_val | |||
q_dict["max_val"] = inp_max_val | |||
q_dict["enable_observer"] = self.enable | |||
if self.mode == QuantMode.SYMMERTIC: | |||
symmetric_max_vals = F.maximum(-min_val, max_val) | |||
# use maximun to avoid scale too small at the begin | |||
q_dict["scale"] = F.maximum( | |||
scale = F.maximum( | |||
symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit | |||
) | |||
# zero_point = self.zero_point | |||
zero_point = None | |||
else: | |||
# use maximun to avoid scale too small at the begin | |||
q_dict["scale"] = F.maximum( | |||
scale = F.maximum( | |||
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit | |||
) | |||
# caculate zero_point | |||
q_dict["zero_point"] = self.qmin - F.round(min_val / q_dict["scale"]) | |||
zero_point = self.qmin - F.round((min_val / scale)) | |||
return q_dict | |||
return create_qparams(self.mode, self.dtype, scale=scale, zero_point=zero_point) | |||
def get_qparams(self): | |||
return self._calculate_qparams(self.min_val, self.max_val) | |||
@@ -122,8 +113,8 @@ class MinMaxObserver(Observer): | |||
# stop gradient | |||
x = x_orig.detach() | |||
# find max and min | |||
self.min_val._reset(F.minimum(self.min_val, x.min())) | |||
self.max_val._reset(F.maximum(self.max_val, x.max())) | |||
self.min_val[...] = F.minimum(self.min_val, x.min()) | |||
self.max_val[...] = F.maximum(self.max_val, x.max()) | |||
return x_orig | |||
@@ -137,42 +128,43 @@ class SyncMinMaxObserver(MinMaxObserver): | |||
else: | |||
min_x = x.min() | |||
max_x = x.max() | |||
self.min_val._reset(F.minimum(self.min_val, min_x)) | |||
self.max_val._reset(F.maximum(self.max_val, max_x)) | |||
self.min_val[...] = F.minimum(self.min_val, min_x) | |||
self.max_val[...] = F.maximum(self.max_val, max_x) | |||
return x_orig | |||
class ExponentialMovingAverageObserver(MinMaxObserver): | |||
def __init__( | |||
self, | |||
momentum=0.9, | |||
mode=QuantMode.SYMMERTIC, | |||
eps=0.00001, | |||
dtype="qint8", | |||
narrow_range: bool = False, | |||
momentum: float = 0.9, | |||
mode: QuantMode = QuantMode.SYMMERTIC, | |||
eps: float = 0.00001, | |||
dtype: Union[str, QuantDtypeMeta] = "qint8", | |||
**kwargs | |||
): | |||
super().__init__(mode, eps, dtype, narrow_range, **kwargs) | |||
super().__init__(mode, eps, dtype, **kwargs) | |||
self.momentum = Tensor(momentum, dtype="float32") | |||
# used to avoid if-clauses in the first forward which is not supported | |||
# in trace mode. | |||
self.runtime_momentum = Tensor(0.0) | |||
def set_momentum(self, momentum): | |||
self.momentum = Tenosr(momentum, dtype="float32") | |||
self.momentum = Tensor(momentum, dtype="float32") | |||
def forward(self, x_orig): | |||
if self.enabled: | |||
# stop gradient | |||
x = x_orig.detach() | |||
# Exponential Moving Average | |||
self.min_val._reset( | |||
self.min_val[...] = ( | |||
self.min_val * self.runtime_momentum | |||
+ (1 - self.runtime_momentum) * x.min() | |||
) | |||
self.max_val._reset( | |||
self.max_val[...] = ( | |||
self.max_val * self.runtime_momentum | |||
+ (1 - self.runtime_momentum) * x.max() | |||
) | |||
self.runtime_momentum = self.momentum | |||
self.runtime_momentum[...] = self.momentum | |||
return x_orig | |||
@@ -187,33 +179,34 @@ class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver): | |||
else: | |||
min_x = x.min() | |||
max_x = x.max() | |||
self.min_val._reset( | |||
self.min_val[...] = ( | |||
self.min_val * self.runtime_momentum | |||
+ (1 - self.runtime_momentum) * min_x | |||
) | |||
self.max_val._reset( | |||
self.max_val[...] = ( | |||
self.max_val * self.runtime_momentum | |||
+ (1 - self.runtime_momentum) * max_x | |||
) | |||
self.runtime_momentum = self.momentum | |||
self.runtime_momentum[...] = self.momentum | |||
return x_orig | |||
class HistogramObserver(MinMaxObserver): | |||
def __init__( | |||
self, | |||
bins=2048, | |||
upsample_rate=128, | |||
mode=QuantMode.SYMMERTIC, | |||
eps=0.00001, | |||
dtype="qint8", | |||
narrow_range: bool = False, | |||
bins: int = 2048, | |||
upsample_rate: int = 128, | |||
mode: QuantMode = QuantMode.SYMMERTIC, | |||
eps: float = 0.00001, | |||
dtype: Union[str, QuantDtypeMeta] = "qint8", | |||
**kwargs | |||
): | |||
super().__init__(mode, eps, dtype, narrow_range, **kwargs) | |||
super().__init__(mode, eps, dtype, **kwargs) | |||
self.bins = bins | |||
self.upsample_rate = upsample_rate | |||
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 | |||
self.dst_nbins = ( | |||
_builtin_quant_dtypes[dtype].qmax - _builtin_quant_dtypes[dtype].qmin + 1 | |||
) | |||
self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32") | |||
def _non_linear_param_search(self): | |||
@@ -450,34 +443,45 @@ class HistogramObserver(MinMaxObserver): | |||
class PassiveObserver(Observer): | |||
r""" | |||
This class can be set :attr:`scale` derectly. | |||
An Observer that supports setting :attr:`scale` directly. | |||
""" | |||
def __init__(self, dtype: str, narrow_range: bool = False, **kwargs): | |||
super().__init__(dtype, narrow_range, **kwargs) | |||
self.q_dict = None | |||
def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs): | |||
super().__init__(dtype, **kwargs) | |||
self.qparams = None | |||
self.orig_scale = None | |||
@property | |||
def scale(self): | |||
return self.q_dict["scale"] | |||
return self.qparams.scale | |||
@scale.setter | |||
def scale(self, value): | |||
assert value > 0 | |||
self.q_dict["scale"][...] = Tensor(value) | |||
def scale(self, value: np.ndarray): | |||
assert np.all(value > 0) | |||
self.qparams.scale[...] = Tensor(value) | |||
def get_qparams(self): | |||
return self.q_dict | |||
return self.qparams | |||
def set_qparams(self, q_dict): | |||
self.q_dict = deepcopy(q_dict) | |||
if "scale" not in q_dict or q_dict["scale"] is None: | |||
def set_qparams(self, qparams: QParams): | |||
""" | |||
:param qparams: used to set initial scale. | |||
""" | |||
self.qparams = deepcopy(qparams) | |||
if qparams.scale is None: | |||
raise AssertionError("Can not get an initialized scale") | |||
self.orig_scale = q_dict["scale"].numpy() | |||
if qparams.dtype_meta is None: | |||
qparams.dtype_meta = self.dtype | |||
else: | |||
assert ( | |||
qparams.dtype_meta is self.dtype | |||
), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format( | |||
qparams.dtype_meta, self.dtype | |||
) | |||
self.orig_scale = qparams.scale.numpy() | |||
def forward(self, x): | |||
r""" | |||
Just return input because :attr:`q_dict` is set by :func:`~.apply_easy_quant`. | |||
Just return input because :attr:`qparams` is set by :func:`~.apply_easy_quant`. | |||
""" | |||
return x |
@@ -5,6 +5,7 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from collections import namedtuple | |||
from functools import partial | |||
from ..module import Module | |||
@@ -19,7 +20,13 @@ from .observer import ( | |||
) | |||
class QConfig: | |||
# use namedtuple to make class immutable, comparable and easy to print | |||
class QConfig( | |||
namedtuple( | |||
"QConfig", | |||
["weight_observer", "act_observer", "weight_fake_quant", "act_fake_quant"], | |||
) | |||
): | |||
r""" | |||
A config class indicating how to do quantize toward :class:`~.QATModule`'s | |||
``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage. | |||
@@ -37,90 +44,66 @@ class QConfig: | |||
# Default EMA QConfig for QAT. | |||
ema_fakequant_qconfig = QConfig( | |||
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False), | |||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"), | |||
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8"), | |||
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"), | |||
act_fake_quant=partial(FakeQuantize, dtype="qint8"), | |||
) | |||
Each parameter is a ``class`` rather than an instance. And we recommand using ``functools.partial`` | |||
to add initialization parameters of the ``class``, so that don't need to provide parameters in | |||
:meth:`~.QATModule.set_qconfig`. | |||
Usually we set ``narrow_range`` of weight related paramters to ``True`` and of activation related | |||
parameters to ``False``. For the result of multiplication and addition as ``a * b + c * d``, if | |||
four variables are all -128 of dtype ``qint8``, then the result will be ``2^15`` and cause overflow. | |||
Weights are commonly calculated in this way, so needed to narrow the range. | |||
Usually we choose narrow version dtype (like ``qint8_narrow``) for weight related | |||
paramters and normal version for activation related ones. For the result of | |||
multiplication and addition as ``a * b + c * d``, if four variables are all -128 of | |||
dtype ``qint8``, then the result will be ``2^15`` and cause overflow. | |||
Weights are commonly calculated in this way, so need to narrow qmin to -127. | |||
""" | |||
def __init__( | |||
self, weight_observer, act_observer, weight_fake_quant, act_fake_quant | |||
): | |||
def __new__(cls, weight_observer, act_observer, weight_fake_quant, act_fake_quant): | |||
if isinstance(act_observer, Module) or isinstance(weight_observer, Module): | |||
raise ValueError( | |||
"QConfig must not receive observer instance, please pass observer" | |||
" class generator using `partial(Observer, ...)` instead. Use" | |||
" partial(MyObserver, x=1) to override arguments to constructor if needed" | |||
) | |||
self.weight_observer = weight_observer | |||
self.act_observer = act_observer | |||
self.weight_fake_quant = weight_fake_quant | |||
self.act_fake_quant = act_fake_quant | |||
def __eq__(self, other): | |||
def eq(a, b): | |||
if isinstance(a, partial) and isinstance(b, partial): | |||
return all( | |||
[a.func == b.func, a.args == b.args, a.keywords == b.keywords] | |||
) | |||
else: | |||
return a == b | |||
return ( | |||
eq(self.weight_observer, other.weight_observer) | |||
and eq(self.act_observer, other.act_observer) | |||
and eq(self.weight_fake_quant, other.weight_fake_quant) | |||
and eq(self.act_fake_quant, other.act_fake_quant) | |||
return super().__new__( | |||
cls, weight_observer, act_observer, weight_fake_quant, act_fake_quant | |||
) | |||
min_max_fakequant_qconfig = QConfig( | |||
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||
act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False), | |||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"), | |||
act_observer=partial(MinMaxObserver, dtype="qint8"), | |||
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"), | |||
act_fake_quant=partial(FakeQuantize, dtype="qint8"), | |||
) | |||
ema_fakequant_qconfig = QConfig( | |||
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||
act_observer=partial( | |||
ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False | |||
), | |||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"), | |||
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8"), | |||
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"), | |||
act_fake_quant=partial(FakeQuantize, dtype="qint8"), | |||
) | |||
sync_ema_fakequant_qconfig = QConfig( | |||
weight_observer=partial(SyncMinMaxObserver, dtype="qint8", narrow_range=True), | |||
act_observer=partial( | |||
SyncExponentialMovingAverageObserver, dtype="qint8", narrow_range=False | |||
), | |||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||
weight_observer=partial(SyncMinMaxObserver, dtype="qint8_narrow"), | |||
act_observer=partial(SyncExponentialMovingAverageObserver, dtype="qint8"), | |||
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"), | |||
act_fake_quant=partial(FakeQuantize, dtype="qint8"), | |||
) | |||
ema_lowbit_fakequant_qconfig = QConfig( | |||
weight_observer=partial(MinMaxObserver, dtype="qint4", narrow_range=False), | |||
act_observer=partial( | |||
ExponentialMovingAverageObserver, dtype="qint4", narrow_range=False | |||
), | |||
weight_fake_quant=partial(FakeQuantize, dtype="qint4", narrow_range=False), | |||
act_fake_quant=partial(FakeQuantize, dtype="qint4", narrow_range=False), | |||
weight_observer=partial(MinMaxObserver, dtype="qint4"), | |||
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint4"), | |||
weight_fake_quant=partial(FakeQuantize, dtype="qint4"), | |||
act_fake_quant=partial(FakeQuantize, dtype="qint4"), | |||
) | |||
calibration_qconfig = QConfig( | |||
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||
act_observer=partial(HistogramObserver, dtype="qint8", narrow_range=False), | |||
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"), | |||
act_observer=partial(HistogramObserver, dtype="qint8"), | |||
weight_fake_quant=None, | |||
act_fake_quant=None, | |||
) | |||
@@ -128,15 +111,15 @@ calibration_qconfig = QConfig( | |||
tqt_qconfig = QConfig( | |||
weight_observer=None, | |||
act_observer=None, | |||
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), | |||
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), | |||
weight_fake_quant=partial(TQT, dtype="qint8_narrow"), | |||
act_fake_quant=partial(TQT, dtype="qint8"), | |||
) | |||
passive_qconfig = QConfig( | |||
weight_observer=partial(PassiveObserver, dtype="qint8", narrow_range=True), | |||
act_observer=partial(PassiveObserver, dtype="qint8", narrow_range=False), | |||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||
weight_observer=partial(PassiveObserver, dtype="qint8_narrow"), | |||
act_observer=partial(PassiveObserver, dtype="qint8"), | |||
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"), | |||
act_fake_quant=partial(FakeQuantize, dtype="qint8"), | |||
) | |||
easyquant_qconfig = passive_qconfig |
@@ -18,6 +18,7 @@ from ..module import qat as QAT | |||
from ..module import quantized as Quantized | |||
from ..module.qat import QATModule | |||
from ..module.quantized import QuantizedModule | |||
from ..tensor import Tensor | |||
from .qconfig import QConfig, ema_fakequant_qconfig | |||
@@ -147,10 +148,10 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True): | |||
if not inplace: | |||
module = deepcopy(module) | |||
def safe_call(func, q_dict): | |||
def safe_call(func, qparams): | |||
inst = func() if func is not None else None | |||
if inst is not None and getattr(inst, "set_qparams", None) is not None: | |||
inst.set_qparams(q_dict) | |||
inst.set_qparams(qparams) | |||
return inst | |||
def is_qat(mod: Module): | |||
@@ -158,13 +159,13 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True): | |||
for m in list(module._flatten(predicate=is_qat)): | |||
if m.with_weight: | |||
weight_q_dict = m.get_weight_qparams() | |||
m.weight_observer = safe_call(qconfig.weight_observer, weight_q_dict) | |||
m.weight_fake_quant = safe_call(qconfig.weight_fake_quant, weight_q_dict) | |||
weight_params = m.get_weight_qparams() | |||
m.weight_observer = safe_call(qconfig.weight_observer, weight_params) | |||
m.weight_fake_quant = safe_call(qconfig.weight_fake_quant, weight_params) | |||
if m.with_act: | |||
act_q_dict = m.get_activation_qparams() | |||
m.act_observer = safe_call(qconfig.act_observer, act_q_dict) | |||
m.act_fake_quant = safe_call(qconfig.act_fake_quant, act_q_dict) | |||
act_params = m.get_activation_qparams() | |||
m.act_observer = safe_call(qconfig.act_observer, act_params) | |||
m.act_fake_quant = safe_call(qconfig.act_fake_quant, act_params) | |||
return module | |||
@@ -202,7 +203,9 @@ def hook_qat_module(module: Module, func: Callable): | |||
return hooks | |||
def apply_easy_quant(module, data, start=0.8, stop=1.2, num=40): | |||
def apply_easy_quant( | |||
module: Module, data: Tensor, start: float = 0.8, stop: float = 1.2, num: int = 40 | |||
): | |||
r""" | |||
Implementation of ``EasyQuant``: https://arxiv.org/pdf/2006.16669. | |||
Search for optimal scales. | |||
@@ -5,9 +5,10 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import abc | |||
from enum import Enum | |||
from functools import partial, update_wrapper, wraps | |||
from typing import Dict | |||
from typing import Union | |||
import numpy as np | |||
@@ -15,7 +16,11 @@ from .. import functional as F | |||
from ..autodiff import Function | |||
from ..core._imperative_rt.core2 import apply | |||
from ..core.ops import builtin | |||
from ..core.tensor.dtype import _metadata_dict | |||
from ..core.tensor.dtype import ( | |||
QuantDtypeMeta, | |||
_builtin_quant_dtypes, | |||
create_quantized_dtype, | |||
) | |||
from ..tensor import Tensor | |||
@@ -61,37 +66,100 @@ class QuantMode(Enum): | |||
ASYMMERTIC = 2 | |||
qparam_dict = { | |||
QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None}, | |||
QuantMode.ASYMMERTIC: { | |||
"mode": QuantMode.ASYMMERTIC, | |||
"scale": None, | |||
"zero_point": None, | |||
}, | |||
class QParams: | |||
""" | |||
To standardize FakeQuant, Observer and Tensor's qparams format. If custom | |||
qparams is needed, inherit this class and add custom ``__slots__``. | |||
""" | |||
__slots__ = "mode", "dtype_meta", "scale", "zero_point" | |||
def __init__( | |||
self, | |||
mode: QuantMode, | |||
dtype_meta: QuantDtypeMeta, | |||
scale: Tensor, | |||
zero_point: Tensor, | |||
): | |||
self.mode = mode | |||
self.dtype_meta = dtype_meta | |||
self.scale = scale | |||
self.zero_point = zero_point | |||
def update(self, qparams: "QParams"): | |||
for key in self.__slots__: | |||
setattr(self, key, getattr(qparams, key)) | |||
def __eq__(self, other): | |||
if len(self.__slots__) != len(other.__slots__): | |||
return False | |||
for key in self.__slots__: | |||
if not hasattr(other, key) or getattr(self, key) != getattr(other, key): | |||
return False | |||
return True | |||
def __repr__(self): | |||
content = ", ".join( | |||
["{}={}".format(key, getattr(self, key)) for key in self.__slots__] | |||
) | |||
return "QParams({})".format(content) | |||
class QParamsModuleMixin(abc.ABC): | |||
def get_quantized_dtype(self): | |||
qparams = self.get_qparams() | |||
dtype = qparams.dtype_meta | |||
scale = float(qparams.scale.numpy()) if qparams.scale is not None else None | |||
zero_point = ( | |||
int(qparams.zero_point.numpy()) if qparams.zero_point is not None else None | |||
) | |||
return create_quantized_dtype(dtype, scale, zero_point) | |||
@abc.abstractmethod | |||
def get_qparams(self) -> QParams: | |||
pass | |||
_builtin_qparams = { | |||
QuantMode.SYMMERTIC: partial(QParams, mode=QuantMode.SYMMERTIC), | |||
QuantMode.ASYMMERTIC: partial(QParams, mode=QuantMode.ASYMMERTIC), | |||
} | |||
def get_qparam_dict(mode: QuantMode): | |||
def create_qparams( | |||
mode: QuantMode = QuantMode.SYMMERTIC, | |||
dtype_meta: Union[str, QuantDtypeMeta] = None, | |||
scale: Tensor = None, | |||
zero_point: Tensor = None, | |||
): | |||
""" | |||
Return the quantization parameters dictionary according to the mode. | |||
Return :class:`~.QParams` according to the mode. | |||
""" | |||
return qparam_dict.get(mode, None) | |||
if isinstance(dtype_meta, str): | |||
dtype_meta = _builtin_quant_dtypes[dtype_meta] | |||
if mode is None: | |||
return QParams(mode, dtype_meta, scale, zero_point) | |||
assert isinstance(mode, QuantMode) | |||
return _builtin_qparams[mode]( | |||
dtype_meta=dtype_meta, scale=scale, zero_point=zero_point | |||
) | |||
def fake_quant_tensor(inp: Tensor, qmin: int, qmax: int, q_dict: Dict) -> Tensor: | |||
def fake_quant_tensor(inp: Tensor, qparams: QParams) -> Tensor: | |||
""" | |||
Apply fake quantization to the inp tensor. | |||
:param inp: the input tensor which need to be faked. | |||
:param qmin: the minimum value which the integer limit to. | |||
:param qmax: the maximum value which the integer limit to. | |||
:param q_dict: the quantization parameter dict. | |||
:param qparams: to get mode, qmin, qmax, scale and zero_point from. | |||
""" | |||
scale = q_dict["scale"] | |||
zero_point = Tensor([0.0], dtype=np.float32) | |||
if q_dict["mode"] == QuantMode.ASYMMERTIC: | |||
zero_point = q_dict["zero_point"] | |||
scale = qparams.scale | |||
if qparams.mode == QuantMode.ASYMMERTIC: | |||
zero_point = qparams.zero_point | |||
else: | |||
zero_point = Tensor([0.0], dtype=np.float32) | |||
qmin = qparams.dtype_meta.qmin | |||
qmax = qparams.dtype_meta.qmax | |||
op = builtin.FakeQuant(qmin=qmin, qmax=qmax) | |||
return apply(op, inp, scale, zero_point)[0] | |||
@@ -104,22 +172,34 @@ def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor: | |||
:param bias: the bias tensor which need to be faked. | |||
:param inp: the input tensor which contain the quantization parameters. | |||
:param qmax: the weight tensor which contain the quantization parameters. | |||
:param w_qat: the weight tensor which contain the quantization parameters. | |||
.. warning:: | |||
Only work for symmetric quantization method now. | |||
""" | |||
b_qat = bias | |||
if hasattr(inp, "q_dict") and b_qat is not None: | |||
if inp.q_dict["scale"] is not None and w_qat.q_dict["scale"] is not None: | |||
# use the same mode with weight. | |||
b_dict = get_qparam_dict(w_qat.q_dict["mode"]) | |||
b_dict["scale"] = inp.q_dict["scale"] * w_qat.q_dict["scale"] | |||
# TODO: add zero_point for ASYMMERTIC mode. | |||
qmax = _metadata_dict["qint32"].qmax | |||
qmin = _metadata_dict["qint32"].qmin | |||
b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict) | |||
b_qat.q_dict.update(b_dict) | |||
if ( | |||
getattr(inp, "qparams", None) is not None | |||
and getattr(w_qat, "qparams", None) is not None | |||
and bias is not None | |||
): | |||
inp_params = inp.qparams | |||
w_params = w_qat.qparams | |||
if inp_params.scale is not None and w_params.scale is not None: | |||
assert inp_params.mode == w_params.mode, "incompatible QuantMode" | |||
# TODO: support quint8 dtype. | |||
assert ( | |||
inp_params.dtype_meta.np_dtype_str == "int8" | |||
and w_params.dtype_meta.np_dtype_str == "int8" | |||
), "fake_quant_bias only support int8 like dtype now" | |||
# use the same mode with weight. | |||
# TODO: avoid hardcode | |||
b_dtype = _builtin_quant_dtypes["qint32"] | |||
b_param = create_qparams( | |||
w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale | |||
) | |||
b_qat = fake_quant_tensor(bias, b_param) | |||
b_qat.qparams.update(b_param) | |||
return b_qat |
@@ -22,6 +22,8 @@ from .logger import get_logger | |||
from .utils.deprecation import deprecated | |||
from .utils.naming import auto_naming | |||
logger = get_logger(__name__) | |||
class Tensor(_Tensor, ArrayMethodMixin): | |||
r""" | |||
@@ -30,7 +32,7 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
grad = None | |||
dmap_callback = None | |||
_q_dict = None | |||
_qparams = None | |||
def __new__( | |||
cls, data, dtype=None, device=None, is_const=False, no_cache=False, name=None | |||
@@ -50,7 +52,7 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
if isinstance(data, _Tensor): | |||
if dtype is not None: | |||
get_logger().warning( | |||
logger.warning( | |||
"dtype does not work when creating a new Tensor with another Tensor" | |||
) | |||
obj = _Tensor.__new__(cls, data) | |||
@@ -101,10 +103,12 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
return super().dtype | |||
@property | |||
def q_dict(self): | |||
if self._q_dict is None: | |||
self._q_dict = {"mode": None, "scale": None, "zero_point": None} | |||
return self._q_dict | |||
def qparams(self): | |||
from .quantization.utils import create_qparams # pylint: disable=all | |||
if self._qparams is None: | |||
self._qparams = create_qparams() | |||
return self._qparams | |||
def numpy(self) -> np.ndarray: | |||
r""" | |||
@@ -185,14 +189,29 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
def __getstate__(self): | |||
r""" __getstate__ will be called for pickle serialization or deep copy | |||
""" | |||
state = { | |||
"qdict": self.q_dict, | |||
"numpy": self.numpy(), | |||
"dtype": self.dtype, | |||
"device": self.device.logical_name, | |||
} | |||
if self._qparams is not None: | |||
state["qparams"] = self._qparams | |||
return state | |||
def __setstate__(self, state): | |||
self._q_dict = state.pop("qdict") | |||
from .quantization.utils import create_qparams # pylint: disable=all | |||
if "qdict" in state: | |||
qparams = state.pop("qdict") | |||
logger.warning( | |||
"Tensor's 'qdict' state is depreciated. Use 'qparams' instead" | |||
) | |||
elif "qparams" in state: | |||
qparams = state.pop("qparams") | |||
else: | |||
qparams = None | |||
self._reset(Tensor(state.pop("numpy"), state.pop("dtype"), state.pop("device"))) | |||
self._qparams = qparams | |||
tensor = Tensor | |||
@@ -14,7 +14,7 @@ import pytest | |||
import megengine.core.tensor.megbrain_graph as G | |||
from megengine.core.ops import builtin as ops | |||
from megengine.core.tensor.dtype import ( | |||
_metadata_dict, | |||
_builtin_quant_dtypes, | |||
convert_from_qint4, | |||
convert_from_qint8, | |||
convert_from_quint4, | |||
@@ -76,10 +76,10 @@ def _get_compiled_result(inp, dtype, shape, device, calc_func=None): | |||
def _check_result_attr(oup, dtype, dtype_str, is_unsigned=True): | |||
metadata = _metadata_dict[dtype_str] | |||
metadata = _builtin_quant_dtypes[dtype_str] | |||
assert "mgb_dtype" in oup.dtype.metadata | |||
assert is_quantize(oup.dtype) | |||
np.testing.assert_equal(oup.dtype.metadata["mgb_dtype"]["name"], metadata.name) | |||
np.testing.assert_equal(oup.dtype.metadata["mgb_dtype"]["name"], metadata.cname) | |||
np.testing.assert_allclose(get_scale(oup.dtype), get_scale(dtype)) | |||
if is_unsigned: | |||
np.testing.assert_equal(get_zero_point(oup.dtype), get_zero_point(dtype)) | |||
@@ -65,9 +65,9 @@ def test_tensor_serialization(): | |||
with TemporaryFile() as f: | |||
a = Tensor(0) | |||
a.q_dict["scale"] = Tensor(1.0) | |||
a.qparams.scale = Tensor(1.0) | |||
pickle.dump(a, f) | |||
f.seek(0) | |||
b = pickle.load(f) | |||
assert isinstance(b.q_dict["scale"], Tensor) | |||
np.testing.assert_equal(b.q_dict["scale"].numpy(), 1.0) | |||
assert isinstance(b.qparams.scale, Tensor) | |||
np.testing.assert_equal(b.qparams.scale.numpy(), 1.0) |
@@ -6,6 +6,8 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import copy | |||
import numpy as np | |||
from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 | |||
@@ -86,3 +88,23 @@ def test_as_type(): | |||
b = a.astype(quint8(0.3, 128)) | |||
np.testing.assert_almost_equal(get_scale(b.dtype), 0.3) | |||
np.testing.assert_equal(get_zero_point(b.dtype), 128) | |||
def test_qparams(): | |||
x = Tensor(1) | |||
assert x.qparams.scale is None | |||
x.qparams.scale = Tensor(1.0) | |||
assert x.qparams.scale.numpy() == 1.0 | |||
x2 = copy.copy(x) | |||
assert x.qparams is x2.qparams and x2.qparams.scale.numpy() == 1.0 | |||
x3 = copy.deepcopy(x) | |||
assert x.qparams is not x3.qparams and x3.qparams.scale.numpy() == 1.0 | |||
def test_name(): | |||
x = Tensor(0) | |||
assert x.name == "" | |||
x.name = "x" | |||
assert x.name == "x" | |||
x = Tensor(0, name="x") | |||
assert x.name == "x" |
@@ -406,28 +406,3 @@ def test_copy_d2h(): | |||
def test_copy_d2d(): | |||
copy_test("gpu0", "gpu1") | |||
copy_test("gpu0:0", "gpu0:1") | |||
def test_name(): | |||
x = tensor(0) | |||
assert x.name == "" | |||
x.name = "x" | |||
assert x.name == "x" | |||
x = tensor(0, name="x") | |||
assert x.name == "x" | |||
def test_q_dict(): | |||
x = tensor(1) | |||
assert x.q_dict["scale"] is None | |||
x.q_dict["scale"] = tensor(1.0) | |||
y = tensor(1) | |||
assert y.q_dict["scale"] is None | |||
y.q_dict["scale"] = tensor(2.0) | |||
assert x.q_dict["scale"].numpy() == 1.0 | |||
assert y.q_dict["scale"].numpy() == 2.0 | |||
z = x + y | |||
assert z.q_dict["scale"] is None |
@@ -12,9 +12,15 @@ import pytest | |||
import megengine as mge | |||
from megengine import tensor | |||
from megengine.core.autodiff.grad import Function, Grad | |||
from megengine.core.tensor.dtype import QuantDtypeMeta | |||
from megengine.core.tensor.utils import make_shape_tuple | |||
from megengine.quantization.internal_fake_quant import * | |||
from megengine.quantization.utils import QuantMode, fake_quant_tensor, tqt_forward | |||
from megengine.quantization.utils import ( | |||
QuantMode, | |||
create_qparams, | |||
fake_quant_tensor, | |||
tqt_forward, | |||
) | |||
class TQT_numpy: | |||
@@ -111,16 +117,14 @@ def fake_quant_tensor_gt(inp, scale, zero_point, qmin, qmax): | |||
def test_fakequant(): | |||
qmin = -126 | |||
qmax = 129 | |||
test_dtype = QuantDtypeMeta("test_qint8", None, "int8", qmin, qmax) | |||
def run(zero_point, scale): | |||
q_dict = {} | |||
q_dict["mode"] = QuantMode.ASYMMERTIC | |||
q_dict["scale"] = scale | |||
q_dict["zero_point"] = zero_point | |||
qparams = create_qparams(QuantMode.ASYMMERTIC, test_dtype, scale, zero_point) | |||
inp_data = np.random.uniform(low=-512.0, high=512.0, size=(1, 32, 32, 32)) | |||
inp = tensor(inp_data, dtype=np.float32) | |||
# test forward | |||
oup = fake_quant_tensor(inp, qmin, qmax, q_dict).numpy() | |||
oup = fake_quant_tensor(inp, qparams).numpy() | |||
oup_gt = fake_quant_tensor_gt(inp, scale, zero_point, qmin, qmax).numpy() | |||
assert np.allclose(oup, oup_gt) | |||
assert oup.shape == oup_gt.shape | |||
@@ -128,7 +132,7 @@ def test_fakequant(): | |||
# test backward | |||
x = tensor(inp_data, dtype=np.float32) | |||
grad = Grad().wrt(x, callback=_save_to(x)) | |||
y = fake_quant_tensor(x, qmin, qmax, q_dict) | |||
y = fake_quant_tensor(x, qparams) | |||
grad(y, tensor(F.ones_like(x))) | |||
x1 = tensor(inp_data, dtype=np.float32) | |||
@@ -10,7 +10,13 @@ import megengine.module.qat as QAT | |||
import megengine.module.quantized as Q | |||
from megengine import Parameter, Tensor | |||
from megengine.core.tensor import dtype | |||
from megengine.quantization import FakeQuantize, MinMaxObserver, QConfig | |||
from megengine.quantization import ( | |||
FakeQuantize, | |||
MinMaxObserver, | |||
QConfig, | |||
QuantMode, | |||
create_qparams, | |||
) | |||
from megengine.quantization.quantize import ( | |||
disable_fake_quant, | |||
disable_observer, | |||
@@ -18,10 +24,10 @@ from megengine.quantization.quantize import ( | |||
) | |||
min_max_fakequant_qconfig = QConfig( | |||
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||
act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False), | |||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"), | |||
act_observer=partial(MinMaxObserver, dtype="qint8"), | |||
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"), | |||
act_fake_quant=partial(FakeQuantize, dtype="qint8"), | |||
) | |||
inp_scale = np.float32(np.random.rand() + 1) | |||
@@ -111,7 +117,7 @@ def test_dequant_stub(): | |||
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | |||
x = fake_quant_act(x, inp_scale) | |||
x.q_dict["scale"] = inp_scale | |||
x.qparams.scale = inp_scale | |||
normal = normal_net(x) | |||
qat_without_fakequant = qat_from_float(x) | |||
@@ -146,12 +152,12 @@ def test_elemwise(kind): | |||
x1_scale = np.float32(np.random.rand() + 1) | |||
x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | |||
x1 = fake_quant_act(x1, x1_scale) | |||
x1.q_dict["scale"] = x1_scale | |||
x1.qparams.scale = x1_scale | |||
x2_scale = np.float32(np.random.rand() + 1) | |||
x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | |||
x2 = fake_quant_act(x2, x2_scale) | |||
x2.q_dict["scale"] = x2_scale | |||
x2.qparams.scale = x2_scale | |||
x1_int8 = quant(x1, x1_scale) | |||
x2_int8 = quant(x2, x2_scale) | |||
@@ -187,7 +193,7 @@ def test_linear(): | |||
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | |||
x = fake_quant_act(x, inp_scale) | |||
x.q_dict["scale"] = inp_scale | |||
x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) | |||
x_int8 = quant(x, inp_scale) | |||
@@ -230,7 +236,7 @@ def test_conv(module): | |||
x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32")) | |||
x = fake_quant_act(x, inp_scale) | |||
x.q_dict["scale"] = inp_scale | |||
x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) | |||
x_int8 = quant(x, inp_scale) | |||
@@ -6,6 +6,7 @@ import pytest | |||
import megengine as mge | |||
import megengine.distributed as dist | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.quantization import QuantMode, create_qparams | |||
from megengine.quantization.observer import ( | |||
ExponentialMovingAverageObserver, | |||
HistogramObserver, | |||
@@ -56,14 +57,14 @@ def test_histogram_observer(): | |||
def test_passive_observer(): | |||
q_dict = {"scale": mge.tensor(1.0)} | |||
qparams = create_qparams(QuantMode.SYMMERTIC, "qint8", mge.tensor(1.0)) | |||
m = PassiveObserver("qint8") | |||
m.set_qparams(q_dict) | |||
m.set_qparams(qparams) | |||
assert m.orig_scale == 1.0 | |||
assert m.scale == 1.0 | |||
m.scale = 2.0 | |||
assert m.scale == 2.0 | |||
assert m.get_qparams() == {"scale": mge.tensor(2.0)} | |||
assert m.scale.numpy() == 1.0 | |||
assert m.get_qparams().dtype_meta == qparams.dtype_meta | |||
assert m.get_qparams().scale == qparams.scale | |||
assert m.get_qparams() == qparams | |||
@pytest.mark.require_ngpu(2) | |||
@@ -6,6 +6,7 @@ import megengine.functional as F | |||
from megengine.core.tensor import dtype | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.functional.elemwise import _elemwise_multi_type, _elwise | |||
from megengine.quantization import QuantMode, create_qparams | |||
def quant(x, scale): | |||
@@ -26,13 +27,13 @@ def test_elemwise(kind): | |||
x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | |||
x1_scale = np.float32(np.random.rand() + 1) | |||
x1 = fake_quant(x1, x1_scale) | |||
x1.q_dict["scale"] = x1_scale | |||
x1.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", x1_scale)) | |||
x1_int8 = quant(x1, x1_scale) | |||
x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) | |||
x2_scale = np.float32(np.random.rand() + 1) | |||
x2 = fake_quant(x2, x2_scale) | |||
x2.q_dict["scale"] = x2_scale | |||
x2.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", x2_scale)) | |||
x2_int8 = quant(x2, x2_scale) | |||
output_scale = np.float32(np.random.rand() + 1) | |||
@@ -1,14 +0,0 @@ | |||
from functools import partial | |||
from megengine.quantization import QConfig, tqt_qconfig | |||
from megengine.quantization.fake_quant import TQT | |||
def test_equal(): | |||
qconfig = QConfig( | |||
weight_observer=None, | |||
act_observer=None, | |||
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), | |||
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), | |||
) | |||
assert qconfig == tqt_qconfig |
@@ -33,7 +33,7 @@ from megengine.quantization.quantize import ( | |||
) | |||
class Net(Float.Module): | |||
class FloatNet(Float.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.quant = Float.QuantStub() | |||
@@ -113,25 +113,25 @@ def test_reset_qconfig(): | |||
def test_enable_and_disable_observer(): | |||
net = init_qat_net() | |||
enable_observer(net) | |||
assert net.quant.act_observer.enabled == True | |||
assert net.linear.weight_observer.enabled == True | |||
assert net.linear.act_observer.enabled == True | |||
assert net.quant.act_observer.enabled is True | |||
assert net.linear.weight_observer.enabled is True | |||
assert net.linear.act_observer.enabled is True | |||
disable_observer(net) | |||
assert net.quant.act_observer.enabled == False | |||
assert net.linear.weight_observer.enabled == False | |||
assert net.linear.act_observer.enabled == False | |||
assert net.quant.act_observer.enabled is False | |||
assert net.linear.weight_observer.enabled is False | |||
assert net.linear.act_observer.enabled is False | |||
def test_enable_and_disable_fake_quant(): | |||
net = init_qat_net() | |||
disable_fake_quant(net) | |||
assert net.quant.act_fake_quant.enabled == False | |||
assert net.linear.weight_fake_quant.enabled == False | |||
assert net.linear.act_fake_quant.enabled == False | |||
assert net.quant.act_fake_quant.enabled is False | |||
assert net.linear.weight_fake_quant.enabled is False | |||
assert net.linear.act_fake_quant.enabled is False | |||
enable_fake_quant(net) | |||
assert net.quant.act_fake_quant.enabled == True | |||
assert net.linear.weight_fake_quant.enabled == True | |||
assert net.linear.act_fake_quant.enabled == True | |||
assert net.quant.act_fake_quant.enabled is True | |||
assert net.linear.weight_fake_quant.enabled is True | |||
assert net.linear.act_fake_quant.enabled is True | |||
def init_observer(module, data): | |||
@@ -144,7 +144,7 @@ def init_observer(module, data): | |||
def test_enable_and_disable_all(): | |||
x = Tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32)) | |||
net = Net() | |||
net = FloatNet() | |||
y1 = net(x).numpy() | |||
net = quantize_qat(net, min_max_fakequant_qconfig) | |||
@@ -162,7 +162,7 @@ def test_enable_and_disable_all(): | |||
def test_quantize_qat(): | |||
net = Net() | |||
net = FloatNet() | |||
qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig) | |||
assert isinstance(qat_net.quant, QAT.QuantStub) | |||
assert isinstance(qat_net.linear, QAT.Linear) | |||