|
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import math
- from typing import Union
-
- from .. import functional as F
- from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes
- from ..logger import get_logger
- from ..module import Module
- from ..tensor import Parameter, Tensor
- from .utils import (
- LSQParams,
- QParams,
- QParamsModuleMixin,
- QuantMode,
- create_qparams,
- fake_quant_tensor,
- lsq_forward,
- tqt_forward,
- )
-
- logger = get_logger(__name__)
-
-
- class _FakeQuantize(Module):
- def __init__(
- self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
- ):
- super().__init__()
- if isinstance(dtype, str):
- if not dtype in _builtin_quant_dtypes:
- raise ValueError(
- "unknown dtype: {}, only support {}".format(
- dtype, _builtin_quant_dtypes.keys()
- )
- )
- dtype = _builtin_quant_dtypes[dtype]
- if "narrow_range" in kwargs:
- del kwargs["narrow_range"]
- logger.warning(
- "FakeQuantize currently has no narrow_range param "
- "so it is ignored here",
- exc_info=DeprecationWarning,
- )
- self.dtype = dtype
- self.qmin = dtype.qmin
- self.qmax = dtype.qmax
- self.enabled = enable
-
- def enable(self):
- self.enabled = True
-
- def disable(self):
- self.enabled = False
-
- def fake_quant_forward(self, inp, qparams: QParams = None):
- raise NotImplementedError
-
- def normal_foward(self, inp, qparams: QParams = None):
- return inp
-
- def forward(self, inp, qparams: QParams = None):
- if self.enabled:
- return self.fake_quant_forward(inp, qparams=qparams)
- else:
- return self.normal_foward(inp, qparams=qparams)
-
-
- class TQT(_FakeQuantize, QParamsModuleMixin):
- r"""
- TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds
- for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks.
-
- :param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
- quantization dtype of input.
- :param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
- """
-
- def __init__(
- self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
- ):
- super().__init__(dtype, enable, **kwargs)
- self.scale = Parameter(0.0, dtype="float32")
-
- def fake_quant_forward(self, inp, qparams: QParams = None):
- # when enable, TQT will do fakequant forward, finetune the scale
- return tqt_forward(self.qmin, self.qmax, inp, self.scale)
-
- def set_qparams(self, qparams: QParams):
- assert (
- qparams.mode == QuantMode.SYMMERTIC
- ), "only symmetric quantization is supported by TQT"
- if qparams.scale is None:
- raise AssertionError("Can not get an initialized scale")
- self.scale[...] = F.log(qparams.scale) / math.log(2)
-
- def get_qparams(self):
- return create_qparams(QuantMode.SYMMERTIC, self.dtype, scale=2 ** self.scale)
-
-
- class FakeQuantize(_FakeQuantize):
- r"""
- A module to do quant and dequant according to observer's scale and zero_point.
-
- :param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
- quantization dtype of input.
- :param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
- """
-
- def fake_quant_forward(self, inp, qparams: QParams = None):
- assert (
- qparams.dtype_meta is self.dtype
- ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
- qparams.dtype_meta, self.dtype
- )
- return fake_quant_tensor(inp, qparams)
-
-
- class LSQ(_FakeQuantize, QParamsModuleMixin):
- r"""
- LSQ: https://arxiv.org/pdf/1902.08153.pdf Estimating and scaling the
- task loss gradient at each weight and activation layer's quantizer step size
-
- :param dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
- quantization dtype of input.
- :param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
- :param eps:a small value to avoid division by zero. Default: 1e-5
- """
-
- def init(
- self,
- dtype: Union[str, QuantDtypeMeta],
- enable: bool = True,
- eps: float = 1e-5,
- **kwargs
- ):
- super().__init__(dtype=dtype, enable=enable, **kwargs)
- self.eps = Tensor(eps, dtype="float32")
- self.step_size = Parameter(1.0, dtype="float32")
-
- def set_qparams(self, qparams: LSQParams):
- self.mode = qparams.mode
- if qparams.mode == QuantMode.ASYMMERTIC:
- self.zero_point = qparams.zero_point
- else:
- self.zero_point = Tensor([0.0], dtype="float32")
- if qparams.scale is None:
- raise AssertionError("Can not get an initialized scale")
- init_step_size = qparams.scale
- if init_step_size < self.eps:
- init_step_size = 0
- else:
- init_step_size = init_step_size - self.eps
- self.step_size = Parameter(init_step_size, dtype="float32")
-
- self.grad_scale = qparams.grad_scale
-
- def fake_quant_forward(self, inp, qparams: LSQParams = None):
- step_size = F.abs(self.step_size) + self.eps
- return lsq_forward(
- self.qmin, self.qmax, inp, step_size, self.zero_point, self.grad_scale
- )
-
- def get_qparams(self):
- return LSQParams(
- mode=self.mode,
- dtype_meta=self.dtype,
- scale=F.abs(self.step_size.detach()) + self.eps,
- zero_point=self.zero_point,
- grad_scale=self.grad_scale,
- )
|