@@ -154,6 +154,7 @@ class Function(metaclass=ABCMeta): | |||
memo[id(self)] = result | |||
for k, v in self.__dict__.items(): | |||
setattr(result, k, copy.deepcopy(v, memo)) | |||
setattr(result, "saved_tensors", tmp) | |||
self.saved_tensors = tmp | |||
return result | |||
@@ -77,13 +77,19 @@ class QATModule(Module): | |||
r""" | |||
Get weight's quantization dtype as the method from ``qconfig``. | |||
""" | |||
return self.weight_observer.get_dtype() | |||
if hasattr(self.act_fake_quant, "get_dtype"): | |||
return self.weight_fake_quant.get_dtype() | |||
else: | |||
return self.weight_observer.get_dtype() | |||
def get_activation_dtype(self): | |||
r""" | |||
Get activation's quantization dtype as the method from ``qconfig``. | |||
""" | |||
return self.act_observer.get_dtype() | |||
if hasattr(self.act_fake_quant, "get_dtype"): | |||
return self.act_fake_quant.get_dtype() | |||
else: | |||
return self.act_observer.get_dtype() | |||
@classmethod | |||
@abstractmethod | |||
@@ -12,4 +12,5 @@ from .qconfig import ( | |||
calibration_qconfig, | |||
ema_fakequant_qconfig, | |||
min_max_fakequant_qconfig, | |||
tqt_quant_qconfig, | |||
) |
@@ -5,17 +5,20 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import copy | |||
import math | |||
import numpy as np | |||
from .. import functional as F | |||
from .._internal.dtype import _metadata_dict | |||
from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||
from ..core import Buffer, Function, Parameter | |||
from ..jit import sideeffect | |||
from ..module import Module | |||
from .observer import ObserverMode, Round | |||
class FakeQuantize(Module): | |||
r""" | |||
A module to do quant and dequant according to observer's scale and zero_point. | |||
""" | |||
class _FakeQuantize(Module): | |||
def __init__(self, dtype: str, enable: bool = True): | |||
super().__init__() | |||
if not dtype in _metadata_dict.keys(): | |||
@@ -35,25 +38,103 @@ class FakeQuantize(Module): | |||
def disable(self): | |||
self.enabled = False | |||
def fake_quant_forward(self, inp, q_dict): | |||
return inp | |||
def normal_foward(self, inp, q_dict): | |||
return inp | |||
def forward(self, inp, q_dict): | |||
if self.enabled: | |||
if q_dict["mode"] == ObserverMode.SYMMERTIC: | |||
scale = q_dict["scale"] | |||
# Quant | |||
oup = Round()(inp / scale) | |||
# clip | |||
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||
# DeQuant | |||
oup = (oup) * scale | |||
return oup | |||
else: | |||
scale = q_dict["scale"] | |||
zero_point = q_dict["zero_point"] | |||
# Quant | |||
oup = Round()(inp / scale) + zero_point | |||
# clip | |||
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||
# DeQuant | |||
oup = (oup - zero_point) * scale | |||
return oup | |||
return self.fake_quant_forward(inp, q_dict) | |||
else: | |||
return self.normal_foward(inp, q_dict) | |||
class TQT_Function(Function): | |||
def __init__(self, lowerbound, upperbound): | |||
super().__init__() | |||
self.lowerbound = lowerbound | |||
self.upperbound = upperbound | |||
def forward(self, inp, scale): | |||
t = 2 ** scale | |||
# t = F.maximum(t, 1e-4) | |||
inp_scaled = inp / t | |||
inp_clipped = F.maximum(F.minimum(inp_scaled, self.upperbound), self.lowerbound) | |||
inp_rounded = F.round(inp_clipped) | |||
inp_flq = inp_rounded * t | |||
self.save_for_backward(inp_scaled, inp_rounded, t) | |||
return inp_flq | |||
def backward(self, grad_inp_flq): | |||
(inp_scaled, inp_rounded, t) = self.saved_tensors | |||
mask_clip = (inp_scaled < -0.5 + self.lowerbound) + ( | |||
inp_scaled > self.upperbound + 0.5 | |||
) # mask for accumulating the gradients of |data_scaled|>L | |||
mask_quant = F.abs( | |||
mask_clip - 1 | |||
) # mask for accumulating the gradients with |data_scaled|<=L | |||
grad_quant = ( | |||
grad_inp_flq * mask_quant * (inp_rounded - inp_scaled) | |||
) # gradient within |data_scaled|<=L | |||
grad_clip = ( | |||
grad_inp_flq * mask_clip * inp_rounded | |||
) # gradient with | data_scaled|>L | |||
grad_s = grad_clip.sum() + grad_quant.sum() | |||
# dL/ds = dL/dt * t * ln(2) | |||
grad_s = grad_s * t * math.log(2) | |||
grad_inp = grad_inp_flq * mask_quant | |||
return grad_inp, grad_s | |||
class TQT(_FakeQuantize): | |||
""" | |||
TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds | |||
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks | |||
""" | |||
def __init__(self, dtype: str, enable: bool = True): | |||
super().__init__(dtype, enable) | |||
self.scale = Parameter(0.0, dtype=np.float32) | |||
def fake_quant_forward(self, inp, q_dict): | |||
# when enable, TQT will do fakequant forward, finetune the scale | |||
return TQT_Function(self.qmin, self.qmax)(inp, self.scale) | |||
def normal_foward(self, inp, q_dict): | |||
# when disable, TQT will do normal forward, initialize scale weight | |||
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) | |||
tmp_scale = F.log(tmp_scale / 127) / F.log(2) | |||
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) | |||
return inp | |||
def get_dtype(self): | |||
return get_quantized_dtype(self.dtype, 2 ** self.scale.numpy()[0], None) | |||
class FakeQuantize(_FakeQuantize): | |||
r""" | |||
A module to do quant and dequant according to observer's scale and zero_point. | |||
""" | |||
def fake_quant_forward(self, inp, q_dict): | |||
if q_dict["mode"] == ObserverMode.SYMMERTIC: | |||
scale = q_dict["scale"] | |||
# Quant | |||
oup = Round()(inp / scale) | |||
# clip | |||
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||
# DeQuant | |||
oup = (oup) * scale | |||
return oup | |||
else: | |||
scale = q_dict["scale"] | |||
zero_point = q_dict["zero_point"] | |||
# Quant | |||
oup = Round()(inp / scale) + zero_point | |||
# clip | |||
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||
# DeQuant | |||
oup = (oup - zero_point) * scale | |||
return oup |
@@ -107,6 +107,8 @@ class MinMaxObserver(Observer): | |||
min_val = F.minimum(0.0, inp_min_val) | |||
max_val = F.maximum(0.0, inp_max_val) | |||
q_dict = create_observer_dict(self.mode) | |||
q_dict["min_val"] = inp_min_val | |||
q_dict["max_val"] = inp_max_val | |||
if self.mode == ObserverMode.SYMMERTIC: | |||
symmetric_max_vals = F.maximum(-min_val, max_val) | |||
# use maximun to avoid scale too small at the begin | |||
@@ -1,12 +1,12 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
#' | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..module import Module | |||
from .fake_quant import FakeQuantize | |||
from .fake_quant import TQT, FakeQuantize | |||
from .observer import ( | |||
ExponentialMovingAverageObserver, | |||
HistogramObserver, | |||
@@ -52,6 +52,12 @@ class QConfig: | |||
self.fake_quant = fake_quant | |||
tqt_quant_qconfig = QConfig( | |||
weight_observer=ExponentialMovingAverageObserver, | |||
act_observer=ExponentialMovingAverageObserver, | |||
fake_quant=TQT, | |||
) | |||
# Default QAT QConfigs | |||
min_max_fakequant_qconfig = QConfig( | |||
weight_observer=MinMaxObserver, | |||
@@ -96,7 +96,6 @@ def test_deepcopy(): | |||
origin = Sigmoid(0) | |||
new = copy.deepcopy(Sigmoid(0)) | |||
assert new.param == origin.param | |||
assert new.saved_tensors == None | |||
def test_save_context(): | |||
@@ -0,0 +1,77 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
import pytest | |||
import megengine as mge | |||
import megengine._internal as mgb | |||
from megengine.core import tensor | |||
from megengine.quantization.fake_quant import TQT_Function | |||
from megengine.test import assertTensorClose | |||
class numpy_TQT_Function: | |||
def __init__(self, lowerbound, upperbound): | |||
super().__init__() | |||
self.lowerbound = lowerbound | |||
self.upperbound = upperbound | |||
def forward(self, inp, scale): | |||
t = 2 ** scale | |||
# t = F.maximum(t, 1e-4) | |||
inp_scaled = inp / t | |||
inp_clipped = np.maximum( | |||
np.minimum(inp_scaled, self.upperbound), self.lowerbound | |||
) | |||
inp_rounded = np.round(inp_clipped) | |||
inp_flq = inp_rounded * t | |||
self.saved_tensors = (inp_scaled, inp_rounded, t) | |||
return inp_flq | |||
def backward(self, grad_inp_flq): | |||
(inp_scaled, inp_rounded, t) = self.saved_tensors | |||
mask_clip = (inp_scaled < -0.5 + self.lowerbound) + ( | |||
inp_scaled > self.upperbound + 0.5 | |||
) # mask for accumulating the gradients of |data_scaled|>L | |||
mask_quant = np.abs( | |||
mask_clip - 1 | |||
) # mask for accumulating the gradients with |data_scaled|<=L | |||
grad_quant = ( | |||
grad_inp_flq * mask_quant * (inp_rounded - inp_scaled) | |||
) # gradient within |data_scaled|<=L | |||
grad_clip = ( | |||
grad_inp_flq * mask_clip * inp_rounded | |||
) # gradient with | data_scaled|>L | |||
grad_s = grad_clip.sum() + grad_quant.sum() | |||
# dL/ds = dL/dt * t * ln(2) | |||
grad_s = grad_s * t * np.log(2) | |||
grad_inp = grad_inp_flq * mask_quant | |||
return grad_inp, grad_s | |||
def test_TQT(): | |||
f = TQT_Function(-127, 127) | |||
nf = numpy_TQT_Function(-127, 127) | |||
def check_inp(a, b, c, a_np, b_np, c_np): | |||
assertTensorClose( | |||
f.forward(a, b).numpy(), nf.forward(a_np, b_np).astype("float32") | |||
) | |||
c1, c2 = f.backward(c) | |||
c1_np, c2_np = nf.backward(c_np) | |||
assertTensorClose(c1.numpy(), c1_np.astype("float32")) | |||
assertTensorClose(c2.numpy(), c2_np.astype("float32")) | |||
a = tensor() | |||
b = tensor() | |||
a_np = np.random.random((4, 3)).astype("float32") | |||
b_np = np.random.random((1)).astype("float32") | |||
a.set_value(a_np) | |||
b.set_value(b_np) | |||
check_inp(a, b, b, a_np, b_np, b_np) |