Browse Source

feat(mge/quantization): add bias fakequant support

GitOrigin-RevId: a5e953b3fa
tags/v1.0.0-rc1
Megvii Engine Team Xinran Xu 4 years ago
parent
commit
555ecea9bc
8 changed files with 99 additions and 43 deletions
  1. +4
    -3
      python_module/megengine/core/tensor.py
  2. +3
    -1
      python_module/megengine/module/qat/conv.py
  3. +3
    -1
      python_module/megengine/module/qat/conv_bn.py
  4. +3
    -1
      python_module/megengine/module/qat/linear.py
  5. +10
    -5
      python_module/megengine/module/qat/module.py
  6. +2
    -21
      python_module/megengine/quantization/fake_quant.py
  7. +2
    -10
      python_module/megengine/quantization/observer.py
  8. +72
    -1
      python_module/megengine/quantization/utils.py

+ 4
- 3
python_module/megengine/core/tensor.py View File

@@ -138,6 +138,7 @@ class Tensor:


def __init__(self, val=None, *, requires_grad=None): def __init__(self, val=None, *, requires_grad=None):
self._reset(val, requires_grad=requires_grad) self._reset(val, requires_grad=requires_grad)
self.q_dict = {"mode": None, "scale": None, "zero_point": None}


def _reset(self, val=None, *, requires_grad=None): def _reset(self, val=None, *, requires_grad=None):
self.__sym_override = None self.__sym_override = None
@@ -677,9 +678,9 @@ class Tensor:


def __deepcopy__(self, memo): def __deepcopy__(self, memo):
""" """
Since Tensor have __getstate__ and __setstate__ method,
deepcopy only process the that and ignore the attribute of Parameter.
So we need to add __deepcopy__ method to deepcopy correct attribute.
The default deepcopy will ignore other attributes except those defined at
__getstate__ and __setstate__ method.
So we need to add __deepcopy__ method to deepcopy correct attributes.
""" """
assert (self.__val is not None) and ( assert (self.__val is not None) and (
self.__sym is None self.__sym is None


+ 3
- 1
python_module/megengine/module/qat/conv.py View File

@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import functional as F from ... import functional as F
from ...quantization.utils import fake_quant_bias
from .. import conv as Float from .. import conv as Float
from .module import QATModule from .module import QATModule


@@ -18,7 +19,8 @@ class Conv2d(Float.Conv2d, QATModule):


def calc_conv_qat(self, inp): def calc_conv_qat(self, inp):
w_qat = self.apply_quant_weight(self.weight) w_qat = self.apply_quant_weight(self.weight)
conv = self.calc_conv(inp, w_qat, self.bias)
b_qat = fake_quant_bias(self.bias, inp, w_qat)
conv = self.calc_conv(inp, w_qat, b_qat)
return conv return conv


@classmethod @classmethod


+ 3
- 1
python_module/megengine/module/qat/conv_bn.py View File

@@ -7,6 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...core import ones, zeros from ...core import ones, zeros
from ...functional import add_update, relu, sqrt, sum, zero_grad from ...functional import add_update, relu, sqrt, sum, zero_grad
from ...quantization.utils import fake_quant_bias
from .. import conv_bn as Float from .. import conv_bn as Float
from .module import QATModule from .module import QATModule


@@ -132,7 +133,8 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd


w_qat = self.apply_quant_weight(w_fold) w_qat = self.apply_quant_weight(w_fold)
conv = self.conv.calc_conv(inp, w_qat, b_fold)
b_qat = fake_quant_bias(b_fold, inp, w_qat)
conv = self.conv.calc_conv(inp, w_qat, b_qat)
if not (self.training and approx): if not (self.training and approx):
return conv return conv




+ 3
- 1
python_module/megengine/module/qat/linear.py View File

@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...quantization.utils import fake_quant_bias
from .. import linear as Float from .. import linear as Float
from .module import QATModule from .module import QATModule


@@ -23,7 +24,8 @@ class Linear(Float.Linear, QATModule):


def forward(self, x): def forward(self, x):
w_qat = self.apply_quant_weight(self.weight) w_qat = self.apply_quant_weight(self.weight)
return self.apply_quant_activation(self._calc_linear(x, w_qat, self.bias),)
b_qat = fake_quant_bias(self.bias, x, w_qat)
return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat))


@classmethod @classmethod
def from_float_module(cls, float_module: Float.Linear): def from_float_module(cls, float_module: Float.Linear):


+ 10
- 5
python_module/megengine/module/qat/module.py View File

@@ -73,11 +73,16 @@ class QATModule(Module):
if observer is None: if observer is None:
return target return target
oup = observer(target) oup = observer(target)
if fake_quant is None:
return oup
else:
q_dict = observer.get_qparams()
return fake_quant(oup, q_dict)
q_dict = observer.get_qparams()
# do fake quant
if fake_quant is not None:
oup = fake_quant(oup, q_dict)
# use qparams of fake_quant if have.
if hasattr(fake_quant, "get_qparams"):
q_dict = fake_quant.get_qparams()
# set to tensor qparams.
oup.q_dict.update(q_dict)
return oup


def apply_quant_weight(self, target: Tensor): def apply_quant_weight(self, target: Tensor):
r""" r"""


+ 2
- 21
python_module/megengine/quantization/fake_quant.py View File

@@ -15,8 +15,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, Parameter from ..core import Buffer, Function, Parameter
from ..jit import sideeffect from ..jit import sideeffect
from ..module import Module from ..module import Module
from .observer import Round
from .utils import QuantMode, get_qparam_dict
from .utils import QuantMode, Round, fake_quant_tensor, get_qparam_dict




class _FakeQuantize(Module): class _FakeQuantize(Module):
@@ -143,22 +142,4 @@ class FakeQuantize(_FakeQuantize):
""" """


def fake_quant_forward(self, inp, q_dict): def fake_quant_forward(self, inp, q_dict):
if q_dict["mode"] == QuantMode.SYMMERTIC:
scale = q_dict["scale"]
# Quant
oup = Round()(inp / scale)
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup) * scale
return oup
else:
scale = q_dict["scale"]
zero_point = q_dict["zero_point"]
# Quant
oup = Round()(inp / scale) + zero_point
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup - zero_point) * scale
return oup
return fake_quant_tensor(inp, self.qmin, self.qmax, q_dict)

+ 2
- 10
python_module/megengine/quantization/observer.py View File

@@ -13,18 +13,10 @@ import numpy as np


from .. import functional as F from .. import functional as F
from .._internal.dtype import _metadata_dict, get_quantized_dtype from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, tensor
from ..core import Buffer
from ..jit import sideeffect from ..jit import sideeffect
from ..module import Module from ..module import Module
from .utils import QuantMode, get_qparam_dict


class Round(Function):
def forward(self, x):
return x.round()

def backward(self, output_grads):
return output_grads
from .utils import QuantMode, Round, get_qparam_dict




class Observer(Module): class Observer(Module):


+ 72
- 1
python_module/megengine/quantization/utils.py View File

@@ -8,6 +8,24 @@


from enum import Enum from enum import Enum
from functools import partial, update_wrapper, wraps from functools import partial, update_wrapper, wraps
from typing import Dict

from .. import functional as F
from .._internal.dtype import _metadata_dict
from ..core import Function, Tensor


class Round(Function):
"""
The functional round have no grad and can not use for quantization-aware-training.
We use Function and STE(Straight-Through Estimator) to implement backward propagation.
"""

def forward(self, x):
return x.round()

def backward(self, output_grads):
return output_grads




def register_method_to_class(cls): def register_method_to_class(cls):
@@ -25,6 +43,9 @@ def register_method_to_class(cls):




class QuantMode(Enum): class QuantMode(Enum):
"""Quantization mode enumerate class.
"""

SYMMERTIC = 1 SYMMERTIC = 1
ASYMMERTIC = 2 ASYMMERTIC = 2
TQT = 3 TQT = 3
@@ -41,5 +62,55 @@ qparam_dict = {
} }




def get_qparam_dict(mode):
def get_qparam_dict(mode: QuantMode):
"""Return the quantization parameters dictory according to the mode.
"""
return qparam_dict.get(mode, None) return qparam_dict.get(mode, None)


def fake_quant_tensor(inp: Tensor, qmin: int, qmax: int, q_dict: Dict) -> Tensor:
"""Apply fake quantization to the inp tensor.

:param inp: the input tensor which need to be faked.
:param qmin: the minimum value which the integer limit to.
:param qmax: the maximum value which the integer limit to.
:param q_dict: the quantization parameter dict.

"""
scale = q_dict["scale"]
zero_point = 0
if q_dict["mode"] == QuantMode.ASYMMERTIC:
zero_point = q_dict["zero_point"]
# Quant
oup = Round()(inp / scale) + zero_point
# Clip
oup = F.minimum(F.maximum(oup, qmin), qmax)
# Dequant
oup = (oup - zero_point) * scale
return oup


def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
"""Apply fake quantization to bias, the special scale from input tensor
and weight tensor, the quantized type set to qint32 also.

:param bias: the bias tensor which need to be faked.
:param inp: the input tensor which contain the quantization parameters.
:param qmax: the weight tensor which contain the quantization parameters.

.. warning::
Only work for symmetric quantization method now.

"""
b_qat = bias
if hasattr(inp, "q_dict") and b_qat is not None:
if inp.q_dict["scale"] is not None and w_qat.q_dict["scale"] is not None:
# use the same mode with weight.
b_dict = get_qparam_dict(w_qat.q_dict["mode"])
b_dict["scale"] = inp.q_dict["scale"] * w_qat.q_dict["scale"]
# TODO: add zero_point for ASYMMERTIC mode.
qmax = _metadata_dict["qint32"].qmax
qmin = _metadata_dict["qint32"].qmin
b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict)

return b_qat

Loading…
Cancel
Save