diff --git a/python_module/megengine/module/qat/module.py b/python_module/megengine/module/qat/module.py index 37315ee6..d5562f6a 100644 --- a/python_module/megengine/module/qat/module.py +++ b/python_module/megengine/module/qat/module.py @@ -49,6 +49,8 @@ class QATModule(Module): def _apply_fakequant_with_observer( self, target: Tensor, fake_quant: FakeQuantize, observer: Observer ): + if observer is None: + return target oup = observer(target) if fake_quant is None: return oup @@ -76,7 +78,7 @@ class QATModule(Module): r""" Get weight's quantization dtype as the method from ``qconfig``. """ - if hasattr(self.act_fake_quant, "get_dtype"): + if hasattr(self.weight_fake_quant, "get_dtype"): return self.weight_fake_quant.get_dtype() else: return self.weight_observer.get_dtype() diff --git a/python_module/megengine/quantization/__init__.py b/python_module/megengine/quantization/__init__.py index 938cec9f..b93e6453 100644 --- a/python_module/megengine/quantization/__init__.py +++ b/python_module/megengine/quantization/__init__.py @@ -5,7 +5,9 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + from .fake_quant import FakeQuantize +from .internal_fake_quant import * from .observer import HistogramObserver, Observer, ObserverMode from .qconfig import ( QConfig, diff --git a/python_module/megengine/quantization/fake_quant.py b/python_module/megengine/quantization/fake_quant.py index 9349f0b3..78577d9b 100644 --- a/python_module/megengine/quantization/fake_quant.py +++ b/python_module/megengine/quantization/fake_quant.py @@ -19,6 +19,15 @@ from .observer import ObserverMode, Round class _FakeQuantize(Module): + r""" + A Basic Fake Quant module. + + :param dtype: A string indicating the target quantization type of input. + :param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``, + instead of 1 greater. Usually True for weight and False for activation. + :param enable: Whether do ``normal_forward`` or ``fake_quant_forward``. + """ + def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): super().__init__() if not dtype in _metadata_dict.keys(): @@ -92,9 +101,9 @@ class TQT_Function(Function): class TQT(_FakeQuantize): - """ + r""" TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds - for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks + for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks. """ def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): @@ -119,11 +128,6 @@ class TQT(_FakeQuantize): class FakeQuantize(_FakeQuantize): r""" A module to do quant and dequant according to observer's scale and zero_point. - - :param dtype: A string indicating the target quantization type of input. - :param narrow_range: Whether the absolute value of ``qmin`` is the same as ``qmax``, - instead of 1 greater. Usually True for weight and False for activation. - :param enable: Whether do ``normal_forward`` or ``fake_quant_forward``. """ def fake_quant_forward(self, inp, q_dict): diff --git a/python_module/megengine/quantization/internal_fake_quant.py b/python_module/megengine/quantization/internal_fake_quant.py new file mode 100644 index 00000000..df15a916 --- /dev/null +++ b/python_module/megengine/quantization/internal_fake_quant.py @@ -0,0 +1,19 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import copy +import math +from functools import partial + +import numpy as np + +from .. import functional as F +from ..core import Function +from .fake_quant import _FakeQuantize +from .observer import MinMaxObserver +from .qconfig import QConfig + diff --git a/python_module/test/unit/quantization/test_TQT.py b/python_module/test/unit/quantization/test_fake_quant.py similarity index 97% rename from python_module/test/unit/quantization/test_TQT.py rename to python_module/test/unit/quantization/test_fake_quant.py index 11de7e40..0fbd9eb1 100644 --- a/python_module/test/unit/quantization/test_TQT.py +++ b/python_module/test/unit/quantization/test_fake_quant.py @@ -13,6 +13,7 @@ import megengine as mge import megengine._internal as mgb from megengine.core import tensor from megengine.quantization.fake_quant import TQT_Function +from megengine.quantization.internal_fake_quant import * from megengine.test import assertTensorClose @@ -75,3 +76,5 @@ def test_TQT(): a.set_value(a_np) b.set_value(b_np) check_inp(a, b, b, a_np, b_np, b_np) + +