Browse Source

feat(mge/quantization): add QParams and QuantDtypeMeta for quantization data structure

GitOrigin-RevId: df3416fe13
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
1d7dd00144
24 changed files with 570 additions and 411 deletions
  1. +0
    -3
      imperative/python/megengine/core/ops/builtin/__init__.py
  2. +118
    -62
      imperative/python/megengine/core/tensor/dtype.py
  3. +1
    -1
      imperative/python/megengine/functional/nn.py
  4. +0
    -1
      imperative/python/megengine/jit/tracing.py
  5. +4
    -7
      imperative/python/megengine/module/module.py
  6. +12
    -9
      imperative/python/megengine/module/qat/module.py
  7. +13
    -3
      imperative/python/megengine/quantization/__init__.py
  8. +61
    -50
      imperative/python/megengine/quantization/fake_quant.py
  9. +2
    -0
      imperative/python/megengine/quantization/internal_fake_quant.py
  10. +84
    -80
      imperative/python/megengine/quantization/observer.py
  11. +44
    -61
      imperative/python/megengine/quantization/qconfig.py
  12. +12
    -9
      imperative/python/megengine/quantization/quantize.py
  13. +111
    -31
      imperative/python/megengine/quantization/utils.py
  14. +28
    -9
      imperative/python/megengine/tensor.py
  15. +3
    -3
      imperative/python/test/unit/core/test_dtype_quant.py
  16. +3
    -3
      imperative/python/test/unit/core/test_serialization.py
  17. +22
    -0
      imperative/python/test/unit/core/test_tensor_wrapper.py
  18. +0
    -25
      imperative/python/test/unit/functional/test_tensor.py
  19. +11
    -7
      imperative/python/test/unit/quantization/test_fake_quant.py
  20. +16
    -10
      imperative/python/test/unit/quantization/test_module.py
  21. +7
    -6
      imperative/python/test/unit/quantization/test_observer.py
  22. +3
    -2
      imperative/python/test/unit/quantization/test_op.py
  23. +0
    -14
      imperative/python/test/unit/quantization/test_qconfig.py
  24. +15
    -15
      imperative/python/test/unit/quantization/test_quantize.py

+ 0
- 3
imperative/python/megengine/core/ops/builtin/__init__.py View File

@@ -6,9 +6,6 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import warnings
from typing import Union

from ..._imperative_rt import OpDef, ops from ..._imperative_rt import OpDef, ops


__all__ = ["OpDef"] __all__ = ["OpDef"]


+ 118
- 62
imperative/python/megengine/core/tensor/dtype.py View File

@@ -5,22 +5,24 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

import collections
from collections import namedtuple
from typing import Union from typing import Union


import numpy as np import numpy as np


# normal dtype related
from .._imperative_rt import bfloat16, intb1, intb2, intb4
from .._imperative_rt.common import ( from .._imperative_rt.common import (
bfloat16,
get_scale, get_scale,
get_zero_point, get_zero_point,
intb1,
intb2,
intb4,
is_dtype_equal, is_dtype_equal,
is_quantize, is_quantize,
) )




# normal dtype related
def is_lowbit(dtype): def is_lowbit(dtype):
return (dtype is intb1) or (dtype is intb2) or (dtype is intb4) return (dtype is intb1) or (dtype is intb2) or (dtype is intb4)


@@ -30,34 +32,80 @@ def is_bfloat16(dtype):




# quantization dtype related # quantization dtype related
_QuantDtypeMetadata = collections.namedtuple(
"QuantDtypeMetadata", ["name", "np_dtype_str", "is_unsigned", "qmin", "qmax",]
)


_metadata_dict = {
"quint8": _QuantDtypeMetadata("Quantized8Asymm", "uint8", True, 0, 255),
"qint8": _QuantDtypeMetadata("QuantizedS8", "int8", False, -128, 127),
"quint4": _QuantDtypeMetadata("Quantized4Asymm", "uint8", True, 0, 15),
"qint4": _QuantDtypeMetadata("QuantizedS4", "int8", False, -8, 7),
"qint32": _QuantDtypeMetadata(
"QuantizedS32", "int32", False, -(2 ** 31), 2 ** 31 - 1,
# use namedtuple to make class immutable, comparable and easy to print
class QuantDtypeMeta(
namedtuple(
"QuantDtypeMeta",
["name", "cname", "np_dtype_str", "qmin", "qmax", "is_unsigned"],
)
):
r"""
Store metadata for quantize dtype. Could be used to create custom quant dtype
for QAT when the network don't need to be converted for inference, but only
to export network metadata for third-party platform inference.

:param name: a unique name string.
:param cname: used in :func:`~.create_quantized_dtype` for model dump and inference.
:param np_dtype_str: used in :func:`~.create_quantized_dtype` to generate ``np.dtype``.
:param qmin: a int number indicating quant dtype's lowerbound.
:param qmax: a int number indicating quant dtype's upperbound.
:param is_unsigned: a helper value that could be inference from np_dtype_str.
"""

def __new__(
cls,
name: str,
cname: str,
np_dtype_str: str,
qmin: int,
qmax: int,
is_unsigned: bool = None,
):
assert isinstance(np_dtype_str, str)
is_unsigned = np_dtype_str[0] == "u" if is_unsigned is None else is_unsigned
return super().__new__(cls, name, cname, np_dtype_str, qmin, qmax, is_unsigned)

def __copy__(self):
return self

def __deepcopy__(self, _):
"""
Ignore deepcopy so that a dtype meta can be treated as singleton, for more
strict check in :meth:`~.FakeQuantize.fake_quant_forward`.
"""
return self


_builtin_quant_dtypes = {
"quint8": QuantDtypeMeta("quint8", "Quantized8Asymm", "uint8", 0, 255),
"qint8": QuantDtypeMeta("qint8", "QuantizedS8", "int8", -128, 127),
"qint8_narrow": QuantDtypeMeta("qint8_narrow", "QuantizedS8", "int8", -127, 127),
"quint4": QuantDtypeMeta("quint4", "Quantized4Asymm", "uint8", 0, 15),
"qint4": QuantDtypeMeta("qint4", "QuantizedS4", "int8", -8, 7),
"qint32": QuantDtypeMeta(
"qint32", "QuantizedS32", "int32", -(2 ** 31), 2 ** 31 - 1,
), ),
# NOTE: int2 is not supported for model dump yet # NOTE: int2 is not supported for model dump yet
"quint2": _QuantDtypeMetadata(None, "uint8", True, 0, 3),
"qint2": _QuantDtypeMetadata(None, "int8", False, -2, 1),
"quint2": QuantDtypeMeta("quint2", None, "uint8", 0, 3),
"qint2": QuantDtypeMeta("qint2", None, "int8", -2, 1),
} }




def _check_zero_point(zp: int, dtype_str: str):
qmin = _metadata_dict[dtype_str].qmin
qmax = _metadata_dict[dtype_str].qmax
def _check_zero_point(zp: int, dtype_meta: QuantDtypeMeta):
qmin = dtype_meta.qmin
qmax = dtype_meta.qmax
if zp < qmin or zp > qmax: if zp < qmin or zp > qmax:
raise ValueError( raise ValueError(
"zero_point should be within [{}, {}] for {}".format(qmin, qmax, dtype_str)
"zero_point should be within [{}, {}] for {}".format(
qmin, qmax, dtype_meta.name
)
) )




def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]):
def create_quantized_dtype(
dtype_meta: QuantDtypeMeta, scale: float, zp: Union[int, None]
):
r""" r"""
Get quantized dtype with metadata attribute according to _metadata_dict. Get quantized dtype with metadata attribute according to _metadata_dict.


@@ -65,32 +113,34 @@ def get_quantized_dtype(dtype_str: str, scale: float, zp: Union[int, None]):
not have ``zero_point``, to be consitent with tensor generated by calling not have ``zero_point``, to be consitent with tensor generated by calling
compiled function from `CompGraph.compile(inputs, outspec)`. compiled function from `CompGraph.compile(inputs, outspec)`.


:param dtype: a string indicating which dtype to return
:param dtype_meta: a QuantDtypeMeta indicating which dtype to return. the
``cname`` attribute cannot be ``None``.
:param scale: a number for scale to store in dtype's metadata :param scale: a number for scale to store in dtype's metadata
:param zp: a number for zero_point to store in dtype's metadata :param zp: a number for zero_point to store in dtype's metadata
""" """
metadata = _metadata_dict[dtype_str]
np_dtype_str = metadata.np_dtype_str
is_unsigned = metadata.is_unsigned
if is_unsigned:
if dtype_meta.cname is None:
raise ValueError("dtype {} without cname attr is not supported.")
if dtype_meta.is_unsigned:
if zp is None or int(zp) != zp: if zp is None or int(zp) != zp:
raise ValueError("zero_point should be an integer") raise ValueError("zero_point should be an integer")
zp = int(zp) zp = int(zp)
_check_zero_point(zp, dtype_str)
_check_zero_point(zp, dtype_meta)
return np.dtype( return np.dtype(
np_dtype_str,
dtype_meta.np_dtype_str,
metadata={ metadata={
"mgb_dtype": { "mgb_dtype": {
"name": metadata.name,
"name": dtype_meta.cname,
"scale": float(scale), "scale": float(scale),
"zero_point": zp, "zero_point": zp,
} }
}, },
) )
else: else:
# Don't trick to combine with is_unsigned. Metadata should not contain
# invalid field to keep consistent with c dtype.
return np.dtype( return np.dtype(
np_dtype_str,
metadata={"mgb_dtype": {"name": metadata.name, "scale": float(scale)}},
dtype_meta.np_dtype_str,
metadata={"mgb_dtype": {"name": dtype_meta.cname, "scale": float(scale)}},
) )




@@ -100,7 +150,7 @@ def quint8(scale, zero_point):
``zero_point`` (uint8). The real value represented by a quint8 data type is ``zero_point`` (uint8). The real value represented by a quint8 data type is
float_val = scale * (uint8_val - zero_point) float_val = scale * (uint8_val - zero_point)
""" """
return get_quantized_dtype("quint8", scale, zero_point)
return create_quantized_dtype(_builtin_quant_dtypes["quint8"], scale, zero_point)




def qint8(scale): def qint8(scale):
@@ -108,7 +158,7 @@ def qint8(scale):
Construct a quantized int8 data type with ``scale`` (float). The real value Construct a quantized int8 data type with ``scale`` (float). The real value
represented by a qint8 data type is float_val = scale * int8_val represented by a qint8 data type is float_val = scale * int8_val
""" """
return get_quantized_dtype("qint8", scale, None)
return create_quantized_dtype(_builtin_quant_dtypes["qint8"], scale, None)




def qint32(scale): def qint32(scale):
@@ -116,7 +166,7 @@ def qint32(scale):
Construct a quantized int32 data type with ``scale`` (float). The real value Construct a quantized int32 data type with ``scale`` (float). The real value
represented by a qint32 data type is float_val = scale * int32_val represented by a qint32 data type is float_val = scale * int32_val
""" """
return get_quantized_dtype("qint32", scale, None)
return create_quantized_dtype(_builtin_quant_dtypes["qint32"], scale, None)




def quint4(scale, zero_point): def quint4(scale, zero_point):
@@ -125,7 +175,7 @@ def quint4(scale, zero_point):
``zero_point`` (uint8). The real value represented by a quint4 data type is ``zero_point`` (uint8). The real value represented by a quint4 data type is
float_val = scale * (uint4_val - zero_point) float_val = scale * (uint4_val - zero_point)
""" """
return get_quantized_dtype("quint4", scale, zero_point)
return create_quantized_dtype(_builtin_quant_dtypes["quint4"], scale, zero_point)




def qint4(scale): def qint4(scale):
@@ -133,42 +183,48 @@ def qint4(scale):
Construct a quantized int4 data type with ``scale`` (float). The real value Construct a quantized int4 data type with ``scale`` (float). The real value
represented by a qint4 data type is float_val = scale * int4_val represented by a qint4 data type is float_val = scale * int4_val
""" """
return get_quantized_dtype("qint4", scale, None)
return create_quantized_dtype(_builtin_quant_dtypes["qint4"], scale, None)




def _convert_to_quantized_dtype(arr: np.ndarray, dtype: np.dtype, dtype_str: str):
metadata = _metadata_dict[dtype_str]
arr_metadata = dtype.metadata["mgb_dtype"]
def _convert_to_quantized_dtype(
arr: np.ndarray, dtype: np.dtype, dtype_meta: QuantDtypeMeta
):
if not isinstance(arr, np.ndarray): if not isinstance(arr, np.ndarray):
raise ValueError("arr parameter should be instance of np.ndarray") raise ValueError("arr parameter should be instance of np.ndarray")
if not is_quantize(dtype) or arr_metadata["name"] != metadata.name:
raise ValueError("dtype parameter should be a {} dtype".format(dtype_str))
is_unsigned = metadata.is_unsigned
if is_unsigned:
if (
not is_quantize(dtype)
or dtype.metadata["mgb_dtype"]["name"] != dtype_meta.cname
):
raise ValueError("dtype parameter should be a {} dtype".format(dtype_meta))
arr_metadata = dtype.metadata["mgb_dtype"]
if dtype_meta.is_unsigned:
scale, zp = ( scale, zp = (
arr_metadata["scale"], arr_metadata["scale"],
arr_metadata["zero_point"], arr_metadata["zero_point"],
) )
return ( return (
(np.round(arr / scale) + zp) (np.round(arr / scale) + zp)
.clip(metadata.qmin, metadata.qmax)
.clip(dtype_meta.qmin, dtype_meta.qmax)
.astype(dtype) .astype(dtype)
) )
else: else:
# don't trick to combine with is_unsigned, seeing ``get_quantized_dtype`` # don't trick to combine with is_unsigned, seeing ``get_quantized_dtype``
scale = arr_metadata["scale"] scale = arr_metadata["scale"]
return np.round(arr / scale).clip(metadata.qmin, metadata.qmax).astype(dtype)
return (
np.round(arr / scale).clip(dtype_meta.qmin, dtype_meta.qmax).astype(dtype)
)




def _convert_from_quantized_dtype(arr: np.ndarray, dtype_str: str):
metadata = _metadata_dict[dtype_str]
arr_metadata = arr.dtype.metadata["mgb_dtype"]
def _convert_from_quantized_dtype(arr: np.ndarray, dtype_meta: QuantDtypeMeta):
if not isinstance(arr, np.ndarray): if not isinstance(arr, np.ndarray):
raise ValueError("arr parameter should be instance of np.ndarray") raise ValueError("arr parameter should be instance of np.ndarray")
if not is_quantize(arr.dtype) or arr_metadata["name"] != metadata.name:
raise ValueError("arr's dtype should be a {} dtype".format(dtype_str))
is_unsigned = metadata.is_unsigned
if is_unsigned:
if (
not is_quantize(arr.dtype)
or arr.dtype.metadata["mgb_dtype"]["name"] != dtype_meta.cname
):
raise ValueError("arr's dtype should be a {} dtype".format(dtype_meta))
arr_metadata = arr.dtype.metadata["mgb_dtype"]
if dtype_meta.is_unsigned:
scale, zp = ( scale, zp = (
arr_metadata["scale"], arr_metadata["scale"],
arr_metadata["zero_point"], arr_metadata["zero_point"],
@@ -187,7 +243,7 @@ def convert_to_quint8(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray. :param arr: Input ndarray.
:param q: Target data type, should be a quint8. :param q: Target data type, should be a quint8.
""" """
return _convert_to_quantized_dtype(arr, q, "quint8")
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["quint8"])




def convert_from_quint8(arr: np.ndarray): def convert_from_quint8(arr: np.ndarray):
@@ -196,7 +252,7 @@ def convert_from_quint8(arr: np.ndarray):


:param arr: Input ndarray. :param arr: Input ndarray.
""" """
return _convert_from_quantized_dtype(arr, "quint8")
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["quint8"])




def convert_to_qint8(arr: np.ndarray, q: np.dtype): def convert_to_qint8(arr: np.ndarray, q: np.dtype):
@@ -206,7 +262,7 @@ def convert_to_qint8(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray. :param arr: Input ndarray.
:param q: Target data type, should be a qint8. :param q: Target data type, should be a qint8.
""" """
return _convert_to_quantized_dtype(arr, q, "qint8")
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint8"])




def convert_from_qint8(arr: np.ndarray): def convert_from_qint8(arr: np.ndarray):
@@ -215,7 +271,7 @@ def convert_from_qint8(arr: np.ndarray):


:param arr: Input ndarray. :param arr: Input ndarray.
""" """
return _convert_from_quantized_dtype(arr, "qint8")
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint8"])




def convert_to_qint32(arr: np.ndarray, q: np.dtype): def convert_to_qint32(arr: np.ndarray, q: np.dtype):
@@ -225,7 +281,7 @@ def convert_to_qint32(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray. :param arr: Input ndarray.
:param q: Target data type, should be a qint8. :param q: Target data type, should be a qint8.
""" """
return _convert_to_quantized_dtype(arr, q, "qint32")
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint32"])




def convert_from_qint32(arr): def convert_from_qint32(arr):
@@ -234,7 +290,7 @@ def convert_from_qint32(arr):


:param arr: Input ndarray. :param arr: Input ndarray.
""" """
return _convert_from_quantized_dtype(arr, "qint32")
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint32"])




def convert_to_quint4(arr: np.ndarray, q: np.dtype): def convert_to_quint4(arr: np.ndarray, q: np.dtype):
@@ -244,7 +300,7 @@ def convert_to_quint4(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray. :param arr: Input ndarray.
:param q: Target data type, should be a quint4. :param q: Target data type, should be a quint4.
""" """
return _convert_to_quantized_dtype(arr, q, "quint4")
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["quint4"])




def convert_from_quint4(arr: np.ndarray): def convert_from_quint4(arr: np.ndarray):
@@ -253,7 +309,7 @@ def convert_from_quint4(arr: np.ndarray):


:param arr: Input ndarray. :param arr: Input ndarray.
""" """
return _convert_from_quantized_dtype(arr, "quint4")
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["quint4"])




def convert_to_qint4(arr: np.ndarray, q: np.dtype): def convert_to_qint4(arr: np.ndarray, q: np.dtype):
@@ -263,7 +319,7 @@ def convert_to_qint4(arr: np.ndarray, q: np.dtype):
:param arr: Input ndarray. :param arr: Input ndarray.
:param q: Target data type, should be a qint4. :param q: Target data type, should be a qint4.
""" """
return _convert_to_quantized_dtype(arr, q, "qint4")
return _convert_to_quantized_dtype(arr, q, _builtin_quant_dtypes["qint4"])




def convert_from_qint4(arr: np.ndarray): def convert_from_qint4(arr: np.ndarray):
@@ -272,4 +328,4 @@ def convert_from_qint4(arr: np.ndarray):


:param arr: Input ndarray. :param arr: Input ndarray.
""" """
return _convert_from_quantized_dtype(arr, "qint4")
return _convert_from_quantized_dtype(arr, _builtin_quant_dtypes["qint4"])

+ 1
- 1
imperative/python/megengine/functional/nn.py View File

@@ -203,7 +203,7 @@ def conv_transpose2d(
assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT" assert compute_mode == "DEFAULT" or compute_mode.name == "DEFAULT"


if groups != 1: if groups != 1:
raise NotImplementedError("TODO")
raise NotImplementedError("group transposed conv2d is not supported yet.")


stride_h, stride_w = expand_hw(stride) stride_h, stride_w = expand_hw(stride)
pad_h, pad_w = expand_hw(padding) pad_h, pad_w = expand_hw(padding)


+ 0
- 1
imperative/python/megengine/jit/tracing.py View File

@@ -13,7 +13,6 @@ import itertools
import json import json
import os import os
import typing import typing
import warnings
import weakref import weakref


import numpy as np import numpy as np


+ 4
- 7
imperative/python/megengine/module/module.py View File

@@ -5,7 +5,6 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import warnings
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from collections import OrderedDict from collections import OrderedDict
from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union
@@ -204,10 +203,9 @@ class Module(metaclass=ABCMeta):


if "requires_grad" in kwargs: if "requires_grad" in kwargs:
del kwargs["requires_grad"] del kwargs["requires_grad"]
warnings.warn(
logger.warning(
"Tensor currently has no requires_grad attribute " "Tensor currently has no requires_grad attribute "
"so requires_grad argument is ignored here",
DeprecationWarning,
"so requires_grad argument is ignored here"
) )


def predicate(obj) -> bool: def predicate(obj) -> bool:
@@ -232,10 +230,9 @@ class Module(metaclass=ABCMeta):


if "requires_grad" in kwargs: if "requires_grad" in kwargs:
del kwargs["requires_grad"] del kwargs["requires_grad"]
warnings.warn(
logger.warning(
"Tensor currently has no requires_grad attribute " "Tensor currently has no requires_grad attribute "
"so requires_grad argument is ignored here",
DeprecationWarning,
"so requires_grad argument is ignored here"
) )


def predicate(obj) -> bool: def predicate(obj) -> bool:


+ 12
- 9
imperative/python/megengine/module/qat/module.py View File

@@ -7,7 +7,10 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import abstractmethod from abc import abstractmethod


from ...quantization import FakeQuantize, Observer, QConfig
# avoid circular reference
from ...quantization.fake_quant import FakeQuantize
from ...quantization.observer import Observer
from ...quantization.qconfig import QConfig
from ...tensor import Tensor from ...tensor import Tensor
from ..module import Module from ..module import Module


@@ -73,19 +76,19 @@ class QATModule(Module):
# do observer # do observer
if observer is None: if observer is None:
oup = target oup = target
q_dict = None
qparams = None
else: else:
oup = observer(target) oup = observer(target)
q_dict = observer.get_qparams()
qparams = observer.get_qparams()
# do fake quant # do fake quant
if fake_quant is not None: if fake_quant is not None:
oup = fake_quant(oup, q_dict)
oup = fake_quant(oup, qparams)
# use qparams of fake_quant if have. # use qparams of fake_quant if have.
if hasattr(fake_quant, "get_qparams"): if hasattr(fake_quant, "get_qparams"):
q_dict = fake_quant.get_qparams()
qparams = fake_quant.get_qparams()
# set to tensor qparams. # set to tensor qparams.
if q_dict is not None:
oup.q_dict.update(q_dict)
if qparams is not None:
oup.qparams.update(qparams)
return oup return oup


def apply_quant_weight(self, target: Tensor): def apply_quant_weight(self, target: Tensor):
@@ -118,7 +121,7 @@ class QATModule(Module):
Get weight's quantization dtype as the method from ``qconfig``. Get weight's quantization dtype as the method from ``qconfig``.
""" """
return self._get_method_result( return self._get_method_result(
"get_dtype", self.weight_fake_quant, self.weight_observer
"get_quantized_dtype", self.weight_fake_quant, self.weight_observer
) )


def get_activation_dtype(self): def get_activation_dtype(self):
@@ -126,7 +129,7 @@ class QATModule(Module):
Get activation's quantization dtype as the method from ``qconfig``. Get activation's quantization dtype as the method from ``qconfig``.
""" """
return self._get_method_result( return self._get_method_result(
"get_dtype", self.act_fake_quant, self.act_observer
"get_quantized_dtype", self.act_fake_quant, self.act_observer
) )


def get_weight_qparams(self): def get_weight_qparams(self):


+ 13
- 3
imperative/python/megengine/quantization/__init__.py View File

@@ -7,8 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.


from .fake_quant import FakeQuantize from .fake_quant import FakeQuantize
from .internal_fake_quant import *
from .observer import HistogramObserver, Observer
from .observer import Observer
from .qconfig import ( from .qconfig import (
QConfig, QConfig,
calibration_qconfig, calibration_qconfig,
@@ -20,4 +19,15 @@ from .qconfig import (
sync_ema_fakequant_qconfig, sync_ema_fakequant_qconfig,
tqt_qconfig, tqt_qconfig,
) )
from .utils import QuantMode
from .quantize import (
apply_easy_quant,
disable_fake_quant,
disable_observer,
enable_fake_quant,
enable_observer,
propagate_qconfig,
quantize,
quantize_qat,
reset_qconfig,
)
from .utils import QParams, QuantMode, create_qparams

+ 61
- 50
imperative/python/megengine/quantization/fake_quant.py View File

@@ -6,40 +6,48 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math import math
from typing import Union


from .. import functional as F from .. import functional as F
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype
from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes
from ..logger import get_logger
from ..module import Module from ..module import Module
from ..tensor import Parameter, Tensor
from .utils import QuantMode, fake_quant_tensor, get_qparam_dict, tqt_forward
from ..tensor import Parameter
from .utils import (
QParams,
QParamsModuleMixin,
QuantMode,
create_qparams,
fake_quant_tensor,
tqt_forward,
)


logger = get_logger(__name__)


class _FakeQuantize(Module):
r"""
A Basic Fake Quant module.

:param dtype: a string indicating the target quantization type of input.
:param narrow_range: whether the absolute value of ``qmin`` is the same as ``qmax``,
instead of 1 greater. Usually True for weight and False for activation.
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
"""


class _FakeQuantize(Module):
def __init__( def __init__(
self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs
self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
): ):
super().__init__() super().__init__()
if not dtype in _metadata_dict.keys():
raise ValueError(
"unknown dtype: {}, only support {}".format(
dtype, _metadata_dict.keys()
if isinstance(dtype, str):
if not dtype in _builtin_quant_dtypes:
raise ValueError(
"unknown dtype: {}, only support {}".format(
dtype, _builtin_quant_dtypes.keys()
)
) )
dtype = _builtin_quant_dtypes[dtype]
if "narrow_range" in kwargs:
del kwargs["narrow_range"]
logger.warning(
"FakeQuantize currently has no narrow_range param "
"so it is ignored here",
exc_info=DeprecationWarning,
) )
self.dtype = dtype self.dtype = dtype
self.narrow_range = narrow_range
self.qmin = (
-_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin
)
self.qmax = _metadata_dict[dtype].qmax
self.qmin = dtype.qmin
self.qmax = dtype.qmax
self.enabled = enable self.enabled = enable


def enable(self): def enable(self):
@@ -48,61 +56,64 @@ class _FakeQuantize(Module):
def disable(self): def disable(self):
self.enabled = False self.enabled = False


def fake_quant_forward(self, inp, q_dict=None):
return inp
def fake_quant_forward(self, inp, qparams: QParams = None):
raise NotImplementedError


def normal_foward(self, inp, q_dict=None):
def normal_foward(self, inp, qparams: QParams = None):
return inp return inp


def forward(self, inp, q_dict=None):
def forward(self, inp, qparams: QParams = None):
if self.enabled: if self.enabled:
return self.fake_quant_forward(inp, q_dict=q_dict)
return self.fake_quant_forward(inp, qparams=qparams)
else: else:
return self.normal_foward(inp, q_dict=q_dict)
return self.normal_foward(inp, qparams=qparams)




class TQT(_FakeQuantize):
class TQT(_FakeQuantize, QParamsModuleMixin):
r""" r"""
TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks. for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks.

:param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
quantization dtype of input.
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
""" """


def __init__( def __init__(
self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs
self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
): ):
super().__init__(dtype, narrow_range, enable, **kwargs)
super().__init__(dtype, enable, **kwargs)
self.scale = Parameter(0.0, dtype="float32") self.scale = Parameter(0.0, dtype="float32")


def fake_quant_forward(self, inp, q_dict=None):
def fake_quant_forward(self, inp, qparams: QParams = None):
# when enable, TQT will do fakequant forward, finetune the scale # when enable, TQT will do fakequant forward, finetune the scale
return tqt_forward(self.qmin, self.qmax, inp, self.scale) return tqt_forward(self.qmin, self.qmax, inp, self.scale)


def get_qparams(self):
q_dict = get_qparam_dict(QuantMode.SYMMERTIC)
q_dict["scale"] = 2 ** self.scale.detach()
return q_dict

def set_qparams(self, q_dict):
def set_qparams(self, qparams: QParams):
assert ( assert (
q_dict["mode"] == QuantMode.SYMMERTIC
qparams.mode == QuantMode.SYMMERTIC
), "only symmetric quantization is supported by TQT" ), "only symmetric quantization is supported by TQT"
if "scale" not in q_dict or q_dict["scale"] is None:
if qparams.scale is None:
raise AssertionError("Can not get an initialized scale") raise AssertionError("Can not get an initialized scale")
self.scale._reset(F.log(q_dict["scale"]) / math.log(2))
self.scale[...] = F.log(qparams.scale) / math.log(2)


def get_dtype(self):
q_dict = self.get_qparams()
scale = None if "scale" not in q_dict else q_dict["scale"].numpy()
zero_point = (
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()
)
return get_quantized_dtype(self.dtype, scale, zero_point)
def get_qparams(self):
return create_qparams(QuantMode.SYMMERTIC, self.dtype, scale=2 ** self.scale)




class FakeQuantize(_FakeQuantize): class FakeQuantize(_FakeQuantize):
r""" r"""
A module to do quant and dequant according to observer's scale and zero_point. A module to do quant and dequant according to observer's scale and zero_point.

:param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
quantization dtype of input.
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
""" """


def fake_quant_forward(self, inp, q_dict=None):
return fake_quant_tensor(inp, self.qmin, self.qmax, q_dict)
def fake_quant_forward(self, inp, qparams: QParams = None):
assert (
qparams.dtype_meta is self.dtype
), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
qparams.dtype_meta, self.dtype
)
return fake_quant_tensor(inp, qparams)

+ 2
- 0
imperative/python/megengine/quantization/internal_fake_quant.py View File

@@ -16,4 +16,6 @@ from ..autodiff import Function
from .fake_quant import _FakeQuantize from .fake_quant import _FakeQuantize
from .observer import MinMaxObserver from .observer import MinMaxObserver
from .qconfig import QConfig from .qconfig import QConfig
from .utils import QParams




+ 84
- 80
imperative/python/megengine/quantization/observer.py View File

@@ -8,51 +8,51 @@
import math import math
from abc import abstractmethod from abc import abstractmethod
from copy import deepcopy from copy import deepcopy
from typing import Union


import numpy as np import numpy as np


from .. import functional as F from .. import functional as F
from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype
from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes
from ..distributed import WORLD, get_rank, is_distributed from ..distributed import WORLD, get_rank, is_distributed
from ..functional.distributed import all_reduce_max, all_reduce_min from ..functional.distributed import all_reduce_max, all_reduce_min
from ..logger import get_logger
from ..module import Module from ..module import Module
from ..tensor import Tensor from ..tensor import Tensor
from .utils import QuantMode, get_qparam_dict
from .utils import QParams, QParamsModuleMixin, QuantMode, create_qparams


logger = get_logger(__name__)


class Observer(Module):

class Observer(Module, QParamsModuleMixin):
r""" r"""
A base class for Observer Module. A base class for Observer Module.


:param dtype: a string indicating to collect scale and zero_point of which dtype. :param dtype: a string indicating to collect scale and zero_point of which dtype.
:param narrow_range: whether the absolute value of ``qmin`` is the same as ``qmax``,
instead of 1 greater. Usually True for weight and False for activation.
""" """


def __init__(self, dtype: str, narrow_range: bool = False, **kwargs):
def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
super().__init__() super().__init__()
if dtype not in _metadata_dict.keys():
raise ValueError(
"unknown dtype: {}, only support {}".format(
dtype, _metadata_dict.keys()
if isinstance(dtype, str):
if not dtype in _builtin_quant_dtypes:
raise ValueError(
"unknown dtype: {}, only support {}".format(
dtype, _builtin_quant_dtypes.keys()
)
) )
dtype = _builtin_quant_dtypes[dtype]
if "narrow_range" in kwargs:
del kwargs["narrow_range"]
logger.warning(
"FakeQuantize currently has no narrow_range param "
"so it is ignored here",
exc_info=DeprecationWarning,
) )
self.dtype = dtype self.dtype = dtype
self.narrow_range = narrow_range
self.qmin = (
-_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin
)
self.qmax = _metadata_dict[dtype].qmax
self.qmin = dtype.qmin
self.qmax = dtype.qmax
self.enabled = True self.enabled = True


def get_dtype(self):
q_dict = self.get_qparams()
numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy()
numpy_zero_point = (
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()
)
return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point)

def enable(self): def enable(self):
self.enabled = True self.enabled = True


@@ -70,21 +70,16 @@ class Observer(Module):
def forward(self, x): def forward(self, x):
pass pass


@abstractmethod
def get_qparams(self, **kwargs):
pass



class MinMaxObserver(Observer): class MinMaxObserver(Observer):
def __init__( def __init__(
self, self,
mode=QuantMode.SYMMERTIC,
eps=0.00001,
dtype="qint8",
narrow_range: bool = False,
mode: QuantMode = QuantMode.SYMMERTIC,
eps: float = 0.00001,
dtype: Union[str, QuantDtypeMeta] = "qint8",
**kwargs **kwargs
): ):
super().__init__(dtype, narrow_range, **kwargs)
super().__init__(dtype, **kwargs)
self.mode = mode self.mode = mode
self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32) self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32)
self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32) self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32)
@@ -93,26 +88,22 @@ class MinMaxObserver(Observer):
def _calculate_qparams(self, inp_min_val, inp_max_val): def _calculate_qparams(self, inp_min_val, inp_max_val):
min_val = F.minimum(0.0, inp_min_val) min_val = F.minimum(0.0, inp_min_val)
max_val = F.maximum(0.0, inp_max_val) max_val = F.maximum(0.0, inp_max_val)
q_dict = get_qparam_dict(self.mode)
q_dict["min_val"] = inp_min_val
q_dict["max_val"] = inp_max_val
q_dict["enable_observer"] = self.enable
if self.mode == QuantMode.SYMMERTIC: if self.mode == QuantMode.SYMMERTIC:
symmetric_max_vals = F.maximum(-min_val, max_val) symmetric_max_vals = F.maximum(-min_val, max_val)
# use maximun to avoid scale too small at the begin # use maximun to avoid scale too small at the begin
q_dict["scale"] = F.maximum(
scale = F.maximum(
symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
) )
# zero_point = self.zero_point
zero_point = None
else: else:
# use maximun to avoid scale too small at the begin # use maximun to avoid scale too small at the begin
q_dict["scale"] = F.maximum(
scale = F.maximum(
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit
) )
# caculate zero_point # caculate zero_point
q_dict["zero_point"] = self.qmin - F.round(min_val / q_dict["scale"])
zero_point = self.qmin - F.round((min_val / scale))


return q_dict
return create_qparams(self.mode, self.dtype, scale=scale, zero_point=zero_point)


def get_qparams(self): def get_qparams(self):
return self._calculate_qparams(self.min_val, self.max_val) return self._calculate_qparams(self.min_val, self.max_val)
@@ -122,8 +113,8 @@ class MinMaxObserver(Observer):
# stop gradient # stop gradient
x = x_orig.detach() x = x_orig.detach()
# find max and min # find max and min
self.min_val._reset(F.minimum(self.min_val, x.min()))
self.max_val._reset(F.maximum(self.max_val, x.max()))
self.min_val[...] = F.minimum(self.min_val, x.min())
self.max_val[...] = F.maximum(self.max_val, x.max())
return x_orig return x_orig




@@ -137,42 +128,43 @@ class SyncMinMaxObserver(MinMaxObserver):
else: else:
min_x = x.min() min_x = x.min()
max_x = x.max() max_x = x.max()
self.min_val._reset(F.minimum(self.min_val, min_x))
self.max_val._reset(F.maximum(self.max_val, max_x))
self.min_val[...] = F.minimum(self.min_val, min_x)
self.max_val[...] = F.maximum(self.max_val, max_x)
return x_orig return x_orig




class ExponentialMovingAverageObserver(MinMaxObserver): class ExponentialMovingAverageObserver(MinMaxObserver):
def __init__( def __init__(
self, self,
momentum=0.9,
mode=QuantMode.SYMMERTIC,
eps=0.00001,
dtype="qint8",
narrow_range: bool = False,
momentum: float = 0.9,
mode: QuantMode = QuantMode.SYMMERTIC,
eps: float = 0.00001,
dtype: Union[str, QuantDtypeMeta] = "qint8",
**kwargs **kwargs
): ):
super().__init__(mode, eps, dtype, narrow_range, **kwargs)
super().__init__(mode, eps, dtype, **kwargs)
self.momentum = Tensor(momentum, dtype="float32") self.momentum = Tensor(momentum, dtype="float32")
# used to avoid if-clauses in the first forward which is not supported
# in trace mode.
self.runtime_momentum = Tensor(0.0) self.runtime_momentum = Tensor(0.0)


def set_momentum(self, momentum): def set_momentum(self, momentum):
self.momentum = Tenosr(momentum, dtype="float32")
self.momentum = Tensor(momentum, dtype="float32")


def forward(self, x_orig): def forward(self, x_orig):
if self.enabled: if self.enabled:
# stop gradient # stop gradient
x = x_orig.detach() x = x_orig.detach()
# Exponential Moving Average # Exponential Moving Average
self.min_val._reset(
self.min_val[...] = (
self.min_val * self.runtime_momentum self.min_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * x.min() + (1 - self.runtime_momentum) * x.min()
) )
self.max_val._reset(
self.max_val[...] = (
self.max_val * self.runtime_momentum self.max_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * x.max() + (1 - self.runtime_momentum) * x.max()
) )
self.runtime_momentum = self.momentum
self.runtime_momentum[...] = self.momentum


return x_orig return x_orig


@@ -187,33 +179,34 @@ class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
else: else:
min_x = x.min() min_x = x.min()
max_x = x.max() max_x = x.max()
self.min_val._reset(
self.min_val[...] = (
self.min_val * self.runtime_momentum self.min_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * min_x + (1 - self.runtime_momentum) * min_x
) )
self.max_val._reset(
self.max_val[...] = (
self.max_val * self.runtime_momentum self.max_val * self.runtime_momentum
+ (1 - self.runtime_momentum) * max_x + (1 - self.runtime_momentum) * max_x
) )
self.runtime_momentum = self.momentum
self.runtime_momentum[...] = self.momentum
return x_orig return x_orig




class HistogramObserver(MinMaxObserver): class HistogramObserver(MinMaxObserver):
def __init__( def __init__(
self, self,
bins=2048,
upsample_rate=128,
mode=QuantMode.SYMMERTIC,
eps=0.00001,
dtype="qint8",
narrow_range: bool = False,
bins: int = 2048,
upsample_rate: int = 128,
mode: QuantMode = QuantMode.SYMMERTIC,
eps: float = 0.00001,
dtype: Union[str, QuantDtypeMeta] = "qint8",
**kwargs **kwargs
): ):
super().__init__(mode, eps, dtype, narrow_range, **kwargs)
super().__init__(mode, eps, dtype, **kwargs)
self.bins = bins self.bins = bins
self.upsample_rate = upsample_rate self.upsample_rate = upsample_rate
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1
self.dst_nbins = (
_builtin_quant_dtypes[dtype].qmax - _builtin_quant_dtypes[dtype].qmin + 1
)
self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32") self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32")


def _non_linear_param_search(self): def _non_linear_param_search(self):
@@ -450,34 +443,45 @@ class HistogramObserver(MinMaxObserver):


class PassiveObserver(Observer): class PassiveObserver(Observer):
r""" r"""
This class can be set :attr:`scale` derectly.
An Observer that supports setting :attr:`scale` directly.
""" """


def __init__(self, dtype: str, narrow_range: bool = False, **kwargs):
super().__init__(dtype, narrow_range, **kwargs)
self.q_dict = None
def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
super().__init__(dtype, **kwargs)
self.qparams = None
self.orig_scale = None self.orig_scale = None


@property @property
def scale(self): def scale(self):
return self.q_dict["scale"]
return self.qparams.scale


@scale.setter @scale.setter
def scale(self, value):
assert value > 0
self.q_dict["scale"][...] = Tensor(value)
def scale(self, value: np.ndarray):
assert np.all(value > 0)
self.qparams.scale[...] = Tensor(value)


def get_qparams(self): def get_qparams(self):
return self.q_dict
return self.qparams


def set_qparams(self, q_dict):
self.q_dict = deepcopy(q_dict)
if "scale" not in q_dict or q_dict["scale"] is None:
def set_qparams(self, qparams: QParams):
"""
:param qparams: used to set initial scale.
"""
self.qparams = deepcopy(qparams)
if qparams.scale is None:
raise AssertionError("Can not get an initialized scale") raise AssertionError("Can not get an initialized scale")
self.orig_scale = q_dict["scale"].numpy()
if qparams.dtype_meta is None:
qparams.dtype_meta = self.dtype
else:
assert (
qparams.dtype_meta is self.dtype
), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
qparams.dtype_meta, self.dtype
)
self.orig_scale = qparams.scale.numpy()


def forward(self, x): def forward(self, x):
r""" r"""
Just return input because :attr:`q_dict` is set by :func:`~.apply_easy_quant`.
Just return input because :attr:`qparams` is set by :func:`~.apply_easy_quant`.
""" """
return x return x

+ 44
- 61
imperative/python/megengine/quantization/qconfig.py View File

@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from collections import namedtuple
from functools import partial from functools import partial


from ..module import Module from ..module import Module
@@ -19,7 +20,13 @@ from .observer import (
) )




class QConfig:
# use namedtuple to make class immutable, comparable and easy to print
class QConfig(
namedtuple(
"QConfig",
["weight_observer", "act_observer", "weight_fake_quant", "act_fake_quant"],
)
):
r""" r"""
A config class indicating how to do quantize toward :class:`~.QATModule`'s A config class indicating how to do quantize toward :class:`~.QATModule`'s
``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage. ``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage.
@@ -37,90 +44,66 @@ class QConfig:


# Default EMA QConfig for QAT. # Default EMA QConfig for QAT.
ema_fakequant_qconfig = QConfig( ema_fakequant_qconfig = QConfig(
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False),
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8"),
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
act_fake_quant=partial(FakeQuantize, dtype="qint8"),
) )


Each parameter is a ``class`` rather than an instance. And we recommand using ``functools.partial`` Each parameter is a ``class`` rather than an instance. And we recommand using ``functools.partial``
to add initialization parameters of the ``class``, so that don't need to provide parameters in to add initialization parameters of the ``class``, so that don't need to provide parameters in
:meth:`~.QATModule.set_qconfig`. :meth:`~.QATModule.set_qconfig`.


Usually we set ``narrow_range`` of weight related paramters to ``True`` and of activation related
parameters to ``False``. For the result of multiplication and addition as ``a * b + c * d``, if
four variables are all -128 of dtype ``qint8``, then the result will be ``2^15`` and cause overflow.
Weights are commonly calculated in this way, so needed to narrow the range.
Usually we choose narrow version dtype (like ``qint8_narrow``) for weight related
paramters and normal version for activation related ones. For the result of
multiplication and addition as ``a * b + c * d``, if four variables are all -128 of
dtype ``qint8``, then the result will be ``2^15`` and cause overflow.
Weights are commonly calculated in this way, so need to narrow qmin to -127.
""" """


def __init__(
self, weight_observer, act_observer, weight_fake_quant, act_fake_quant
):
def __new__(cls, weight_observer, act_observer, weight_fake_quant, act_fake_quant):
if isinstance(act_observer, Module) or isinstance(weight_observer, Module): if isinstance(act_observer, Module) or isinstance(weight_observer, Module):
raise ValueError( raise ValueError(
"QConfig must not receive observer instance, please pass observer" "QConfig must not receive observer instance, please pass observer"
" class generator using `partial(Observer, ...)` instead. Use" " class generator using `partial(Observer, ...)` instead. Use"
" partial(MyObserver, x=1) to override arguments to constructor if needed" " partial(MyObserver, x=1) to override arguments to constructor if needed"
) )
self.weight_observer = weight_observer
self.act_observer = act_observer
self.weight_fake_quant = weight_fake_quant
self.act_fake_quant = act_fake_quant

def __eq__(self, other):
def eq(a, b):
if isinstance(a, partial) and isinstance(b, partial):
return all(
[a.func == b.func, a.args == b.args, a.keywords == b.keywords]
)
else:
return a == b

return (
eq(self.weight_observer, other.weight_observer)
and eq(self.act_observer, other.act_observer)
and eq(self.weight_fake_quant, other.weight_fake_quant)
and eq(self.act_fake_quant, other.act_fake_quant)
return super().__new__(
cls, weight_observer, act_observer, weight_fake_quant, act_fake_quant
) )




min_max_fakequant_qconfig = QConfig( min_max_fakequant_qconfig = QConfig(
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False),
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
act_observer=partial(MinMaxObserver, dtype="qint8"),
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
act_fake_quant=partial(FakeQuantize, dtype="qint8"),
) )


ema_fakequant_qconfig = QConfig( ema_fakequant_qconfig = QConfig(
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
act_observer=partial(
ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False
),
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8"),
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
act_fake_quant=partial(FakeQuantize, dtype="qint8"),
) )


sync_ema_fakequant_qconfig = QConfig( sync_ema_fakequant_qconfig = QConfig(
weight_observer=partial(SyncMinMaxObserver, dtype="qint8", narrow_range=True),
act_observer=partial(
SyncExponentialMovingAverageObserver, dtype="qint8", narrow_range=False
),
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
weight_observer=partial(SyncMinMaxObserver, dtype="qint8_narrow"),
act_observer=partial(SyncExponentialMovingAverageObserver, dtype="qint8"),
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
act_fake_quant=partial(FakeQuantize, dtype="qint8"),
) )


ema_lowbit_fakequant_qconfig = QConfig( ema_lowbit_fakequant_qconfig = QConfig(
weight_observer=partial(MinMaxObserver, dtype="qint4", narrow_range=False),
act_observer=partial(
ExponentialMovingAverageObserver, dtype="qint4", narrow_range=False
),
weight_fake_quant=partial(FakeQuantize, dtype="qint4", narrow_range=False),
act_fake_quant=partial(FakeQuantize, dtype="qint4", narrow_range=False),
weight_observer=partial(MinMaxObserver, dtype="qint4"),
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint4"),
weight_fake_quant=partial(FakeQuantize, dtype="qint4"),
act_fake_quant=partial(FakeQuantize, dtype="qint4"),
) )


calibration_qconfig = QConfig( calibration_qconfig = QConfig(
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
act_observer=partial(HistogramObserver, dtype="qint8", narrow_range=False),
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
act_observer=partial(HistogramObserver, dtype="qint8"),
weight_fake_quant=None, weight_fake_quant=None,
act_fake_quant=None, act_fake_quant=None,
) )
@@ -128,15 +111,15 @@ calibration_qconfig = QConfig(
tqt_qconfig = QConfig( tqt_qconfig = QConfig(
weight_observer=None, weight_observer=None,
act_observer=None, act_observer=None,
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True),
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False),
weight_fake_quant=partial(TQT, dtype="qint8_narrow"),
act_fake_quant=partial(TQT, dtype="qint8"),
) )


passive_qconfig = QConfig( passive_qconfig = QConfig(
weight_observer=partial(PassiveObserver, dtype="qint8", narrow_range=True),
act_observer=partial(PassiveObserver, dtype="qint8", narrow_range=False),
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
weight_observer=partial(PassiveObserver, dtype="qint8_narrow"),
act_observer=partial(PassiveObserver, dtype="qint8"),
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
act_fake_quant=partial(FakeQuantize, dtype="qint8"),
) )


easyquant_qconfig = passive_qconfig easyquant_qconfig = passive_qconfig

+ 12
- 9
imperative/python/megengine/quantization/quantize.py View File

@@ -18,6 +18,7 @@ from ..module import qat as QAT
from ..module import quantized as Quantized from ..module import quantized as Quantized
from ..module.qat import QATModule from ..module.qat import QATModule
from ..module.quantized import QuantizedModule from ..module.quantized import QuantizedModule
from ..tensor import Tensor
from .qconfig import QConfig, ema_fakequant_qconfig from .qconfig import QConfig, ema_fakequant_qconfig




@@ -147,10 +148,10 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True):
if not inplace: if not inplace:
module = deepcopy(module) module = deepcopy(module)


def safe_call(func, q_dict):
def safe_call(func, qparams):
inst = func() if func is not None else None inst = func() if func is not None else None
if inst is not None and getattr(inst, "set_qparams", None) is not None: if inst is not None and getattr(inst, "set_qparams", None) is not None:
inst.set_qparams(q_dict)
inst.set_qparams(qparams)
return inst return inst


def is_qat(mod: Module): def is_qat(mod: Module):
@@ -158,13 +159,13 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True):


for m in list(module._flatten(predicate=is_qat)): for m in list(module._flatten(predicate=is_qat)):
if m.with_weight: if m.with_weight:
weight_q_dict = m.get_weight_qparams()
m.weight_observer = safe_call(qconfig.weight_observer, weight_q_dict)
m.weight_fake_quant = safe_call(qconfig.weight_fake_quant, weight_q_dict)
weight_params = m.get_weight_qparams()
m.weight_observer = safe_call(qconfig.weight_observer, weight_params)
m.weight_fake_quant = safe_call(qconfig.weight_fake_quant, weight_params)
if m.with_act: if m.with_act:
act_q_dict = m.get_activation_qparams()
m.act_observer = safe_call(qconfig.act_observer, act_q_dict)
m.act_fake_quant = safe_call(qconfig.act_fake_quant, act_q_dict)
act_params = m.get_activation_qparams()
m.act_observer = safe_call(qconfig.act_observer, act_params)
m.act_fake_quant = safe_call(qconfig.act_fake_quant, act_params)


return module return module


@@ -202,7 +203,9 @@ def hook_qat_module(module: Module, func: Callable):
return hooks return hooks




def apply_easy_quant(module, data, start=0.8, stop=1.2, num=40):
def apply_easy_quant(
module: Module, data: Tensor, start: float = 0.8, stop: float = 1.2, num: int = 40
):
r""" r"""
Implementation of ``EasyQuant``: https://arxiv.org/pdf/2006.16669. Implementation of ``EasyQuant``: https://arxiv.org/pdf/2006.16669.
Search for optimal scales. Search for optimal scales.


+ 111
- 31
imperative/python/megengine/quantization/utils.py View File

@@ -5,9 +5,10 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import abc
from enum import Enum from enum import Enum
from functools import partial, update_wrapper, wraps from functools import partial, update_wrapper, wraps
from typing import Dict
from typing import Union


import numpy as np import numpy as np


@@ -15,7 +16,11 @@ from .. import functional as F
from ..autodiff import Function from ..autodiff import Function
from ..core._imperative_rt.core2 import apply from ..core._imperative_rt.core2 import apply
from ..core.ops import builtin from ..core.ops import builtin
from ..core.tensor.dtype import _metadata_dict
from ..core.tensor.dtype import (
QuantDtypeMeta,
_builtin_quant_dtypes,
create_quantized_dtype,
)
from ..tensor import Tensor from ..tensor import Tensor




@@ -61,37 +66,100 @@ class QuantMode(Enum):
ASYMMERTIC = 2 ASYMMERTIC = 2




qparam_dict = {
QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None},
QuantMode.ASYMMERTIC: {
"mode": QuantMode.ASYMMERTIC,
"scale": None,
"zero_point": None,
},
class QParams:
"""
To standardize FakeQuant, Observer and Tensor's qparams format. If custom
qparams is needed, inherit this class and add custom ``__slots__``.
"""

__slots__ = "mode", "dtype_meta", "scale", "zero_point"

def __init__(
self,
mode: QuantMode,
dtype_meta: QuantDtypeMeta,
scale: Tensor,
zero_point: Tensor,
):
self.mode = mode
self.dtype_meta = dtype_meta
self.scale = scale
self.zero_point = zero_point

def update(self, qparams: "QParams"):
for key in self.__slots__:
setattr(self, key, getattr(qparams, key))

def __eq__(self, other):
if len(self.__slots__) != len(other.__slots__):
return False
for key in self.__slots__:
if not hasattr(other, key) or getattr(self, key) != getattr(other, key):
return False
return True

def __repr__(self):
content = ", ".join(
["{}={}".format(key, getattr(self, key)) for key in self.__slots__]
)
return "QParams({})".format(content)


class QParamsModuleMixin(abc.ABC):
def get_quantized_dtype(self):
qparams = self.get_qparams()
dtype = qparams.dtype_meta
scale = float(qparams.scale.numpy()) if qparams.scale is not None else None
zero_point = (
int(qparams.zero_point.numpy()) if qparams.zero_point is not None else None
)
return create_quantized_dtype(dtype, scale, zero_point)

@abc.abstractmethod
def get_qparams(self) -> QParams:
pass


_builtin_qparams = {
QuantMode.SYMMERTIC: partial(QParams, mode=QuantMode.SYMMERTIC),
QuantMode.ASYMMERTIC: partial(QParams, mode=QuantMode.ASYMMERTIC),
} }




def get_qparam_dict(mode: QuantMode):
def create_qparams(
mode: QuantMode = QuantMode.SYMMERTIC,
dtype_meta: Union[str, QuantDtypeMeta] = None,
scale: Tensor = None,
zero_point: Tensor = None,
):
""" """
Return the quantization parameters dictionary according to the mode.
Return :class:`~.QParams` according to the mode.
""" """
return qparam_dict.get(mode, None)
if isinstance(dtype_meta, str):
dtype_meta = _builtin_quant_dtypes[dtype_meta]
if mode is None:
return QParams(mode, dtype_meta, scale, zero_point)
assert isinstance(mode, QuantMode)
return _builtin_qparams[mode](
dtype_meta=dtype_meta, scale=scale, zero_point=zero_point
)




def fake_quant_tensor(inp: Tensor, qmin: int, qmax: int, q_dict: Dict) -> Tensor:
def fake_quant_tensor(inp: Tensor, qparams: QParams) -> Tensor:
""" """
Apply fake quantization to the inp tensor. Apply fake quantization to the inp tensor.


:param inp: the input tensor which need to be faked. :param inp: the input tensor which need to be faked.
:param qmin: the minimum value which the integer limit to.
:param qmax: the maximum value which the integer limit to.
:param q_dict: the quantization parameter dict.
:param qparams: to get mode, qmin, qmax, scale and zero_point from.


""" """
scale = q_dict["scale"]
zero_point = Tensor([0.0], dtype=np.float32)
if q_dict["mode"] == QuantMode.ASYMMERTIC:
zero_point = q_dict["zero_point"]
scale = qparams.scale
if qparams.mode == QuantMode.ASYMMERTIC:
zero_point = qparams.zero_point
else:
zero_point = Tensor([0.0], dtype=np.float32)
qmin = qparams.dtype_meta.qmin
qmax = qparams.dtype_meta.qmax


op = builtin.FakeQuant(qmin=qmin, qmax=qmax) op = builtin.FakeQuant(qmin=qmin, qmax=qmax)
return apply(op, inp, scale, zero_point)[0] return apply(op, inp, scale, zero_point)[0]
@@ -104,22 +172,34 @@ def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:


:param bias: the bias tensor which need to be faked. :param bias: the bias tensor which need to be faked.
:param inp: the input tensor which contain the quantization parameters. :param inp: the input tensor which contain the quantization parameters.
:param qmax: the weight tensor which contain the quantization parameters.
:param w_qat: the weight tensor which contain the quantization parameters.


.. warning:: .. warning::
Only work for symmetric quantization method now. Only work for symmetric quantization method now.


""" """
b_qat = bias b_qat = bias
if hasattr(inp, "q_dict") and b_qat is not None:
if inp.q_dict["scale"] is not None and w_qat.q_dict["scale"] is not None:
# use the same mode with weight.
b_dict = get_qparam_dict(w_qat.q_dict["mode"])
b_dict["scale"] = inp.q_dict["scale"] * w_qat.q_dict["scale"]
# TODO: add zero_point for ASYMMERTIC mode.
qmax = _metadata_dict["qint32"].qmax
qmin = _metadata_dict["qint32"].qmin
b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict)
b_qat.q_dict.update(b_dict)
if (
getattr(inp, "qparams", None) is not None
and getattr(w_qat, "qparams", None) is not None
and bias is not None
):
inp_params = inp.qparams
w_params = w_qat.qparams
if inp_params.scale is not None and w_params.scale is not None:
assert inp_params.mode == w_params.mode, "incompatible QuantMode"
# TODO: support quint8 dtype.
assert (
inp_params.dtype_meta.np_dtype_str == "int8"
and w_params.dtype_meta.np_dtype_str == "int8"
), "fake_quant_bias only support int8 like dtype now"


# use the same mode with weight.
# TODO: avoid hardcode
b_dtype = _builtin_quant_dtypes["qint32"]
b_param = create_qparams(
w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale
)
b_qat = fake_quant_tensor(bias, b_param)
b_qat.qparams.update(b_param)
return b_qat return b_qat

+ 28
- 9
imperative/python/megengine/tensor.py View File

@@ -22,6 +22,8 @@ from .logger import get_logger
from .utils.deprecation import deprecated from .utils.deprecation import deprecated
from .utils.naming import auto_naming from .utils.naming import auto_naming


logger = get_logger(__name__)



class Tensor(_Tensor, ArrayMethodMixin): class Tensor(_Tensor, ArrayMethodMixin):
r""" r"""
@@ -30,7 +32,7 @@ class Tensor(_Tensor, ArrayMethodMixin):


grad = None grad = None
dmap_callback = None dmap_callback = None
_q_dict = None
_qparams = None


def __new__( def __new__(
cls, data, dtype=None, device=None, is_const=False, no_cache=False, name=None cls, data, dtype=None, device=None, is_const=False, no_cache=False, name=None
@@ -50,7 +52,7 @@ class Tensor(_Tensor, ArrayMethodMixin):


if isinstance(data, _Tensor): if isinstance(data, _Tensor):
if dtype is not None: if dtype is not None:
get_logger().warning(
logger.warning(
"dtype does not work when creating a new Tensor with another Tensor" "dtype does not work when creating a new Tensor with another Tensor"
) )
obj = _Tensor.__new__(cls, data) obj = _Tensor.__new__(cls, data)
@@ -101,10 +103,12 @@ class Tensor(_Tensor, ArrayMethodMixin):
return super().dtype return super().dtype


@property @property
def q_dict(self):
if self._q_dict is None:
self._q_dict = {"mode": None, "scale": None, "zero_point": None}
return self._q_dict
def qparams(self):
from .quantization.utils import create_qparams # pylint: disable=all

if self._qparams is None:
self._qparams = create_qparams()
return self._qparams


def numpy(self) -> np.ndarray: def numpy(self) -> np.ndarray:
r""" r"""
@@ -185,14 +189,29 @@ class Tensor(_Tensor, ArrayMethodMixin):
def __getstate__(self): def __getstate__(self):
r""" __getstate__ will be called for pickle serialization or deep copy r""" __getstate__ will be called for pickle serialization or deep copy
""" """

state = { state = {
"qdict": self.q_dict,
"numpy": self.numpy(),
"dtype": self.dtype,
"device": self.device.logical_name,
} }
if self._qparams is not None:
state["qparams"] = self._qparams
return state return state


def __setstate__(self, state): def __setstate__(self, state):
self._q_dict = state.pop("qdict")
from .quantization.utils import create_qparams # pylint: disable=all

if "qdict" in state:
qparams = state.pop("qdict")
logger.warning(
"Tensor's 'qdict' state is depreciated. Use 'qparams' instead"
)
elif "qparams" in state:
qparams = state.pop("qparams")
else:
qparams = None
self._reset(Tensor(state.pop("numpy"), state.pop("dtype"), state.pop("device")))
self._qparams = qparams




tensor = Tensor tensor = Tensor


+ 3
- 3
imperative/python/test/unit/core/test_dtype_quant.py View File

@@ -14,7 +14,7 @@ import pytest
import megengine.core.tensor.megbrain_graph as G import megengine.core.tensor.megbrain_graph as G
from megengine.core.ops import builtin as ops from megengine.core.ops import builtin as ops
from megengine.core.tensor.dtype import ( from megengine.core.tensor.dtype import (
_metadata_dict,
_builtin_quant_dtypes,
convert_from_qint4, convert_from_qint4,
convert_from_qint8, convert_from_qint8,
convert_from_quint4, convert_from_quint4,
@@ -76,10 +76,10 @@ def _get_compiled_result(inp, dtype, shape, device, calc_func=None):




def _check_result_attr(oup, dtype, dtype_str, is_unsigned=True): def _check_result_attr(oup, dtype, dtype_str, is_unsigned=True):
metadata = _metadata_dict[dtype_str]
metadata = _builtin_quant_dtypes[dtype_str]
assert "mgb_dtype" in oup.dtype.metadata assert "mgb_dtype" in oup.dtype.metadata
assert is_quantize(oup.dtype) assert is_quantize(oup.dtype)
np.testing.assert_equal(oup.dtype.metadata["mgb_dtype"]["name"], metadata.name)
np.testing.assert_equal(oup.dtype.metadata["mgb_dtype"]["name"], metadata.cname)
np.testing.assert_allclose(get_scale(oup.dtype), get_scale(dtype)) np.testing.assert_allclose(get_scale(oup.dtype), get_scale(dtype))
if is_unsigned: if is_unsigned:
np.testing.assert_equal(get_zero_point(oup.dtype), get_zero_point(dtype)) np.testing.assert_equal(get_zero_point(oup.dtype), get_zero_point(dtype))


+ 3
- 3
imperative/python/test/unit/core/test_serialization.py View File

@@ -65,9 +65,9 @@ def test_tensor_serialization():


with TemporaryFile() as f: with TemporaryFile() as f:
a = Tensor(0) a = Tensor(0)
a.q_dict["scale"] = Tensor(1.0)
a.qparams.scale = Tensor(1.0)
pickle.dump(a, f) pickle.dump(a, f)
f.seek(0) f.seek(0)
b = pickle.load(f) b = pickle.load(f)
assert isinstance(b.q_dict["scale"], Tensor)
np.testing.assert_equal(b.q_dict["scale"].numpy(), 1.0)
assert isinstance(b.qparams.scale, Tensor)
np.testing.assert_equal(b.qparams.scale.numpy(), 1.0)

+ 22
- 0
imperative/python/test/unit/core/test_tensor_wrapper.py View File

@@ -6,6 +6,8 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import copy

import numpy as np import numpy as np


from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8 from megengine.core.tensor.dtype import get_scale, get_zero_point, qint8, quint8
@@ -86,3 +88,23 @@ def test_as_type():
b = a.astype(quint8(0.3, 128)) b = a.astype(quint8(0.3, 128))
np.testing.assert_almost_equal(get_scale(b.dtype), 0.3) np.testing.assert_almost_equal(get_scale(b.dtype), 0.3)
np.testing.assert_equal(get_zero_point(b.dtype), 128) np.testing.assert_equal(get_zero_point(b.dtype), 128)


def test_qparams():
x = Tensor(1)
assert x.qparams.scale is None
x.qparams.scale = Tensor(1.0)
assert x.qparams.scale.numpy() == 1.0
x2 = copy.copy(x)
assert x.qparams is x2.qparams and x2.qparams.scale.numpy() == 1.0
x3 = copy.deepcopy(x)
assert x.qparams is not x3.qparams and x3.qparams.scale.numpy() == 1.0


def test_name():
x = Tensor(0)
assert x.name == ""
x.name = "x"
assert x.name == "x"
x = Tensor(0, name="x")
assert x.name == "x"

+ 0
- 25
imperative/python/test/unit/functional/test_tensor.py View File

@@ -406,28 +406,3 @@ def test_copy_d2h():
def test_copy_d2d(): def test_copy_d2d():
copy_test("gpu0", "gpu1") copy_test("gpu0", "gpu1")
copy_test("gpu0:0", "gpu0:1") copy_test("gpu0:0", "gpu0:1")


def test_name():
x = tensor(0)
assert x.name == ""
x.name = "x"
assert x.name == "x"
x = tensor(0, name="x")
assert x.name == "x"


def test_q_dict():
x = tensor(1)
assert x.q_dict["scale"] is None
x.q_dict["scale"] = tensor(1.0)

y = tensor(1)
assert y.q_dict["scale"] is None
y.q_dict["scale"] = tensor(2.0)

assert x.q_dict["scale"].numpy() == 1.0
assert y.q_dict["scale"].numpy() == 2.0

z = x + y
assert z.q_dict["scale"] is None

+ 11
- 7
imperative/python/test/unit/quantization/test_fake_quant.py View File

@@ -12,9 +12,15 @@ import pytest
import megengine as mge import megengine as mge
from megengine import tensor from megengine import tensor
from megengine.core.autodiff.grad import Function, Grad from megengine.core.autodiff.grad import Function, Grad
from megengine.core.tensor.dtype import QuantDtypeMeta
from megengine.core.tensor.utils import make_shape_tuple from megengine.core.tensor.utils import make_shape_tuple
from megengine.quantization.internal_fake_quant import * from megengine.quantization.internal_fake_quant import *
from megengine.quantization.utils import QuantMode, fake_quant_tensor, tqt_forward
from megengine.quantization.utils import (
QuantMode,
create_qparams,
fake_quant_tensor,
tqt_forward,
)




class TQT_numpy: class TQT_numpy:
@@ -111,16 +117,14 @@ def fake_quant_tensor_gt(inp, scale, zero_point, qmin, qmax):
def test_fakequant(): def test_fakequant():
qmin = -126 qmin = -126
qmax = 129 qmax = 129
test_dtype = QuantDtypeMeta("test_qint8", None, "int8", qmin, qmax)


def run(zero_point, scale): def run(zero_point, scale):
q_dict = {}
q_dict["mode"] = QuantMode.ASYMMERTIC
q_dict["scale"] = scale
q_dict["zero_point"] = zero_point
qparams = create_qparams(QuantMode.ASYMMERTIC, test_dtype, scale, zero_point)
inp_data = np.random.uniform(low=-512.0, high=512.0, size=(1, 32, 32, 32)) inp_data = np.random.uniform(low=-512.0, high=512.0, size=(1, 32, 32, 32))
inp = tensor(inp_data, dtype=np.float32) inp = tensor(inp_data, dtype=np.float32)
# test forward # test forward
oup = fake_quant_tensor(inp, qmin, qmax, q_dict).numpy()
oup = fake_quant_tensor(inp, qparams).numpy()
oup_gt = fake_quant_tensor_gt(inp, scale, zero_point, qmin, qmax).numpy() oup_gt = fake_quant_tensor_gt(inp, scale, zero_point, qmin, qmax).numpy()
assert np.allclose(oup, oup_gt) assert np.allclose(oup, oup_gt)
assert oup.shape == oup_gt.shape assert oup.shape == oup_gt.shape
@@ -128,7 +132,7 @@ def test_fakequant():
# test backward # test backward
x = tensor(inp_data, dtype=np.float32) x = tensor(inp_data, dtype=np.float32)
grad = Grad().wrt(x, callback=_save_to(x)) grad = Grad().wrt(x, callback=_save_to(x))
y = fake_quant_tensor(x, qmin, qmax, q_dict)
y = fake_quant_tensor(x, qparams)
grad(y, tensor(F.ones_like(x))) grad(y, tensor(F.ones_like(x)))


x1 = tensor(inp_data, dtype=np.float32) x1 = tensor(inp_data, dtype=np.float32)


+ 16
- 10
imperative/python/test/unit/quantization/test_module.py View File

@@ -10,7 +10,13 @@ import megengine.module.qat as QAT
import megengine.module.quantized as Q import megengine.module.quantized as Q
from megengine import Parameter, Tensor from megengine import Parameter, Tensor
from megengine.core.tensor import dtype from megengine.core.tensor import dtype
from megengine.quantization import FakeQuantize, MinMaxObserver, QConfig
from megengine.quantization import (
FakeQuantize,
MinMaxObserver,
QConfig,
QuantMode,
create_qparams,
)
from megengine.quantization.quantize import ( from megengine.quantization.quantize import (
disable_fake_quant, disable_fake_quant,
disable_observer, disable_observer,
@@ -18,10 +24,10 @@ from megengine.quantization.quantize import (
) )


min_max_fakequant_qconfig = QConfig( min_max_fakequant_qconfig = QConfig(
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False),
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
act_observer=partial(MinMaxObserver, dtype="qint8"),
weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
act_fake_quant=partial(FakeQuantize, dtype="qint8"),
) )


inp_scale = np.float32(np.random.rand() + 1) inp_scale = np.float32(np.random.rand() + 1)
@@ -111,7 +117,7 @@ def test_dequant_stub():


x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x = fake_quant_act(x, inp_scale) x = fake_quant_act(x, inp_scale)
x.q_dict["scale"] = inp_scale
x.qparams.scale = inp_scale


normal = normal_net(x) normal = normal_net(x)
qat_without_fakequant = qat_from_float(x) qat_without_fakequant = qat_from_float(x)
@@ -146,12 +152,12 @@ def test_elemwise(kind):
x1_scale = np.float32(np.random.rand() + 1) x1_scale = np.float32(np.random.rand() + 1)
x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x1 = fake_quant_act(x1, x1_scale) x1 = fake_quant_act(x1, x1_scale)
x1.q_dict["scale"] = x1_scale
x1.qparams.scale = x1_scale


x2_scale = np.float32(np.random.rand() + 1) x2_scale = np.float32(np.random.rand() + 1)
x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x2 = fake_quant_act(x2, x2_scale) x2 = fake_quant_act(x2, x2_scale)
x2.q_dict["scale"] = x2_scale
x2.qparams.scale = x2_scale


x1_int8 = quant(x1, x1_scale) x1_int8 = quant(x1, x1_scale)
x2_int8 = quant(x2, x2_scale) x2_int8 = quant(x2, x2_scale)
@@ -187,7 +193,7 @@ def test_linear():


x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x = fake_quant_act(x, inp_scale) x = fake_quant_act(x, inp_scale)
x.q_dict["scale"] = inp_scale
x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))


x_int8 = quant(x, inp_scale) x_int8 = quant(x, inp_scale)


@@ -230,7 +236,7 @@ def test_conv(module):


x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32")) x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32"))
x = fake_quant_act(x, inp_scale) x = fake_quant_act(x, inp_scale)
x.q_dict["scale"] = inp_scale
x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale))


x_int8 = quant(x, inp_scale) x_int8 = quant(x, inp_scale)




+ 7
- 6
imperative/python/test/unit/quantization/test_observer.py View File

@@ -6,6 +6,7 @@ import pytest
import megengine as mge import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
from megengine.quantization import QuantMode, create_qparams
from megengine.quantization.observer import ( from megengine.quantization.observer import (
ExponentialMovingAverageObserver, ExponentialMovingAverageObserver,
HistogramObserver, HistogramObserver,
@@ -56,14 +57,14 @@ def test_histogram_observer():




def test_passive_observer(): def test_passive_observer():
q_dict = {"scale": mge.tensor(1.0)}
qparams = create_qparams(QuantMode.SYMMERTIC, "qint8", mge.tensor(1.0))
m = PassiveObserver("qint8") m = PassiveObserver("qint8")
m.set_qparams(q_dict)
m.set_qparams(qparams)
assert m.orig_scale == 1.0 assert m.orig_scale == 1.0
assert m.scale == 1.0
m.scale = 2.0
assert m.scale == 2.0
assert m.get_qparams() == {"scale": mge.tensor(2.0)}
assert m.scale.numpy() == 1.0
assert m.get_qparams().dtype_meta == qparams.dtype_meta
assert m.get_qparams().scale == qparams.scale
assert m.get_qparams() == qparams




@pytest.mark.require_ngpu(2) @pytest.mark.require_ngpu(2)


+ 3
- 2
imperative/python/test/unit/quantization/test_op.py View File

@@ -6,6 +6,7 @@ import megengine.functional as F
from megengine.core.tensor import dtype from megengine.core.tensor import dtype
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
from megengine.functional.elemwise import _elemwise_multi_type, _elwise from megengine.functional.elemwise import _elemwise_multi_type, _elwise
from megengine.quantization import QuantMode, create_qparams




def quant(x, scale): def quant(x, scale):
@@ -26,13 +27,13 @@ def test_elemwise(kind):
x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x1_scale = np.float32(np.random.rand() + 1) x1_scale = np.float32(np.random.rand() + 1)
x1 = fake_quant(x1, x1_scale) x1 = fake_quant(x1, x1_scale)
x1.q_dict["scale"] = x1_scale
x1.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", x1_scale))
x1_int8 = quant(x1, x1_scale) x1_int8 = quant(x1, x1_scale)


x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))
x2_scale = np.float32(np.random.rand() + 1) x2_scale = np.float32(np.random.rand() + 1)
x2 = fake_quant(x2, x2_scale) x2 = fake_quant(x2, x2_scale)
x2.q_dict["scale"] = x2_scale
x2.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", x2_scale))
x2_int8 = quant(x2, x2_scale) x2_int8 = quant(x2, x2_scale)


output_scale = np.float32(np.random.rand() + 1) output_scale = np.float32(np.random.rand() + 1)


+ 0
- 14
imperative/python/test/unit/quantization/test_qconfig.py View File

@@ -1,14 +0,0 @@
from functools import partial

from megengine.quantization import QConfig, tqt_qconfig
from megengine.quantization.fake_quant import TQT


def test_equal():
qconfig = QConfig(
weight_observer=None,
act_observer=None,
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True),
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False),
)
assert qconfig == tqt_qconfig

+ 15
- 15
imperative/python/test/unit/quantization/test_quantize.py View File

@@ -33,7 +33,7 @@ from megengine.quantization.quantize import (
) )




class Net(Float.Module):
class FloatNet(Float.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.quant = Float.QuantStub() self.quant = Float.QuantStub()
@@ -113,25 +113,25 @@ def test_reset_qconfig():
def test_enable_and_disable_observer(): def test_enable_and_disable_observer():
net = init_qat_net() net = init_qat_net()
enable_observer(net) enable_observer(net)
assert net.quant.act_observer.enabled == True
assert net.linear.weight_observer.enabled == True
assert net.linear.act_observer.enabled == True
assert net.quant.act_observer.enabled is True
assert net.linear.weight_observer.enabled is True
assert net.linear.act_observer.enabled is True
disable_observer(net) disable_observer(net)
assert net.quant.act_observer.enabled == False
assert net.linear.weight_observer.enabled == False
assert net.linear.act_observer.enabled == False
assert net.quant.act_observer.enabled is False
assert net.linear.weight_observer.enabled is False
assert net.linear.act_observer.enabled is False




def test_enable_and_disable_fake_quant(): def test_enable_and_disable_fake_quant():
net = init_qat_net() net = init_qat_net()
disable_fake_quant(net) disable_fake_quant(net)
assert net.quant.act_fake_quant.enabled == False
assert net.linear.weight_fake_quant.enabled == False
assert net.linear.act_fake_quant.enabled == False
assert net.quant.act_fake_quant.enabled is False
assert net.linear.weight_fake_quant.enabled is False
assert net.linear.act_fake_quant.enabled is False
enable_fake_quant(net) enable_fake_quant(net)
assert net.quant.act_fake_quant.enabled == True
assert net.linear.weight_fake_quant.enabled == True
assert net.linear.act_fake_quant.enabled == True
assert net.quant.act_fake_quant.enabled is True
assert net.linear.weight_fake_quant.enabled is True
assert net.linear.act_fake_quant.enabled is True




def init_observer(module, data): def init_observer(module, data):
@@ -144,7 +144,7 @@ def init_observer(module, data):


def test_enable_and_disable_all(): def test_enable_and_disable_all():
x = Tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32)) x = Tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32))
net = Net()
net = FloatNet()
y1 = net(x).numpy() y1 = net(x).numpy()
net = quantize_qat(net, min_max_fakequant_qconfig) net = quantize_qat(net, min_max_fakequant_qconfig)


@@ -162,7 +162,7 @@ def test_enable_and_disable_all():




def test_quantize_qat(): def test_quantize_qat():
net = Net()
net = FloatNet()
qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig) qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig)
assert isinstance(qat_net.quant, QAT.QuantStub) assert isinstance(qat_net.quant, QAT.QuantStub)
assert isinstance(qat_net.linear, QAT.Linear) assert isinstance(qat_net.linear, QAT.Linear)


Loading…
Cancel
Save