From fda9599a840928cd4f26f58abd78dfdf0bf134de Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Fri, 5 Jun 2020 02:37:03 +0800 Subject: [PATCH] feat(mge/quant): add TQT quant method GitOrigin-RevId: 00b1616e73ed34c8c09e2407b8fc7d90230f8cec --- python_module/megengine/core/function.py | 1 + python_module/megengine/module/qat/module.py | 10 +- python_module/megengine/quantization/__init__.py | 1 + python_module/megengine/quantization/fake_quant.py | 131 +++++++++++++++++---- python_module/megengine/quantization/observer.py | 2 + python_module/megengine/quantization/qconfig.py | 10 +- python_module/test/unit/core/test_function.py | 1 - python_module/test/unit/quantization/test_TQT.py | 77 ++++++++++++ 8 files changed, 203 insertions(+), 30 deletions(-) create mode 100644 python_module/test/unit/quantization/test_TQT.py diff --git a/python_module/megengine/core/function.py b/python_module/megengine/core/function.py index da37ed3d..1219b21e 100644 --- a/python_module/megengine/core/function.py +++ b/python_module/megengine/core/function.py @@ -154,6 +154,7 @@ class Function(metaclass=ABCMeta): memo[id(self)] = result for k, v in self.__dict__.items(): setattr(result, k, copy.deepcopy(v, memo)) + setattr(result, "saved_tensors", tmp) self.saved_tensors = tmp return result diff --git a/python_module/megengine/module/qat/module.py b/python_module/megengine/module/qat/module.py index ba4bea35..9381ef4f 100644 --- a/python_module/megengine/module/qat/module.py +++ b/python_module/megengine/module/qat/module.py @@ -77,13 +77,19 @@ class QATModule(Module): r""" Get weight's quantization dtype as the method from ``qconfig``. """ - return self.weight_observer.get_dtype() + if hasattr(self.act_fake_quant, "get_dtype"): + return self.weight_fake_quant.get_dtype() + else: + return self.weight_observer.get_dtype() def get_activation_dtype(self): r""" Get activation's quantization dtype as the method from ``qconfig``. """ - return self.act_observer.get_dtype() + if hasattr(self.act_fake_quant, "get_dtype"): + return self.act_fake_quant.get_dtype() + else: + return self.act_observer.get_dtype() @classmethod @abstractmethod diff --git a/python_module/megengine/quantization/__init__.py b/python_module/megengine/quantization/__init__.py index e3476f4c..938cec9f 100644 --- a/python_module/megengine/quantization/__init__.py +++ b/python_module/megengine/quantization/__init__.py @@ -12,4 +12,5 @@ from .qconfig import ( calibration_qconfig, ema_fakequant_qconfig, min_max_fakequant_qconfig, + tqt_quant_qconfig, ) diff --git a/python_module/megengine/quantization/fake_quant.py b/python_module/megengine/quantization/fake_quant.py index 67663309..7ac8889d 100644 --- a/python_module/megengine/quantization/fake_quant.py +++ b/python_module/megengine/quantization/fake_quant.py @@ -5,17 +5,20 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import copy +import math + +import numpy as np + from .. import functional as F -from .._internal.dtype import _metadata_dict +from .._internal.dtype import _metadata_dict, get_quantized_dtype +from ..core import Buffer, Function, Parameter +from ..jit import sideeffect from ..module import Module from .observer import ObserverMode, Round -class FakeQuantize(Module): - r""" - A module to do quant and dequant according to observer's scale and zero_point. - """ - +class _FakeQuantize(Module): def __init__(self, dtype: str, enable: bool = True): super().__init__() if not dtype in _metadata_dict.keys(): @@ -35,25 +38,103 @@ class FakeQuantize(Module): def disable(self): self.enabled = False + def fake_quant_forward(self, inp, q_dict): + return inp + + def normal_foward(self, inp, q_dict): + return inp + def forward(self, inp, q_dict): if self.enabled: - if q_dict["mode"] == ObserverMode.SYMMERTIC: - scale = q_dict["scale"] - # Quant - oup = Round()(inp / scale) - # clip - oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) - # DeQuant - oup = (oup) * scale - return oup - else: - scale = q_dict["scale"] - zero_point = q_dict["zero_point"] - # Quant - oup = Round()(inp / scale) + zero_point - # clip - oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) - # DeQuant - oup = (oup - zero_point) * scale - return oup + return self.fake_quant_forward(inp, q_dict) + else: + return self.normal_foward(inp, q_dict) + + +class TQT_Function(Function): + def __init__(self, lowerbound, upperbound): + super().__init__() + self.lowerbound = lowerbound + self.upperbound = upperbound + + def forward(self, inp, scale): + t = 2 ** scale + # t = F.maximum(t, 1e-4) + inp_scaled = inp / t + inp_clipped = F.maximum(F.minimum(inp_scaled, self.upperbound), self.lowerbound) + inp_rounded = F.round(inp_clipped) + inp_flq = inp_rounded * t + self.save_for_backward(inp_scaled, inp_rounded, t) + return inp_flq + + def backward(self, grad_inp_flq): + (inp_scaled, inp_rounded, t) = self.saved_tensors + mask_clip = (inp_scaled < -0.5 + self.lowerbound) + ( + inp_scaled > self.upperbound + 0.5 + ) # mask for accumulating the gradients of |data_scaled|>L + mask_quant = F.abs( + mask_clip - 1 + ) # mask for accumulating the gradients with |data_scaled|<=L + grad_quant = ( + grad_inp_flq * mask_quant * (inp_rounded - inp_scaled) + ) # gradient within |data_scaled|<=L + grad_clip = ( + grad_inp_flq * mask_clip * inp_rounded + ) # gradient with | data_scaled|>L + grad_s = grad_clip.sum() + grad_quant.sum() + # dL/ds = dL/dt * t * ln(2) + grad_s = grad_s * t * math.log(2) + grad_inp = grad_inp_flq * mask_quant + return grad_inp, grad_s + + +class TQT(_FakeQuantize): + """ + TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds + for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks + """ + + def __init__(self, dtype: str, enable: bool = True): + super().__init__(dtype, enable) + self.scale = Parameter(0.0, dtype=np.float32) + + def fake_quant_forward(self, inp, q_dict): + # when enable, TQT will do fakequant forward, finetune the scale + return TQT_Function(self.qmin, self.qmax)(inp, self.scale) + + def normal_foward(self, inp, q_dict): + # when disable, TQT will do normal forward, initialize scale weight + tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) + tmp_scale = F.log(tmp_scale / 127) / F.log(2) + F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) return inp + + def get_dtype(self): + return get_quantized_dtype(self.dtype, 2 ** self.scale.numpy()[0], None) + + +class FakeQuantize(_FakeQuantize): + r""" + A module to do quant and dequant according to observer's scale and zero_point. + """ + + def fake_quant_forward(self, inp, q_dict): + if q_dict["mode"] == ObserverMode.SYMMERTIC: + scale = q_dict["scale"] + # Quant + oup = Round()(inp / scale) + # clip + oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) + # DeQuant + oup = (oup) * scale + return oup + else: + scale = q_dict["scale"] + zero_point = q_dict["zero_point"] + # Quant + oup = Round()(inp / scale) + zero_point + # clip + oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) + # DeQuant + oup = (oup - zero_point) * scale + return oup diff --git a/python_module/megengine/quantization/observer.py b/python_module/megengine/quantization/observer.py index 8da83d10..8f89b010 100644 --- a/python_module/megengine/quantization/observer.py +++ b/python_module/megengine/quantization/observer.py @@ -107,6 +107,8 @@ class MinMaxObserver(Observer): min_val = F.minimum(0.0, inp_min_val) max_val = F.maximum(0.0, inp_max_val) q_dict = create_observer_dict(self.mode) + q_dict["min_val"] = inp_min_val + q_dict["max_val"] = inp_max_val if self.mode == ObserverMode.SYMMERTIC: symmetric_max_vals = F.maximum(-min_val, max_val) # use maximun to avoid scale too small at the begin diff --git a/python_module/megengine/quantization/qconfig.py b/python_module/megengine/quantization/qconfig.py index 410cefe4..00d82429 100644 --- a/python_module/megengine/quantization/qconfig.py +++ b/python_module/megengine/quantization/qconfig.py @@ -1,12 +1,12 @@ # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") # # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. -# +#' # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. from ..module import Module -from .fake_quant import FakeQuantize +from .fake_quant import TQT, FakeQuantize from .observer import ( ExponentialMovingAverageObserver, HistogramObserver, @@ -52,6 +52,12 @@ class QConfig: self.fake_quant = fake_quant +tqt_quant_qconfig = QConfig( + weight_observer=ExponentialMovingAverageObserver, + act_observer=ExponentialMovingAverageObserver, + fake_quant=TQT, +) + # Default QAT QConfigs min_max_fakequant_qconfig = QConfig( weight_observer=MinMaxObserver, diff --git a/python_module/test/unit/core/test_function.py b/python_module/test/unit/core/test_function.py index 8333c5fe..c58978bb 100644 --- a/python_module/test/unit/core/test_function.py +++ b/python_module/test/unit/core/test_function.py @@ -96,7 +96,6 @@ def test_deepcopy(): origin = Sigmoid(0) new = copy.deepcopy(Sigmoid(0)) assert new.param == origin.param - assert new.saved_tensors == None def test_save_context(): diff --git a/python_module/test/unit/quantization/test_TQT.py b/python_module/test/unit/quantization/test_TQT.py new file mode 100644 index 00000000..11de7e40 --- /dev/null +++ b/python_module/test/unit/quantization/test_TQT.py @@ -0,0 +1,77 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import numpy as np +import pytest + +import megengine as mge +import megengine._internal as mgb +from megengine.core import tensor +from megengine.quantization.fake_quant import TQT_Function +from megengine.test import assertTensorClose + + +class numpy_TQT_Function: + def __init__(self, lowerbound, upperbound): + super().__init__() + self.lowerbound = lowerbound + self.upperbound = upperbound + + def forward(self, inp, scale): + t = 2 ** scale + # t = F.maximum(t, 1e-4) + inp_scaled = inp / t + inp_clipped = np.maximum( + np.minimum(inp_scaled, self.upperbound), self.lowerbound + ) + inp_rounded = np.round(inp_clipped) + inp_flq = inp_rounded * t + self.saved_tensors = (inp_scaled, inp_rounded, t) + return inp_flq + + def backward(self, grad_inp_flq): + (inp_scaled, inp_rounded, t) = self.saved_tensors + mask_clip = (inp_scaled < -0.5 + self.lowerbound) + ( + inp_scaled > self.upperbound + 0.5 + ) # mask for accumulating the gradients of |data_scaled|>L + mask_quant = np.abs( + mask_clip - 1 + ) # mask for accumulating the gradients with |data_scaled|<=L + grad_quant = ( + grad_inp_flq * mask_quant * (inp_rounded - inp_scaled) + ) # gradient within |data_scaled|<=L + grad_clip = ( + grad_inp_flq * mask_clip * inp_rounded + ) # gradient with | data_scaled|>L + grad_s = grad_clip.sum() + grad_quant.sum() + # dL/ds = dL/dt * t * ln(2) + grad_s = grad_s * t * np.log(2) + grad_inp = grad_inp_flq * mask_quant + return grad_inp, grad_s + + +def test_TQT(): + f = TQT_Function(-127, 127) + nf = numpy_TQT_Function(-127, 127) + + def check_inp(a, b, c, a_np, b_np, c_np): + assertTensorClose( + f.forward(a, b).numpy(), nf.forward(a_np, b_np).astype("float32") + ) + c1, c2 = f.backward(c) + c1_np, c2_np = nf.backward(c_np) + assertTensorClose(c1.numpy(), c1_np.astype("float32")) + assertTensorClose(c2.numpy(), c2_np.astype("float32")) + + a = tensor() + b = tensor() + a_np = np.random.random((4, 3)).astype("float32") + b_np = np.random.random((1)).astype("float32") + a.set_value(a_np) + b.set_value(b_np) + check_inp(a, b, b, a_np, b_np, b_np)