|
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import math
- from typing import Union
-
- from .. import functional as F
- from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes
- from ..logger import get_logger
- from ..module import Module
- from ..tensor import Parameter
- from .utils import (
- QParams,
- QParamsModuleMixin,
- QuantMode,
- create_qparams,
- fake_quant_tensor,
- tqt_forward,
- )
-
- logger = get_logger(__name__)
-
-
- class _FakeQuantize(Module):
- def __init__(
- self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
- ):
- super().__init__()
- if isinstance(dtype, str):
- if not dtype in _builtin_quant_dtypes:
- raise ValueError(
- "unknown dtype: {}, only support {}".format(
- dtype, _builtin_quant_dtypes.keys()
- )
- )
- dtype = _builtin_quant_dtypes[dtype]
- if "narrow_range" in kwargs:
- del kwargs["narrow_range"]
- logger.warning(
- "FakeQuantize currently has no narrow_range param "
- "so it is ignored here",
- exc_info=DeprecationWarning,
- )
- self.dtype = dtype
- self.qmin = dtype.qmin
- self.qmax = dtype.qmax
- self.enabled = enable
-
- def enable(self):
- self.enabled = True
-
- def disable(self):
- self.enabled = False
-
- def fake_quant_forward(self, inp, qparams: QParams = None):
- raise NotImplementedError
-
- def normal_foward(self, inp, qparams: QParams = None):
- return inp
-
- def forward(self, inp, qparams: QParams = None):
- if self.enabled:
- return self.fake_quant_forward(inp, qparams=qparams)
- else:
- return self.normal_foward(inp, qparams=qparams)
-
-
- class TQT(_FakeQuantize, QParamsModuleMixin):
- r"""
- TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds
- for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks.
-
- :param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
- quantization dtype of input.
- :param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
- """
-
- def __init__(
- self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
- ):
- super().__init__(dtype, enable, **kwargs)
- self.scale = Parameter(0.0, dtype="float32")
-
- def fake_quant_forward(self, inp, qparams: QParams = None):
- # when enable, TQT will do fakequant forward, finetune the scale
- return tqt_forward(self.qmin, self.qmax, inp, self.scale)
-
- def set_qparams(self, qparams: QParams):
- assert (
- qparams.mode == QuantMode.SYMMERTIC
- ), "only symmetric quantization is supported by TQT"
- if qparams.scale is None:
- raise AssertionError("Can not get an initialized scale")
- self.scale[...] = F.log(qparams.scale) / math.log(2)
-
- def get_qparams(self):
- return create_qparams(QuantMode.SYMMERTIC, self.dtype, scale=2 ** self.scale)
-
-
- class FakeQuantize(_FakeQuantize):
- r"""
- A module to do quant and dequant according to observer's scale and zero_point.
-
- :param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
- quantization dtype of input.
- :param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
- """
-
- def fake_quant_forward(self, inp, qparams: QParams = None):
- assert (
- qparams.dtype_meta is self.dtype
- ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
- qparams.dtype_meta, self.dtype
- )
- return fake_quant_tensor(inp, qparams)
|