diff --git a/python_module/megengine/module/module.py b/python_module/megengine/module/module.py index 10548825..d91fa94e 100644 --- a/python_module/megengine/module/module.py +++ b/python_module/megengine/module/module.py @@ -318,6 +318,7 @@ class Module(metaclass=ABCMeta): Set ``module``'s ``quantize_diabled`` attribute and return ``module``. Could be used as a decorator. """ + def fn(module: Module) -> None: module.quantize_diabled = value diff --git a/python_module/megengine/quantization/quantize.py b/python_module/megengine/quantization/quantize.py index 1d2e6b1f..36d6cc00 100644 --- a/python_module/megengine/quantization/quantize.py +++ b/python_module/megengine/quantization/quantize.py @@ -80,9 +80,7 @@ def quantize(module: Module, inplace: bool = True): def quantize_qat( - module: Module, - inplace: bool = True, - qconfig: QConfig = ema_fakequant_qconfig, + module: Module, inplace: bool = True, qconfig: QConfig = ema_fakequant_qconfig, ): r""" Recursively convert float :class:`~.Module` to :class:`~.QATModule` diff --git a/python_module/test/unit/quantization/quantize.py b/python_module/test/unit/quantization/quantize.py index 36cb5279..14e9acb0 100644 --- a/python_module/test/unit/quantization/quantize.py +++ b/python_module/test/unit/quantization/quantize.py @@ -7,7 +7,7 @@ # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from megengine import module as Float from megengine.module import qat as QAT -from megengine.quantization.quantize import _get_quantable_module_names +from megengine.quantization.quantize import _get_quantable_module_names, quantize_qat def test_get_quantable_module_names(): @@ -36,3 +36,19 @@ def test_get_quantable_module_names(): and issubclass(value, Float.Module) and value != Float.Module ) + + +def test_disable_quantize(): + class Net(Float.Module): + def __init__(self): + super().__init__() + self.conv = Float.ConvBnRelu2d(3, 3, 3) + self.conv.disable_quantize() + + def forward(self, x): + return self.conv(x) + + net = Net() + qat_net = quantize_qat(net, inplace=False) + assert isinstance(qat_net.conv, Float.ConvBnRelu2d) + assert isinstance(qat_net.conv.conv, Float.Conv2d)