- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- from abc import abstractmethod
-
- # avoid circular reference
- from ...quantization.fake_quant import FakeQuantize
- from ...quantization.observer import Observer
- from ...quantization.qconfig import QConfig
- from ...quantization.utils import fake_quant_bias
- from ...tensor import Tensor
- from ..module import Module
-
-
- class QATModule(Module):
- r"""
- Base class of quantized-float related :class:`~.Module`, basically for QAT and Calibration.
-
- Use :meth:`from_float_module` to generate a instance from float :class:`~.Module`.
- Or use :func:`~.quantize.quantize_qat` to do it recursively and automatically.
-
- Can also be converted to :class:`~.QuantizedModule` for deployment using
- :func:`~.quantize.quantize` further.
- """
-
- with_weight = True
- with_act = True
-
- def __init__(self, **kwargs):
- super().__init__(**kwargs)
-
- self.weight_observer = None # type: Observer
- self.act_observer = None # type: Observer
-
- self.weight_fake_quant = None # type: FakeQuantize
- self.act_fake_quant = None # type: FakeQuantize
-
- def __repr__(self):
- return "QAT." + super().__repr__()
-
- def set_qconfig(self, qconfig: QConfig):
- r"""
- Set quantization related configs with ``qconfig``, including
- observer and fake_quant for weight and activation.
- """
-
- def safe_call(func):
- return func() if func is not None else None
-
- if self.with_act:
- self.act_observer = safe_call(qconfig.act_observer)
- self.act_fake_quant = safe_call(qconfig.act_fake_quant)
- if self.with_weight:
- self.weight_observer = safe_call(qconfig.weight_observer)
- self.weight_fake_quant = safe_call(qconfig.weight_fake_quant)
-
- def _enable_exec(self, with_module, func, enable):
- if not with_module or not func:
- return
- if enable:
- func.enable()
- else:
- func.disable()
-
- def set_fake_quant(self, enable):
- self._enable_exec(self.with_act, self.act_fake_quant, enable)
- self._enable_exec(self.with_weight, self.weight_fake_quant, enable)
-
- def set_observer(self, enable):
- self._enable_exec(self.with_act, self.act_observer, enable)
- self._enable_exec(self.with_weight, self.weight_observer, enable)
-
- def _apply_fakequant_with_observer(
- self, target: Tensor, fake_quant: FakeQuantize, observer: Observer
- ):
- # do observer
- if observer is None:
- oup = target
- qparams = None
- else:
- oup = observer(target)
- qparams = observer.get_qparams()
- # do fake quant
- if fake_quant is not None:
- oup = fake_quant(oup, qparams)
- # use qparams of fake_quant if have.
- if hasattr(fake_quant, "get_qparams"):
- qparams = fake_quant.get_qparams()
- # set to tensor qparams.
- if qparams is not None:
- oup.qparams.update(qparams)
- return oup
-
- def apply_quant_weight(self, target: Tensor):
- r"""
- Apply weight's observer and fake_quant from ``qconfig`` on ``target``.
- """
- return self._apply_fakequant_with_observer(
- target, self.weight_fake_quant, self.weight_observer
- )
-
- def apply_quant_activation(self, target: Tensor):
- r"""
- Apply weight's observer and fake_quant from ``qconfig`` on ``target``.
- """
- return self._apply_fakequant_with_observer(
- target, self.act_fake_quant, self.act_observer
- )
-
- def apply_quant_bias(self, target: Tensor, inp: Tensor, w_qat: Tensor):
- r"""
- Use :func:`~.fake_quant_bias` to process ``target``. Only valid when
- ``act_fake_quant`` and ``weight_fake_quant`` are both enabled.
- """
- # bias should have the same dtype as activation, so act_fake_quant can also
- # decide whether to do bias fakequant
- if (
- self.act_fake_quant
- and self.act_fake_quant.enabled
- and self.weight_fake_quant
- and self.weight_fake_quant.enabled
- ):
- b_qat = fake_quant_bias(target, inp, w_qat)
- else:
- b_qat = target
- return b_qat
-
- def _get_method_result(
- self, method: str, fake_quant: FakeQuantize, observer: Observer
- ):
- if hasattr(fake_quant, method):
- return getattr(fake_quant, method)()
- elif hasattr(observer, method):
- return getattr(observer, method)()
- return None
-
- def get_weight_dtype(self):
- r"""
- Get weight's quantization dtype as the method from ``qconfig``.
- """
- return self._get_method_result(
- "get_quantized_dtype", self.weight_fake_quant, self.weight_observer
- )
-
- def get_activation_dtype(self):
- r"""
- Get activation's quantization dtype as the method from ``qconfig``.
- """
- return self._get_method_result(
- "get_quantized_dtype", self.act_fake_quant, self.act_observer
- )
-
- def get_weight_qparams(self):
- r"""
- Get weight's quantization parameters.
- """
- return self._get_method_result(
- "get_qparams", self.weight_fake_quant, self.weight_observer
- )
-
- def get_activation_qparams(self):
- r"""
- Get activation's quantization parameters.
- """
- return self._get_method_result(
- "get_qparams", self.act_fake_quant, self.act_observer
- )
-
- @classmethod
- @abstractmethod
- def from_float_module(cls, float_module: Module):
- r"""
- Return a :class:`~.QATModule` instance converted from
- a float :class:`~.Module` instance.
- """
|