diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index ecd8b185..b999fb12 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -60,7 +60,7 @@ class Module(metaclass=ABCMeta): def __init__(self): # runtime attributes self.training = True - self.quantize_diabled = False + self.quantize_disabled = False # hooks self._forward_pre_hooks = OrderedDict() @@ -328,12 +328,12 @@ class Module(metaclass=ABCMeta): def disable_quantize(self, value=True): r""" - Set ``module``'s ``quantize_diabled`` attribute and return ``module``. + Set ``module``'s ``quantize_disabled`` attribute and return ``module``. Could be used as a decorator. """ def fn(module: Module) -> None: - module.quantize_diabled = value + module.quantize_disabled = value self.apply(fn) diff --git a/python_module/megengine/quantization/quantize.py b/python_module/megengine/quantization/quantize.py index 329a0fd5..5dab2ae4 100644 --- a/python_module/megengine/quantization/quantize.py +++ b/python_module/megengine/quantization/quantize.py @@ -5,7 +5,7 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from copy import deepcopy +from copy import copy, deepcopy from typing import Callable, Dict, Tuple from .. import module as Float @@ -49,19 +49,24 @@ def _get_convert_dict() -> Tuple[ _float2qat_dict, _qat2quantized_dict = _get_convert_dict() -def quantize(module: Module, inplace: bool = True): +def quantize(module: Module, inplace: bool = True, mapping: dict = None): r""" Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule` through :meth:`~.Module.apply`. :param module: root module to do convert recursively. :param inplace: whether to convert submodules in-place. + :param mapping: a dict indicating how to convert custom modules from QATModule to + QuantizedModule. Will be combined with internal default convert mapping dict. """ if not inplace: module = deepcopy(module) - qat_modules = tuple(_qat2quantized_dict.keys()) + convert_dict = copy(_qat2quantized_dict) + if mapping is not None: + convert_dict.update(mapping) + qat_modules = tuple(convert_dict.keys()) def is_qat(mod: Module): return isinstance(mod, qat_modules) @@ -70,7 +75,7 @@ def quantize(module: Module, inplace: bool = True): for key, submodule, parent in list( module._flatten(with_key=True, with_parent=True, predicate=is_qat) ): - new_mod = _qat2quantized_dict[type(submodule)].from_qat_module(submodule) + new_mod = convert_dict[type(submodule)].from_qat_module(submodule) if isinstance(parent, Float.Sequential): # cannnot use setattr to be compatible with Sequential's ``__setitem__`` parent[int(key.split(".")[-1])] = new_mod @@ -81,7 +86,10 @@ def quantize(module: Module, inplace: bool = True): def quantize_qat( - module: Module, inplace: bool = True, qconfig: QConfig = ema_fakequant_qconfig, + module: Module, + inplace: bool = True, + qconfig: QConfig = ema_fakequant_qconfig, + mapping: dict = None, ): r""" Recursively convert float :class:`~.Module` to :class:`~.QATModule` @@ -91,12 +99,17 @@ def quantize_qat( :param inplace: whether to convert submodules in-place. :param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig. default is ``ema_fakequant_qconfig``. + :param mapping: a dict indicating how to convert custom modules from Module to QATModule. + Will be combined with internal default convert mapping dict. """ if not inplace: module = deepcopy(module) - quantable_modules = tuple(_float2qat_dict.keys()) + convert_dict = copy(_float2qat_dict) + if mapping is not None: + convert_dict.update(mapping) + quantable_modules = tuple(convert_dict.keys()) def is_quantable(mod: Module): return isinstance(mod, quantable_modules) @@ -106,10 +119,10 @@ def quantize_qat( module._flatten(with_key=True, with_parent=True, predicate=is_quantable) ): # only convert top quantable module. - if is_quantable(parent) or submodule.quantize_diabled: + if is_quantable(parent) or submodule.quantize_disabled: continue - new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule) + new_mod = convert_dict[type(submodule)].from_float_module(submodule) if isinstance(parent, Float.Sequential): # cannnot use setattr to be compatible with Sequential's ``__setitem__`` parent[int(key.split(".")[-1])] = new_mod diff --git a/python_module/test/unit/quantization/quantize.py b/python_module/test/unit/quantization/quantize.py index 14e9acb0..236ef9e1 100644 --- a/python_module/test/unit/quantization/quantize.py +++ b/python_module/test/unit/quantization/quantize.py @@ -52,3 +52,29 @@ def test_disable_quantize(): qat_net = quantize_qat(net, inplace=False) assert isinstance(qat_net.conv, Float.ConvBnRelu2d) assert isinstance(qat_net.conv.conv, Float.Conv2d) + + +def test_convert_with_custom_mapping(): + class FloatExample(Float.Module): + def forward(self, x): + return x + + class QATExample(QAT.QATModule): + def forward(self, x): + return x + + @classmethod + def from_float_module(cls, float_module): + return cls() + + class Net(Float.Module): + def __init__(self): + super().__init__() + self.example = FloatExample() + + def forward(self, x): + return self.example(x) + + net = Net() + qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) + assert isinstance(qat_net.example, QATExample)