diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index c0732fc3..10548825 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -57,6 +57,7 @@ class Module(metaclass=ABCMeta): def __init__(self): self.training = True + self.quantize_diabled = False @abstractmethod def forward(self, inputs): @@ -312,6 +313,16 @@ class Module(metaclass=ABCMeta): """ self.train(False) + def disable_quantize(self, value=True): + r""" + Set ``module``'s ``quantize_diabled`` attribute and return ``module``. + Could be used as a decorator. + """ + def fn(module: Module) -> None: + module.quantize_diabled = value + + self.apply(fn) + def state_dict(self, rst=None, prefix="", keep_var=False): r"""Returns a dictionary containing whole states of the module. """ diff --git a/python_module/megengine/module/qat/module.py b/python_module/megengine/module/qat/module.py index 7dc47996..ba4bea35 100644 --- a/python_module/megengine/module/qat/module.py +++ b/python_module/megengine/module/qat/module.py @@ -26,8 +26,6 @@ class QATModule(Module): def __init__(self): super().__init__() - self.scale = None - self.weight_observer = None # type: Observer self.act_observer = None # type: Observer diff --git a/python_module/megengine/quantization/quantize.py b/python_module/megengine/quantization/quantize.py index 358c6cc5..1d2e6b1f 100644 --- a/python_module/megengine/quantization/quantize.py +++ b/python_module/megengine/quantization/quantize.py @@ -6,7 +6,7 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from copy import deepcopy -from typing import Dict, Tuple +from typing import Callable, Dict, Tuple from .. import module as Float from ..module import Module @@ -48,7 +48,7 @@ def _get_convert_dict() -> Tuple[ _float2qat_dict, _qat2quantized_dict = _get_convert_dict() -def quantize(module: Module, inplace=True): +def quantize(module: Module, inplace: bool = True): r""" Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule` through :meth:`~.Module.apply`. @@ -80,7 +80,9 @@ def quantize(module: Module, inplace=True): def quantize_qat( - module: Module, inplace=True, qconfig: QConfig = ema_fakequant_qconfig + module: Module, + inplace: bool = True, + qconfig: QConfig = ema_fakequant_qconfig, ): r""" Recursively convert float :class:`~.Module` to :class:`~.QATModule` @@ -105,7 +107,7 @@ def quantize_qat( module._flatten(with_key=True, with_parent=True, predicate=is_quantable) ): # only convert top quantable module. - if is_quantable(parent): + if is_quantable(parent) or submodule.quantize_diabled: continue new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule) @@ -136,12 +138,12 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig): def disable_fake_quant(module: Module): r""" - Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply` + Recursively disable ``module`` fake quantization in QATModule through :meth:`~.Module.apply` :param module: root module to do disable fake quantization recursively. """ - def fn(mod): + def fn(mod: Module): if isinstance(mod, QATModule): mod.act_fake_quant.disable() mod.weight_fake_quant.disable() @@ -151,12 +153,12 @@ def disable_fake_quant(module: Module): def disable_observer(module: Module): r""" - Recursively disable `module` observer in QATModule through :meth:`~.Module.apply` + Recursively disable ``module`` observer in QATModule through :meth:`~.Module.apply` :param module: root module to do disable observer recursively. """ - def fn(mod): + def fn(mod: Module): if isinstance(mod, QATModule): mod.act_observer.disable() mod.weight_observer.disable() @@ -166,12 +168,12 @@ def disable_observer(module: Module): def enable_fake_quant(module: Module): r""" - Recursively enable `module` fake quantization in QATModule through :meth:`~.Module.apply` + Recursively enable ``module`` fake quantization in QATModule through :meth:`~.Module.apply` :param module: root module to do enable fake quantization recursively. """ - def fn(mod): + def fn(mod: Module): if isinstance(mod, QATModule): mod.act_fake_quant.enable() mod.weight_fake_quant.enable() @@ -181,12 +183,12 @@ def enable_fake_quant(module: Module): def enable_observer(module: Module): r""" - Recursively enable `module` observer in QATModule through :meth:`~.Module.apply` + Recursively enable ``module`` observer in QATModule through :meth:`~.Module.apply` :param module: root module to do enable observer recursively. """ - def fn(mod): + def fn(mod: Module): if isinstance(mod, QATModule): mod.act_observer.enable() mod.weight_observer.enable()