GitOrigin-RevId: a5e953b3fa
tags/v1.0.0-rc1
@@ -138,6 +138,7 @@ class Tensor: | |||||
def __init__(self, val=None, *, requires_grad=None): | def __init__(self, val=None, *, requires_grad=None): | ||||
self._reset(val, requires_grad=requires_grad) | self._reset(val, requires_grad=requires_grad) | ||||
self.q_dict = {"mode": None, "scale": None, "zero_point": None} | |||||
def _reset(self, val=None, *, requires_grad=None): | def _reset(self, val=None, *, requires_grad=None): | ||||
self.__sym_override = None | self.__sym_override = None | ||||
@@ -677,9 +678,9 @@ class Tensor: | |||||
def __deepcopy__(self, memo): | def __deepcopy__(self, memo): | ||||
""" | """ | ||||
Since Tensor have __getstate__ and __setstate__ method, | |||||
deepcopy only process the that and ignore the attribute of Parameter. | |||||
So we need to add __deepcopy__ method to deepcopy correct attribute. | |||||
The default deepcopy will ignore other attributes except those defined at | |||||
__getstate__ and __setstate__ method. | |||||
So we need to add __deepcopy__ method to deepcopy correct attributes. | |||||
""" | """ | ||||
assert (self.__val is not None) and ( | assert (self.__val is not None) and ( | ||||
self.__sym is None | self.__sym is None | ||||
@@ -6,6 +6,7 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from ... import functional as F | from ... import functional as F | ||||
from ...quantization.utils import fake_quant_bias | |||||
from .. import conv as Float | from .. import conv as Float | ||||
from .module import QATModule | from .module import QATModule | ||||
@@ -18,7 +19,8 @@ class Conv2d(Float.Conv2d, QATModule): | |||||
def calc_conv_qat(self, inp): | def calc_conv_qat(self, inp): | ||||
w_qat = self.apply_quant_weight(self.weight) | w_qat = self.apply_quant_weight(self.weight) | ||||
conv = self.calc_conv(inp, w_qat, self.bias) | |||||
b_qat = fake_quant_bias(self.bias, inp, w_qat) | |||||
conv = self.calc_conv(inp, w_qat, b_qat) | |||||
return conv | return conv | ||||
@classmethod | @classmethod | ||||
@@ -7,6 +7,7 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from ...core import ones, zeros | from ...core import ones, zeros | ||||
from ...functional import add_update, relu, sqrt, sum, zero_grad | from ...functional import add_update, relu, sqrt, sum, zero_grad | ||||
from ...quantization.utils import fake_quant_bias | |||||
from .. import conv_bn as Float | from .. import conv_bn as Float | ||||
from .module import QATModule | from .module import QATModule | ||||
@@ -132,7 +133,8 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): | |||||
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | ||||
w_qat = self.apply_quant_weight(w_fold) | w_qat = self.apply_quant_weight(w_fold) | ||||
conv = self.conv.calc_conv(inp, w_qat, b_fold) | |||||
b_qat = fake_quant_bias(b_fold, inp, w_qat) | |||||
conv = self.conv.calc_conv(inp, w_qat, b_qat) | |||||
if not (self.training and approx): | if not (self.training and approx): | ||||
return conv | return conv | ||||
@@ -5,6 +5,7 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from ...quantization.utils import fake_quant_bias | |||||
from .. import linear as Float | from .. import linear as Float | ||||
from .module import QATModule | from .module import QATModule | ||||
@@ -23,7 +24,8 @@ class Linear(Float.Linear, QATModule): | |||||
def forward(self, x): | def forward(self, x): | ||||
w_qat = self.apply_quant_weight(self.weight) | w_qat = self.apply_quant_weight(self.weight) | ||||
return self.apply_quant_activation(self._calc_linear(x, w_qat, self.bias),) | |||||
b_qat = fake_quant_bias(self.bias, x, w_qat) | |||||
return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat)) | |||||
@classmethod | @classmethod | ||||
def from_float_module(cls, float_module: Float.Linear): | def from_float_module(cls, float_module: Float.Linear): | ||||
@@ -73,11 +73,16 @@ class QATModule(Module): | |||||
if observer is None: | if observer is None: | ||||
return target | return target | ||||
oup = observer(target) | oup = observer(target) | ||||
if fake_quant is None: | |||||
return oup | |||||
else: | |||||
q_dict = observer.get_qparams() | |||||
return fake_quant(oup, q_dict) | |||||
q_dict = observer.get_qparams() | |||||
# do fake quant | |||||
if fake_quant is not None: | |||||
oup = fake_quant(oup, q_dict) | |||||
# use qparams of fake_quant if have. | |||||
if hasattr(fake_quant, "get_qparams"): | |||||
q_dict = fake_quant.get_qparams() | |||||
# set to tensor qparams. | |||||
oup.q_dict.update(q_dict) | |||||
return oup | |||||
def apply_quant_weight(self, target: Tensor): | def apply_quant_weight(self, target: Tensor): | ||||
r""" | r""" | ||||
@@ -15,8 +15,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||||
from ..core import Buffer, Function, Parameter | from ..core import Buffer, Function, Parameter | ||||
from ..jit import sideeffect | from ..jit import sideeffect | ||||
from ..module import Module | from ..module import Module | ||||
from .observer import Round | |||||
from .utils import QuantMode, get_qparam_dict | |||||
from .utils import QuantMode, Round, fake_quant_tensor, get_qparam_dict | |||||
class _FakeQuantize(Module): | class _FakeQuantize(Module): | ||||
@@ -143,22 +142,4 @@ class FakeQuantize(_FakeQuantize): | |||||
""" | """ | ||||
def fake_quant_forward(self, inp, q_dict): | def fake_quant_forward(self, inp, q_dict): | ||||
if q_dict["mode"] == QuantMode.SYMMERTIC: | |||||
scale = q_dict["scale"] | |||||
# Quant | |||||
oup = Round()(inp / scale) | |||||
# clip | |||||
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||||
# DeQuant | |||||
oup = (oup) * scale | |||||
return oup | |||||
else: | |||||
scale = q_dict["scale"] | |||||
zero_point = q_dict["zero_point"] | |||||
# Quant | |||||
oup = Round()(inp / scale) + zero_point | |||||
# clip | |||||
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||||
# DeQuant | |||||
oup = (oup - zero_point) * scale | |||||
return oup | |||||
return fake_quant_tensor(inp, self.qmin, self.qmax, q_dict) |
@@ -13,18 +13,10 @@ import numpy as np | |||||
from .. import functional as F | from .. import functional as F | ||||
from .._internal.dtype import _metadata_dict, get_quantized_dtype | from .._internal.dtype import _metadata_dict, get_quantized_dtype | ||||
from ..core import Buffer, Function, tensor | |||||
from ..core import Buffer | |||||
from ..jit import sideeffect | from ..jit import sideeffect | ||||
from ..module import Module | from ..module import Module | ||||
from .utils import QuantMode, get_qparam_dict | |||||
class Round(Function): | |||||
def forward(self, x): | |||||
return x.round() | |||||
def backward(self, output_grads): | |||||
return output_grads | |||||
from .utils import QuantMode, Round, get_qparam_dict | |||||
class Observer(Module): | class Observer(Module): | ||||
@@ -8,6 +8,24 @@ | |||||
from enum import Enum | from enum import Enum | ||||
from functools import partial, update_wrapper, wraps | from functools import partial, update_wrapper, wraps | ||||
from typing import Dict | |||||
from .. import functional as F | |||||
from .._internal.dtype import _metadata_dict | |||||
from ..core import Function, Tensor | |||||
class Round(Function): | |||||
""" | |||||
The functional round have no grad and can not use for quantization-aware-training. | |||||
We use Function and STE(Straight-Through Estimator) to implement backward propagation. | |||||
""" | |||||
def forward(self, x): | |||||
return x.round() | |||||
def backward(self, output_grads): | |||||
return output_grads | |||||
def register_method_to_class(cls): | def register_method_to_class(cls): | ||||
@@ -25,6 +43,9 @@ def register_method_to_class(cls): | |||||
class QuantMode(Enum): | class QuantMode(Enum): | ||||
"""Quantization mode enumerate class. | |||||
""" | |||||
SYMMERTIC = 1 | SYMMERTIC = 1 | ||||
ASYMMERTIC = 2 | ASYMMERTIC = 2 | ||||
TQT = 3 | TQT = 3 | ||||
@@ -41,5 +62,55 @@ qparam_dict = { | |||||
} | } | ||||
def get_qparam_dict(mode): | |||||
def get_qparam_dict(mode: QuantMode): | |||||
"""Return the quantization parameters dictory according to the mode. | |||||
""" | |||||
return qparam_dict.get(mode, None) | return qparam_dict.get(mode, None) | ||||
def fake_quant_tensor(inp: Tensor, qmin: int, qmax: int, q_dict: Dict) -> Tensor: | |||||
"""Apply fake quantization to the inp tensor. | |||||
:param inp: the input tensor which need to be faked. | |||||
:param qmin: the minimum value which the integer limit to. | |||||
:param qmax: the maximum value which the integer limit to. | |||||
:param q_dict: the quantization parameter dict. | |||||
""" | |||||
scale = q_dict["scale"] | |||||
zero_point = 0 | |||||
if q_dict["mode"] == QuantMode.ASYMMERTIC: | |||||
zero_point = q_dict["zero_point"] | |||||
# Quant | |||||
oup = Round()(inp / scale) + zero_point | |||||
# Clip | |||||
oup = F.minimum(F.maximum(oup, qmin), qmax) | |||||
# Dequant | |||||
oup = (oup - zero_point) * scale | |||||
return oup | |||||
def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor: | |||||
"""Apply fake quantization to bias, the special scale from input tensor | |||||
and weight tensor, the quantized type set to qint32 also. | |||||
:param bias: the bias tensor which need to be faked. | |||||
:param inp: the input tensor which contain the quantization parameters. | |||||
:param qmax: the weight tensor which contain the quantization parameters. | |||||
.. warning:: | |||||
Only work for symmetric quantization method now. | |||||
""" | |||||
b_qat = bias | |||||
if hasattr(inp, "q_dict") and b_qat is not None: | |||||
if inp.q_dict["scale"] is not None and w_qat.q_dict["scale"] is not None: | |||||
# use the same mode with weight. | |||||
b_dict = get_qparam_dict(w_qat.q_dict["mode"]) | |||||
b_dict["scale"] = inp.q_dict["scale"] * w_qat.q_dict["scale"] | |||||
# TODO: add zero_point for ASYMMERTIC mode. | |||||
qmax = _metadata_dict["qint32"].qmax | |||||
qmin = _metadata_dict["qint32"].qmin | |||||
b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict) | |||||
return b_qat |