From d1be31279aa9a20e4af00714290fed6e0dbba1d0 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 2 Feb 2021 12:42:57 +0800 Subject: [PATCH] fix(mge/quantization): fix tqt load and convert issue and observer calculate params issue GitOrigin-RevId: f8511f72adbac3869f7bb05f3f1364329798119e --- .../python/megengine/quantization/fake_quant.py | 32 +++++++++------------- .../python/megengine/quantization/observer.py | 18 +++++++----- .../python/megengine/quantization/quantize.py | 5 +++- .../python/test/unit/quantization/test_observer.py | 7 +++-- 4 files changed, 32 insertions(+), 30 deletions(-) diff --git a/imperative/python/megengine/quantization/fake_quant.py b/imperative/python/megengine/quantization/fake_quant.py index c1b8a491..15d584db 100644 --- a/imperative/python/megengine/quantization/fake_quant.py +++ b/imperative/python/megengine/quantization/fake_quant.py @@ -6,12 +6,8 @@ # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import math -from typing import Iterable - -import numpy as np from .. import functional as F -from ..autodiff import Function from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype from ..module import Module from ..tensor import Parameter, Tensor @@ -72,20 +68,10 @@ class TQT(_FakeQuantize): """ def __init__( - self, - q_dict, - dtype: str, - narrow_range: bool = False, - enable: bool = True, - **kwargs + self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs ): super().__init__(dtype, narrow_range, enable, **kwargs) - assert ( - q_dict["mode"] == QuantMode.SYMMERTIC - ), "only symmetric quantization is supported by TQT" - if "scale" not in q_dict or q_dict["scale"] is None: - raise AssertionError("Can not get an initialized scale") - self.scale = Tensor(F.log(q_dict["scale"]) / math.log(2)) + self.scale = Parameter(0.0, dtype="float32") def fake_quant_forward(self, inp, q_dict=None): # when enable, TQT will do fakequant forward, finetune the scale @@ -93,14 +79,22 @@ class TQT(_FakeQuantize): def get_qparams(self): q_dict = get_qparam_dict(QuantMode.SYMMERTIC) - q_dict["scale"] = 2 ** self.scale + q_dict["scale"] = 2 ** self.scale.detach() return q_dict + def set_qparams(self, q_dict): + assert ( + q_dict["mode"] == QuantMode.SYMMERTIC + ), "only symmetric quantization is supported by TQT" + if "scale" not in q_dict or q_dict["scale"] is None: + raise AssertionError("Can not get an initialized scale") + self.scale._reset(F.log(q_dict["scale"]) / math.log(2)) + def get_dtype(self): q_dict = self.get_qparams() - scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0] + scale = None if "scale" not in q_dict else q_dict["scale"].numpy() zero_point = ( - None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0] + None if "zero_point" not in q_dict else q_dict["zero_point"].numpy() ) return get_quantized_dtype(self.dtype, scale, zero_point) diff --git a/imperative/python/megengine/quantization/observer.py b/imperative/python/megengine/quantization/observer.py index 8da70581..bc4f0cf5 100644 --- a/imperative/python/megengine/quantization/observer.py +++ b/imperative/python/megengine/quantization/observer.py @@ -17,7 +17,7 @@ from ..distributed import WORLD, get_rank, is_distributed from ..functional.distributed import all_reduce_max, all_reduce_min from ..module import Module from ..tensor import Tensor -from .utils import QuantMode, Round, get_qparam_dict +from .utils import QuantMode, get_qparam_dict class Observer(Module): @@ -110,7 +110,7 @@ class MinMaxObserver(Observer): (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit ) # caculate zero_point - q_dict["zero_point"] = self.qmin - Round()((min_val / q_dict["scale"])) + q_dict["zero_point"] = self.qmin - F.round(min_val / q_dict["scale"]) return q_dict @@ -453,12 +453,10 @@ class PassiveObserver(Observer): This class can be set :attr:`scale` derectly. """ - def __init__(self, q_dict, dtype: str, narrow_range: bool = False, **kwargs): + def __init__(self, dtype: str, narrow_range: bool = False, **kwargs): super().__init__(dtype, narrow_range, **kwargs) - self.q_dict = deepcopy(q_dict) - if "scale" not in q_dict or q_dict["scale"] is None: - raise AssertionError("Can not get an initialized scale") - self.orig_scale = q_dict["scale"].numpy() + self.q_dict = None + self.orig_scale = None @property def scale(self): @@ -472,6 +470,12 @@ class PassiveObserver(Observer): def get_qparams(self): return self.q_dict + def set_qparams(self, q_dict): + self.q_dict = deepcopy(q_dict) + if "scale" not in q_dict or q_dict["scale"] is None: + raise AssertionError("Can not get an initialized scale") + self.orig_scale = q_dict["scale"].numpy() + def forward(self, x): r""" Just return input because :attr:`q_dict` is set by :func:`~.apply_easy_quant`. diff --git a/imperative/python/megengine/quantization/quantize.py b/imperative/python/megengine/quantization/quantize.py index 1d11442a..e3a01942 100644 --- a/imperative/python/megengine/quantization/quantize.py +++ b/imperative/python/megengine/quantization/quantize.py @@ -152,7 +152,10 @@ def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True): module = deepcopy(module) def safe_call(func, q_dict): - return func(q_dict=q_dict) if func is not None else None + inst = func() if func is not None else None + if inst is not None and getattr(inst, "set_qparams", None) is not None: + inst.set_qparams(q_dict) + return inst for m in list(module._flatten(predicate=is_qat)): if m.with_weight: diff --git a/imperative/python/test/unit/quantization/test_observer.py b/imperative/python/test/unit/quantization/test_observer.py index ea46eb30..33e4f964 100644 --- a/imperative/python/test/unit/quantization/test_observer.py +++ b/imperative/python/test/unit/quantization/test_observer.py @@ -41,8 +41,8 @@ def test_exponential_moving_average_observer(): m = ExponentialMovingAverageObserver(momentum=t) m(mge.tensor(x1, dtype=np.float32)) m(mge.tensor(x2, dtype=np.float32)) - np.testing.assert_allclose(m.min_val.numpy(), expected_min) - np.testing.assert_allclose(m.max_val.numpy(), expected_max) + np.testing.assert_allclose(m.min_val.numpy(), expected_min, atol=1e-5) + np.testing.assert_allclose(m.max_val.numpy(), expected_max, atol=1e-5) def test_histogram_observer(): @@ -57,7 +57,8 @@ def test_histogram_observer(): def test_passive_observer(): q_dict = {"scale": mge.tensor(1.0)} - m = PassiveObserver(q_dict, "qint8") + m = PassiveObserver("qint8") + m.set_qparams(q_dict) assert m.orig_scale == 1.0 assert m.scale == 1.0 m.scale = 2.0