GitOrigin-RevId: a5e953b3fa
tags/v1.0.0-rc1
@@ -138,6 +138,7 @@ class Tensor: | |||
def __init__(self, val=None, *, requires_grad=None): | |||
self._reset(val, requires_grad=requires_grad) | |||
self.q_dict = {"mode": None, "scale": None, "zero_point": None} | |||
def _reset(self, val=None, *, requires_grad=None): | |||
self.__sym_override = None | |||
@@ -677,9 +678,9 @@ class Tensor: | |||
def __deepcopy__(self, memo): | |||
""" | |||
Since Tensor have __getstate__ and __setstate__ method, | |||
deepcopy only process the that and ignore the attribute of Parameter. | |||
So we need to add __deepcopy__ method to deepcopy correct attribute. | |||
The default deepcopy will ignore other attributes except those defined at | |||
__getstate__ and __setstate__ method. | |||
So we need to add __deepcopy__ method to deepcopy correct attributes. | |||
""" | |||
assert (self.__val is not None) and ( | |||
self.__sym is None | |||
@@ -6,6 +6,7 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ... import functional as F | |||
from ...quantization.utils import fake_quant_bias | |||
from .. import conv as Float | |||
from .module import QATModule | |||
@@ -18,7 +19,8 @@ class Conv2d(Float.Conv2d, QATModule): | |||
def calc_conv_qat(self, inp): | |||
w_qat = self.apply_quant_weight(self.weight) | |||
conv = self.calc_conv(inp, w_qat, self.bias) | |||
b_qat = fake_quant_bias(self.bias, inp, w_qat) | |||
conv = self.calc_conv(inp, w_qat, b_qat) | |||
return conv | |||
@classmethod | |||
@@ -7,6 +7,7 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ...core import ones, zeros | |||
from ...functional import add_update, relu, sqrt, sum, zero_grad | |||
from ...quantization.utils import fake_quant_bias | |||
from .. import conv_bn as Float | |||
from .module import QATModule | |||
@@ -132,7 +133,8 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): | |||
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | |||
w_qat = self.apply_quant_weight(w_fold) | |||
conv = self.conv.calc_conv(inp, w_qat, b_fold) | |||
b_qat = fake_quant_bias(b_fold, inp, w_qat) | |||
conv = self.conv.calc_conv(inp, w_qat, b_qat) | |||
if not (self.training and approx): | |||
return conv | |||
@@ -5,6 +5,7 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ...quantization.utils import fake_quant_bias | |||
from .. import linear as Float | |||
from .module import QATModule | |||
@@ -23,7 +24,8 @@ class Linear(Float.Linear, QATModule): | |||
def forward(self, x): | |||
w_qat = self.apply_quant_weight(self.weight) | |||
return self.apply_quant_activation(self._calc_linear(x, w_qat, self.bias),) | |||
b_qat = fake_quant_bias(self.bias, x, w_qat) | |||
return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat)) | |||
@classmethod | |||
def from_float_module(cls, float_module: Float.Linear): | |||
@@ -73,11 +73,16 @@ class QATModule(Module): | |||
if observer is None: | |||
return target | |||
oup = observer(target) | |||
if fake_quant is None: | |||
return oup | |||
else: | |||
q_dict = observer.get_qparams() | |||
return fake_quant(oup, q_dict) | |||
q_dict = observer.get_qparams() | |||
# do fake quant | |||
if fake_quant is not None: | |||
oup = fake_quant(oup, q_dict) | |||
# use qparams of fake_quant if have. | |||
if hasattr(fake_quant, "get_qparams"): | |||
q_dict = fake_quant.get_qparams() | |||
# set to tensor qparams. | |||
oup.q_dict.update(q_dict) | |||
return oup | |||
def apply_quant_weight(self, target: Tensor): | |||
r""" | |||
@@ -15,8 +15,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||
from ..core import Buffer, Function, Parameter | |||
from ..jit import sideeffect | |||
from ..module import Module | |||
from .observer import Round | |||
from .utils import QuantMode, get_qparam_dict | |||
from .utils import QuantMode, Round, fake_quant_tensor, get_qparam_dict | |||
class _FakeQuantize(Module): | |||
@@ -143,22 +142,4 @@ class FakeQuantize(_FakeQuantize): | |||
""" | |||
def fake_quant_forward(self, inp, q_dict): | |||
if q_dict["mode"] == QuantMode.SYMMERTIC: | |||
scale = q_dict["scale"] | |||
# Quant | |||
oup = Round()(inp / scale) | |||
# clip | |||
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||
# DeQuant | |||
oup = (oup) * scale | |||
return oup | |||
else: | |||
scale = q_dict["scale"] | |||
zero_point = q_dict["zero_point"] | |||
# Quant | |||
oup = Round()(inp / scale) + zero_point | |||
# clip | |||
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||
# DeQuant | |||
oup = (oup - zero_point) * scale | |||
return oup | |||
return fake_quant_tensor(inp, self.qmin, self.qmax, q_dict) |
@@ -13,18 +13,10 @@ import numpy as np | |||
from .. import functional as F | |||
from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||
from ..core import Buffer, Function, tensor | |||
from ..core import Buffer | |||
from ..jit import sideeffect | |||
from ..module import Module | |||
from .utils import QuantMode, get_qparam_dict | |||
class Round(Function): | |||
def forward(self, x): | |||
return x.round() | |||
def backward(self, output_grads): | |||
return output_grads | |||
from .utils import QuantMode, Round, get_qparam_dict | |||
class Observer(Module): | |||
@@ -8,6 +8,24 @@ | |||
from enum import Enum | |||
from functools import partial, update_wrapper, wraps | |||
from typing import Dict | |||
from .. import functional as F | |||
from .._internal.dtype import _metadata_dict | |||
from ..core import Function, Tensor | |||
class Round(Function): | |||
""" | |||
The functional round have no grad and can not use for quantization-aware-training. | |||
We use Function and STE(Straight-Through Estimator) to implement backward propagation. | |||
""" | |||
def forward(self, x): | |||
return x.round() | |||
def backward(self, output_grads): | |||
return output_grads | |||
def register_method_to_class(cls): | |||
@@ -25,6 +43,9 @@ def register_method_to_class(cls): | |||
class QuantMode(Enum): | |||
"""Quantization mode enumerate class. | |||
""" | |||
SYMMERTIC = 1 | |||
ASYMMERTIC = 2 | |||
TQT = 3 | |||
@@ -41,5 +62,55 @@ qparam_dict = { | |||
} | |||
def get_qparam_dict(mode): | |||
def get_qparam_dict(mode: QuantMode): | |||
"""Return the quantization parameters dictory according to the mode. | |||
""" | |||
return qparam_dict.get(mode, None) | |||
def fake_quant_tensor(inp: Tensor, qmin: int, qmax: int, q_dict: Dict) -> Tensor: | |||
"""Apply fake quantization to the inp tensor. | |||
:param inp: the input tensor which need to be faked. | |||
:param qmin: the minimum value which the integer limit to. | |||
:param qmax: the maximum value which the integer limit to. | |||
:param q_dict: the quantization parameter dict. | |||
""" | |||
scale = q_dict["scale"] | |||
zero_point = 0 | |||
if q_dict["mode"] == QuantMode.ASYMMERTIC: | |||
zero_point = q_dict["zero_point"] | |||
# Quant | |||
oup = Round()(inp / scale) + zero_point | |||
# Clip | |||
oup = F.minimum(F.maximum(oup, qmin), qmax) | |||
# Dequant | |||
oup = (oup - zero_point) * scale | |||
return oup | |||
def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor: | |||
"""Apply fake quantization to bias, the special scale from input tensor | |||
and weight tensor, the quantized type set to qint32 also. | |||
:param bias: the bias tensor which need to be faked. | |||
:param inp: the input tensor which contain the quantization parameters. | |||
:param qmax: the weight tensor which contain the quantization parameters. | |||
.. warning:: | |||
Only work for symmetric quantization method now. | |||
""" | |||
b_qat = bias | |||
if hasattr(inp, "q_dict") and b_qat is not None: | |||
if inp.q_dict["scale"] is not None and w_qat.q_dict["scale"] is not None: | |||
# use the same mode with weight. | |||
b_dict = get_qparam_dict(w_qat.q_dict["mode"]) | |||
b_dict["scale"] = inp.q_dict["scale"] * w_qat.q_dict["scale"] | |||
# TODO: add zero_point for ASYMMERTIC mode. | |||
qmax = _metadata_dict["qint32"].qmax | |||
qmin = _metadata_dict["qint32"].qmin | |||
b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict) | |||
return b_qat |