|
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #'
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- from collections import namedtuple
- from functools import partial
-
- from ..module import Module
- from .fake_quant import TQT, FakeQuantize
- from .observer import (
- ExponentialMovingAverageObserver,
- HistogramObserver,
- MinMaxObserver,
- PassiveObserver,
- SyncExponentialMovingAverageObserver,
- SyncMinMaxObserver,
- )
-
-
- # use namedtuple to make class immutable, comparable and easy to print
- class QConfig(
- namedtuple(
- "QConfig",
- ["weight_observer", "act_observer", "weight_fake_quant", "act_fake_quant"],
- )
- ):
- r"""
- A config class indicating how to do quantize toward :class:`~.QATModule`'s
- ``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage.
-
- :param weight_observer: interface to instantiate an :class:`~.Observer` indicating
- how to collect scales and zero_point of wegiht.
- :param act_observer: similar to ``weight_observer`` but toward activation.
- :param weight_fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating
- how to do fake_quant calculation.
- :param act_observer: similar to ``weight_fake_quant`` but toward activation.
-
- Examples:
-
- .. code-block::
-
- # Default EMA QConfig for QAT.
- ema_fakequant_qconfig = QConfig(
- weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
- act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8"),
- weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
- act_fake_quant=partial(FakeQuantize, dtype="qint8"),
- )
-
- Each parameter is a ``class`` rather than an instance. And we recommand using ``functools.partial``
- to add initialization parameters of the ``class``, so that don't need to provide parameters in
- :meth:`~.QATModule.set_qconfig`.
-
- Usually we choose narrow version dtype (like ``qint8_narrow``) for weight related
- paramters and normal version for activation related ones. For the result of
- multiplication and addition as ``a * b + c * d``, if four variables are all -128 of
- dtype ``qint8``, then the result will be ``2^15`` and cause overflow.
- Weights are commonly calculated in this way, so need to narrow qmin to -127.
- """
-
- def __new__(cls, weight_observer, act_observer, weight_fake_quant, act_fake_quant):
- if isinstance(act_observer, Module) or isinstance(weight_observer, Module):
- raise ValueError(
- "QConfig must not receive observer instance, please pass observer"
- " class generator using `partial(Observer, ...)` instead. Use"
- " partial(MyObserver, x=1) to override arguments to constructor if needed"
- )
- return super().__new__(
- cls, weight_observer, act_observer, weight_fake_quant, act_fake_quant
- )
-
-
- min_max_fakequant_qconfig = QConfig(
- weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
- act_observer=partial(MinMaxObserver, dtype="qint8"),
- weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
- act_fake_quant=partial(FakeQuantize, dtype="qint8"),
- )
-
- ema_fakequant_qconfig = QConfig(
- weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
- act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8"),
- weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
- act_fake_quant=partial(FakeQuantize, dtype="qint8"),
- )
-
- sync_ema_fakequant_qconfig = QConfig(
- weight_observer=partial(SyncMinMaxObserver, dtype="qint8_narrow"),
- act_observer=partial(SyncExponentialMovingAverageObserver, dtype="qint8"),
- weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
- act_fake_quant=partial(FakeQuantize, dtype="qint8"),
- )
-
- ema_lowbit_fakequant_qconfig = QConfig(
- weight_observer=partial(MinMaxObserver, dtype="qint4"),
- act_observer=partial(ExponentialMovingAverageObserver, dtype="qint4"),
- weight_fake_quant=partial(FakeQuantize, dtype="qint4"),
- act_fake_quant=partial(FakeQuantize, dtype="qint4"),
- )
-
- calibration_qconfig = QConfig(
- weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
- act_observer=partial(HistogramObserver, dtype="qint8"),
- weight_fake_quant=None,
- act_fake_quant=None,
- )
-
- tqt_qconfig = QConfig(
- weight_observer=None,
- act_observer=None,
- weight_fake_quant=partial(TQT, dtype="qint8_narrow"),
- act_fake_quant=partial(TQT, dtype="qint8"),
- )
-
- passive_qconfig = QConfig(
- weight_observer=partial(PassiveObserver, dtype="qint8_narrow"),
- act_observer=partial(PassiveObserver, dtype="qint8"),
- weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
- act_fake_quant=partial(FakeQuantize, dtype="qint8"),
- )
-
- easyquant_qconfig = passive_qconfig
|