@@ -154,6 +154,7 @@ class Function(metaclass=ABCMeta): | |||||
memo[id(self)] = result | memo[id(self)] = result | ||||
for k, v in self.__dict__.items(): | for k, v in self.__dict__.items(): | ||||
setattr(result, k, copy.deepcopy(v, memo)) | setattr(result, k, copy.deepcopy(v, memo)) | ||||
setattr(result, "saved_tensors", tmp) | |||||
self.saved_tensors = tmp | self.saved_tensors = tmp | ||||
return result | return result | ||||
@@ -77,13 +77,19 @@ class QATModule(Module): | |||||
r""" | r""" | ||||
Get weight's quantization dtype as the method from ``qconfig``. | Get weight's quantization dtype as the method from ``qconfig``. | ||||
""" | """ | ||||
return self.weight_observer.get_dtype() | |||||
if hasattr(self.act_fake_quant, "get_dtype"): | |||||
return self.weight_fake_quant.get_dtype() | |||||
else: | |||||
return self.weight_observer.get_dtype() | |||||
def get_activation_dtype(self): | def get_activation_dtype(self): | ||||
r""" | r""" | ||||
Get activation's quantization dtype as the method from ``qconfig``. | Get activation's quantization dtype as the method from ``qconfig``. | ||||
""" | """ | ||||
return self.act_observer.get_dtype() | |||||
if hasattr(self.act_fake_quant, "get_dtype"): | |||||
return self.act_fake_quant.get_dtype() | |||||
else: | |||||
return self.act_observer.get_dtype() | |||||
@classmethod | @classmethod | ||||
@abstractmethod | @abstractmethod | ||||
@@ -12,4 +12,5 @@ from .qconfig import ( | |||||
calibration_qconfig, | calibration_qconfig, | ||||
ema_fakequant_qconfig, | ema_fakequant_qconfig, | ||||
min_max_fakequant_qconfig, | min_max_fakequant_qconfig, | ||||
tqt_quant_qconfig, | |||||
) | ) |
@@ -5,17 +5,20 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import copy | |||||
import math | |||||
import numpy as np | |||||
from .. import functional as F | from .. import functional as F | ||||
from .._internal.dtype import _metadata_dict | |||||
from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||||
from ..core import Buffer, Function, Parameter | |||||
from ..jit import sideeffect | |||||
from ..module import Module | from ..module import Module | ||||
from .observer import ObserverMode, Round | from .observer import ObserverMode, Round | ||||
class FakeQuantize(Module): | |||||
r""" | |||||
A module to do quant and dequant according to observer's scale and zero_point. | |||||
""" | |||||
class _FakeQuantize(Module): | |||||
def __init__(self, dtype: str, enable: bool = True): | def __init__(self, dtype: str, enable: bool = True): | ||||
super().__init__() | super().__init__() | ||||
if not dtype in _metadata_dict.keys(): | if not dtype in _metadata_dict.keys(): | ||||
@@ -35,25 +38,103 @@ class FakeQuantize(Module): | |||||
def disable(self): | def disable(self): | ||||
self.enabled = False | self.enabled = False | ||||
def fake_quant_forward(self, inp, q_dict): | |||||
return inp | |||||
def normal_foward(self, inp, q_dict): | |||||
return inp | |||||
def forward(self, inp, q_dict): | def forward(self, inp, q_dict): | ||||
if self.enabled: | if self.enabled: | ||||
if q_dict["mode"] == ObserverMode.SYMMERTIC: | |||||
scale = q_dict["scale"] | |||||
# Quant | |||||
oup = Round()(inp / scale) | |||||
# clip | |||||
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||||
# DeQuant | |||||
oup = (oup) * scale | |||||
return oup | |||||
else: | |||||
scale = q_dict["scale"] | |||||
zero_point = q_dict["zero_point"] | |||||
# Quant | |||||
oup = Round()(inp / scale) + zero_point | |||||
# clip | |||||
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||||
# DeQuant | |||||
oup = (oup - zero_point) * scale | |||||
return oup | |||||
return self.fake_quant_forward(inp, q_dict) | |||||
else: | |||||
return self.normal_foward(inp, q_dict) | |||||
class TQT_Function(Function): | |||||
def __init__(self, lowerbound, upperbound): | |||||
super().__init__() | |||||
self.lowerbound = lowerbound | |||||
self.upperbound = upperbound | |||||
def forward(self, inp, scale): | |||||
t = 2 ** scale | |||||
# t = F.maximum(t, 1e-4) | |||||
inp_scaled = inp / t | |||||
inp_clipped = F.maximum(F.minimum(inp_scaled, self.upperbound), self.lowerbound) | |||||
inp_rounded = F.round(inp_clipped) | |||||
inp_flq = inp_rounded * t | |||||
self.save_for_backward(inp_scaled, inp_rounded, t) | |||||
return inp_flq | |||||
def backward(self, grad_inp_flq): | |||||
(inp_scaled, inp_rounded, t) = self.saved_tensors | |||||
mask_clip = (inp_scaled < -0.5 + self.lowerbound) + ( | |||||
inp_scaled > self.upperbound + 0.5 | |||||
) # mask for accumulating the gradients of |data_scaled|>L | |||||
mask_quant = F.abs( | |||||
mask_clip - 1 | |||||
) # mask for accumulating the gradients with |data_scaled|<=L | |||||
grad_quant = ( | |||||
grad_inp_flq * mask_quant * (inp_rounded - inp_scaled) | |||||
) # gradient within |data_scaled|<=L | |||||
grad_clip = ( | |||||
grad_inp_flq * mask_clip * inp_rounded | |||||
) # gradient with | data_scaled|>L | |||||
grad_s = grad_clip.sum() + grad_quant.sum() | |||||
# dL/ds = dL/dt * t * ln(2) | |||||
grad_s = grad_s * t * math.log(2) | |||||
grad_inp = grad_inp_flq * mask_quant | |||||
return grad_inp, grad_s | |||||
class TQT(_FakeQuantize): | |||||
""" | |||||
TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds | |||||
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks | |||||
""" | |||||
def __init__(self, dtype: str, enable: bool = True): | |||||
super().__init__(dtype, enable) | |||||
self.scale = Parameter(0.0, dtype=np.float32) | |||||
def fake_quant_forward(self, inp, q_dict): | |||||
# when enable, TQT will do fakequant forward, finetune the scale | |||||
return TQT_Function(self.qmin, self.qmax)(inp, self.scale) | |||||
def normal_foward(self, inp, q_dict): | |||||
# when disable, TQT will do normal forward, initialize scale weight | |||||
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) | |||||
tmp_scale = F.log(tmp_scale / 127) / F.log(2) | |||||
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) | |||||
return inp | return inp | ||||
def get_dtype(self): | |||||
return get_quantized_dtype(self.dtype, 2 ** self.scale.numpy()[0], None) | |||||
class FakeQuantize(_FakeQuantize): | |||||
r""" | |||||
A module to do quant and dequant according to observer's scale and zero_point. | |||||
""" | |||||
def fake_quant_forward(self, inp, q_dict): | |||||
if q_dict["mode"] == ObserverMode.SYMMERTIC: | |||||
scale = q_dict["scale"] | |||||
# Quant | |||||
oup = Round()(inp / scale) | |||||
# clip | |||||
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||||
# DeQuant | |||||
oup = (oup) * scale | |||||
return oup | |||||
else: | |||||
scale = q_dict["scale"] | |||||
zero_point = q_dict["zero_point"] | |||||
# Quant | |||||
oup = Round()(inp / scale) + zero_point | |||||
# clip | |||||
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||||
# DeQuant | |||||
oup = (oup - zero_point) * scale | |||||
return oup |
@@ -107,6 +107,8 @@ class MinMaxObserver(Observer): | |||||
min_val = F.minimum(0.0, inp_min_val) | min_val = F.minimum(0.0, inp_min_val) | ||||
max_val = F.maximum(0.0, inp_max_val) | max_val = F.maximum(0.0, inp_max_val) | ||||
q_dict = create_observer_dict(self.mode) | q_dict = create_observer_dict(self.mode) | ||||
q_dict["min_val"] = inp_min_val | |||||
q_dict["max_val"] = inp_max_val | |||||
if self.mode == ObserverMode.SYMMERTIC: | if self.mode == ObserverMode.SYMMERTIC: | ||||
symmetric_max_vals = F.maximum(-min_val, max_val) | symmetric_max_vals = F.maximum(-min_val, max_val) | ||||
# use maximun to avoid scale too small at the begin | # use maximun to avoid scale too small at the begin | ||||
@@ -1,12 +1,12 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
# | # | ||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
# | |||||
#' | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from ..module import Module | from ..module import Module | ||||
from .fake_quant import FakeQuantize | |||||
from .fake_quant import TQT, FakeQuantize | |||||
from .observer import ( | from .observer import ( | ||||
ExponentialMovingAverageObserver, | ExponentialMovingAverageObserver, | ||||
HistogramObserver, | HistogramObserver, | ||||
@@ -52,6 +52,12 @@ class QConfig: | |||||
self.fake_quant = fake_quant | self.fake_quant = fake_quant | ||||
tqt_quant_qconfig = QConfig( | |||||
weight_observer=ExponentialMovingAverageObserver, | |||||
act_observer=ExponentialMovingAverageObserver, | |||||
fake_quant=TQT, | |||||
) | |||||
# Default QAT QConfigs | # Default QAT QConfigs | ||||
min_max_fakequant_qconfig = QConfig( | min_max_fakequant_qconfig = QConfig( | ||||
weight_observer=MinMaxObserver, | weight_observer=MinMaxObserver, | ||||
@@ -96,7 +96,6 @@ def test_deepcopy(): | |||||
origin = Sigmoid(0) | origin = Sigmoid(0) | ||||
new = copy.deepcopy(Sigmoid(0)) | new = copy.deepcopy(Sigmoid(0)) | ||||
assert new.param == origin.param | assert new.param == origin.param | ||||
assert new.saved_tensors == None | |||||
def test_save_context(): | def test_save_context(): | ||||
@@ -0,0 +1,77 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import numpy as np | |||||
import pytest | |||||
import megengine as mge | |||||
import megengine._internal as mgb | |||||
from megengine.core import tensor | |||||
from megengine.quantization.fake_quant import TQT_Function | |||||
from megengine.test import assertTensorClose | |||||
class numpy_TQT_Function: | |||||
def __init__(self, lowerbound, upperbound): | |||||
super().__init__() | |||||
self.lowerbound = lowerbound | |||||
self.upperbound = upperbound | |||||
def forward(self, inp, scale): | |||||
t = 2 ** scale | |||||
# t = F.maximum(t, 1e-4) | |||||
inp_scaled = inp / t | |||||
inp_clipped = np.maximum( | |||||
np.minimum(inp_scaled, self.upperbound), self.lowerbound | |||||
) | |||||
inp_rounded = np.round(inp_clipped) | |||||
inp_flq = inp_rounded * t | |||||
self.saved_tensors = (inp_scaled, inp_rounded, t) | |||||
return inp_flq | |||||
def backward(self, grad_inp_flq): | |||||
(inp_scaled, inp_rounded, t) = self.saved_tensors | |||||
mask_clip = (inp_scaled < -0.5 + self.lowerbound) + ( | |||||
inp_scaled > self.upperbound + 0.5 | |||||
) # mask for accumulating the gradients of |data_scaled|>L | |||||
mask_quant = np.abs( | |||||
mask_clip - 1 | |||||
) # mask for accumulating the gradients with |data_scaled|<=L | |||||
grad_quant = ( | |||||
grad_inp_flq * mask_quant * (inp_rounded - inp_scaled) | |||||
) # gradient within |data_scaled|<=L | |||||
grad_clip = ( | |||||
grad_inp_flq * mask_clip * inp_rounded | |||||
) # gradient with | data_scaled|>L | |||||
grad_s = grad_clip.sum() + grad_quant.sum() | |||||
# dL/ds = dL/dt * t * ln(2) | |||||
grad_s = grad_s * t * np.log(2) | |||||
grad_inp = grad_inp_flq * mask_quant | |||||
return grad_inp, grad_s | |||||
def test_TQT(): | |||||
f = TQT_Function(-127, 127) | |||||
nf = numpy_TQT_Function(-127, 127) | |||||
def check_inp(a, b, c, a_np, b_np, c_np): | |||||
assertTensorClose( | |||||
f.forward(a, b).numpy(), nf.forward(a_np, b_np).astype("float32") | |||||
) | |||||
c1, c2 = f.backward(c) | |||||
c1_np, c2_np = nf.backward(c_np) | |||||
assertTensorClose(c1.numpy(), c1_np.astype("float32")) | |||||
assertTensorClose(c2.numpy(), c2_np.astype("float32")) | |||||
a = tensor() | |||||
b = tensor() | |||||
a_np = np.random.random((4, 3)).astype("float32") | |||||
b_np = np.random.random((1)).astype("float32") | |||||
a.set_value(a_np) | |||||
b.set_value(b_np) | |||||
check_inp(a, b, b, a_np, b_np, b_np) |