GitOrigin-RevId: 92389341be
release-0.5
@@ -37,15 +37,14 @@ class QATModule(Module): | |||
Set quantization related configs with ``qconfig``, including | |||
observer and fake_quant for weight and activation. | |||
""" | |||
self.weight_observer = qconfig.weight_observer() | |||
self.act_observer = qconfig.act_observer() | |||
if qconfig.fake_quant is None: | |||
self.weight_fake_quant = None | |||
self.act_fake_quant = None | |||
else: | |||
self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype) | |||
self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype) | |||
def safe_call(func): | |||
return func() if func is not None else None | |||
self.weight_observer = safe_call(qconfig.weight_observer) | |||
self.act_observer = safe_call(qconfig.act_observer) | |||
self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) | |||
self.act_fake_quant = safe_call(qconfig.act_fake_quant) | |||
def _apply_fakequant_with_observer( | |||
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | |||
@@ -19,7 +19,7 @@ from .observer import ObserverMode, Round | |||
class _FakeQuantize(Module): | |||
def __init__(self, dtype: str, enable: bool = True): | |||
def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | |||
super().__init__() | |||
if not dtype in _metadata_dict.keys(): | |||
raise ValueError( | |||
@@ -28,7 +28,10 @@ class _FakeQuantize(Module): | |||
) | |||
) | |||
self.dtype = dtype | |||
self.qmin = _metadata_dict[dtype].qmin | |||
self.narrow_range = narrow_range | |||
self.qmin = ( | |||
-_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin | |||
) | |||
self.qmax = _metadata_dict[dtype].qmax | |||
self.enabled = enable | |||
@@ -90,12 +93,12 @@ class TQT_Function(Function): | |||
class TQT(_FakeQuantize): | |||
""" | |||
TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds | |||
TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds | |||
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks | |||
""" | |||
def __init__(self, dtype: str, enable: bool = True): | |||
super().__init__(dtype, enable) | |||
def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | |||
super().__init__(dtype, narrow_range, enable) | |||
self.scale = Parameter(0.0, dtype=np.float32) | |||
def fake_quant_forward(self, inp, q_dict): | |||
@@ -116,6 +119,11 @@ class TQT(_FakeQuantize): | |||
class FakeQuantize(_FakeQuantize): | |||
r""" | |||
A module to do quant and dequant according to observer's scale and zero_point. | |||
:param dtype: A string indicating the target quantization type of input. | |||
:param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``, | |||
instead of 1 greater. Usually True for weight and False for activation. | |||
:param enable: Whether do ``normal_forward`` or ``fake_quant_forward``. | |||
""" | |||
def fake_quant_forward(self, inp, q_dict): | |||
@@ -31,9 +31,11 @@ class Observer(Module): | |||
A base class for Observer Module. | |||
:param dtype: a string indicating to collect scale and zero_point of which dtype | |||
:param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``, | |||
instead of 1 greater. Usually True for weight and False for activation. | |||
""" | |||
def __init__(self, dtype="qint8"): | |||
def __init__(self, dtype: str, narrow_range: bool = False): | |||
super().__init__() | |||
if dtype not in _metadata_dict.keys(): | |||
raise ValueError( | |||
@@ -42,7 +44,10 @@ class Observer(Module): | |||
) | |||
) | |||
self.dtype = dtype | |||
self.qmin = _metadata_dict[dtype].qmin | |||
self.narrow_range = narrow_range | |||
self.qmin = ( | |||
-_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin | |||
) | |||
self.qmax = _metadata_dict[dtype].qmax | |||
self.enabled = True | |||
@@ -96,8 +101,14 @@ def create_observer_dict(mode): | |||
class MinMaxObserver(Observer): | |||
def __init__(self, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"): | |||
super().__init__(dtype) | |||
def __init__( | |||
self, | |||
mode=ObserverMode.SYMMERTIC, | |||
eps=0.00001, | |||
dtype="qint8", | |||
narrow_range: bool = False, | |||
): | |||
super().__init__(dtype, narrow_range) | |||
self.mode = mode | |||
self.min_val = Buffer(np.finfo(np.float32).max, dtype=np.float32) | |||
self.max_val = Buffer(np.finfo(np.float32).min, dtype=np.float32) | |||
@@ -153,9 +164,14 @@ class MinMaxObserver(Observer): | |||
class ExponentialMovingAverageObserver(MinMaxObserver): | |||
def __init__( | |||
self, momentum=0.9, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8" | |||
self, | |||
momentum=0.9, | |||
mode=ObserverMode.SYMMERTIC, | |||
eps=0.00001, | |||
dtype="qint8", | |||
narrow_range: bool = False, | |||
): | |||
super().__init__(mode, eps, dtype) | |||
super().__init__(mode, eps, dtype, narrow_range) | |||
self.momentum = Buffer(momentum) | |||
self.runtime_momentum = Buffer(0.0) | |||
@@ -188,11 +204,12 @@ class HistogramObserver(MinMaxObserver): | |||
self, | |||
bins=2048, | |||
upsample_rate=128, | |||
dtype="qint8", | |||
mode=ObserverMode.SYMMERTIC, | |||
eps=0.00001, | |||
dtype="qint8", | |||
narrow_range: bool = False, | |||
): | |||
super().__init__(mode, eps, dtype) | |||
super().__init__(mode, eps, dtype, narrow_range) | |||
self.bins = bins | |||
self.upsample_rate = upsample_rate | |||
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 | |||
@@ -5,6 +5,8 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from functools import partial | |||
from ..module import Module | |||
from .fake_quant import TQT, FakeQuantize | |||
from .observer import ( | |||
@@ -22,9 +24,9 @@ class QConfig: | |||
:param weight_observer: interface to instantiate an :class:`~.Observer` indicating | |||
how to collect scales and zero_point of wegiht. | |||
:param act_observer: similar to ``weight_observer`` but toward activation. | |||
:param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating | |||
how to do fake_quant calculation. can be invoked multi times to get different | |||
instance for each target tensor, for better control on enable and disable. | |||
:param weight_fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating | |||
how to do fake_quant calculation. | |||
:param act_observer: similar to ``weight_fake_quant`` but toward activation. | |||
Examples: | |||
@@ -32,14 +34,24 @@ class QConfig: | |||
# Default EMA QConfig for QAT. | |||
ema_fakequant_qconfig = QConfig( | |||
weight_observer=MinMaxObserver, | |||
act_observer=ExponentialMovingAverageObserver, | |||
fake_quant=FakeQuantize, | |||
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||
act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False), | |||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||
) | |||
Each parameter is a ``class`` rather than an instance. And we recommand using ``functools.partial`` | |||
to add initialization parameters of the ``class``, so that don't need to provide parameters in | |||
:meth:`~.QATModule.set_qconfig`. | |||
Usually we set ``narrow_range`` of weight related paramters to ``True`` and of activation related | |||
parameters to ``False``. For the result of multiplication and addition as ``a * b + c * d``, if | |||
four variables are all -128 of dtype ``qint8``, then the result will be ``2^15`` and cause overflow. | |||
Weights are commonly calculated in this way, so needed to narrow the range. | |||
""" | |||
def __init__( | |||
self, act_observer, weight_observer, fake_quant, | |||
self, weight_observer, act_observer, weight_fake_quant, act_fake_quant | |||
): | |||
if isinstance(act_observer, Module) or isinstance(weight_observer, Module): | |||
raise ValueError( | |||
@@ -47,30 +59,42 @@ class QConfig: | |||
" class generator using `partial(Observer, ...)` instead. Use" | |||
" partial(MyObserver, x=1) to override arguments to constructor if needed" | |||
) | |||
self.act_observer = act_observer | |||
self.weight_observer = weight_observer | |||
self.fake_quant = fake_quant | |||
self.act_observer = act_observer | |||
self.weight_fake_quant = weight_fake_quant | |||
self.act_fake_quant = act_fake_quant | |||
tqt_quant_qconfig = QConfig( | |||
weight_observer=ExponentialMovingAverageObserver, | |||
act_observer=ExponentialMovingAverageObserver, | |||
fake_quant=TQT, | |||
weight_observer=partial( | |||
ExponentialMovingAverageObserver, dtype="qint8", narrow_range=True | |||
), | |||
act_observer=partial( | |||
ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False | |||
), | |||
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), | |||
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), | |||
) | |||
# Default QAT QConfigs | |||
min_max_fakequant_qconfig = QConfig( | |||
weight_observer=MinMaxObserver, | |||
act_observer=MinMaxObserver, | |||
fake_quant=FakeQuantize, | |||
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||
act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False), | |||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||
) | |||
ema_fakequant_qconfig = QConfig( | |||
weight_observer=MinMaxObserver, | |||
act_observer=ExponentialMovingAverageObserver, | |||
fake_quant=FakeQuantize, | |||
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||
act_observer=partial( | |||
ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False | |||
), | |||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||
) | |||
calibration_qconfig = QConfig( | |||
weight_observer=MinMaxObserver, act_observer=HistogramObserver, fake_quant=None, | |||
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||
act_observer=partial(HistogramObserver, dtype="qint8", narrow_range=False), | |||
weight_fake_quant=None, | |||
act_fake_quant=None, | |||
) |