diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index 6ab0e6e2..4ee9f62b 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -476,21 +476,17 @@ class QATModule(Module): self.quantizing = self.QATMode.DISABLED self.scale = None - self.inp_observer = None # type: Observer self.weight_observer = None # type: Observer self.act_observer = None # type: Observer self.weight_fake_quant = None # type: FakeQuantize - self.bias_fake_quant = None # type: FakeQuantize self.act_fake_quant = None # type: FakeQuantize def set_qconfig(self, qconfig: "QConfig"): - self.inp_observer = qconfig.inp_observer() self.weight_observer = qconfig.weight_observer() self.act_observer = qconfig.act_observer() self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype) - self.bias_fake_quant = qconfig.bias_fake_quant() self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype) def apply_observer(self, target: Tensor, obs: "Observer"): diff --git a/python_module/megengine/quantization/__init__.py b/python_module/megengine/quantization/__init__.py index 428db50b..46145bd8 100644 --- a/python_module/megengine/quantization/__init__.py +++ b/python_module/megengine/quantization/__init__.py @@ -8,4 +8,11 @@ from .fake_quant import FakeQuantize from .observer import Observer from .qconfig import QConfig, ema_fakequant_qconfig, min_max_fakequant_qconfig -from .quantize import quantize, quantize_qat +from .quantize import ( + disable_fake_quant, + disable_observer, + enable_fake_quant, + enable_observer, + quantize, + quantize_qat, +) diff --git a/python_module/megengine/quantization/qconfig.py b/python_module/megengine/quantization/qconfig.py index 14c34831..9a524300 100644 --- a/python_module/megengine/quantization/qconfig.py +++ b/python_module/megengine/quantization/qconfig.py @@ -15,21 +15,18 @@ from .observer import ExponentialMovingAverageObserver, MinMaxObserver class QConfig: """ A config class indicating how to do quantize toward :class:`~.QATModule`'s - ``activation``, ``weight`` and ``bias``. + ``activation`` and ``weight``. And ``fake_quant`` parameter to indicate See :meth:`~.QATModule.set_qconfig` for detail usage. - :param inp_observer: interface to instantiate an :class:`~.Observer` indicating - how to collect scales and zero_point of input. - :param weight_observer: similar to ``inp_observer`` but toward weight. - :param act_observer: similar to ``inp_observer`` but toward activation. + :param weight_observer: interface to instantiate an :class:`~.Observer` indicating +- how to collect scales and zero_point of wegiht. + :param act_observer: similar to ``weight_observer`` but toward activation. :param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating how to do fake_quant calculation. can be invoked multi times to get different instance for each target tensor, for better control on enable and disable. - :param bias_fake_quant: similar to ``fake_quant``, but usually need to set ``dtype`` - in advance, for bias's dtype is unable to be inferred from observer. Examples: @@ -37,21 +34,16 @@ class QConfig: # Default EMA QConfig for QAT. ema_fakequant_qconfig = QConfig( - inp_observer=ExponentialMovingAverageObserver, - weight_observer=ExponentialMovingAverageObserver, + weight_observer=MinMaxObserver, act_observer=ExponentialMovingAverageObserver, fake_quant=FakeQuantize, ) """ def __init__( - self, act_observer, weight_observer, inp_observer, fake_quant, bias_fake_quant, + self, act_observer, weight_observer, fake_quant, ): - if ( - isinstance(act_observer, Module) - or isinstance(weight_observer, Module) - or isinstance(inp_observer, Module) - ): + if isinstance(act_observer, Module) or isinstance(weight_observer, Module): raise ValueError( "QConfig must not receive observer instance, please pass observer" " class generator using `partial(Observer, ...)` instead. Use" @@ -59,24 +51,18 @@ class QConfig: ) self.act_observer = act_observer self.weight_observer = weight_observer - self.inp_observer = inp_observer self.fake_quant = fake_quant - self.bias_fake_quant = bias_fake_quant # Default QAT QConfigs min_max_fakequant_qconfig = QConfig( - inp_observer=MinMaxObserver, weight_observer=MinMaxObserver, act_observer=MinMaxObserver, fake_quant=FakeQuantize, - bias_fake_quant=partial(FakeQuantize, dtype="qint32"), ) ema_fakequant_qconfig = QConfig( - inp_observer=ExponentialMovingAverageObserver, weight_observer=MinMaxObserver, act_observer=ExponentialMovingAverageObserver, fake_quant=FakeQuantize, - bias_fake_quant=partial(FakeQuantize, dtype="qint32"), ) diff --git a/python_module/megengine/quantization/quantize.py b/python_module/megengine/quantization/quantize.py index c89ad6dc..1ce5c953 100644 --- a/python_module/megengine/quantization/quantize.py +++ b/python_module/megengine/quantization/quantize.py @@ -64,7 +64,6 @@ def disable_fake_quant(module: Module): if isinstance(mod, QATModule): mod.act_fake_quant.disable() mod.weight_fake_quant.disable() - mod.inp_fake_quant.disable() module.apply(fn) @@ -79,6 +78,7 @@ def disable_observer(module: Module): def fn(mod): if isinstance(mod, QATModule): mod.act_observer.disable() + mod.weight_observer.disable() module.apply(fn) @@ -94,7 +94,6 @@ def enable_fake_quant(module: Module): if isinstance(mod, QATModule): mod.act_fake_quant.enable() mod.weight_fake_quant.enable() - mod.inp_fake_quant.enable() module.apply(fn) @@ -109,5 +108,6 @@ def enable_observer(module: Module): def fn(mod): if isinstance(mod, QATModule): mod.act_observer.enable() + mod.weight_observer.enable() module.apply(fn)