Browse Source

feat(mge/quantization): add bias fakequant support

GitOrigin-RevId: a5e953b3fa
tags/v1.0.0-rc1
Megvii Engine Team Xinran Xu 4 years ago
parent
commit
555ecea9bc
8 changed files with 99 additions and 43 deletions
  1. +4
    -3
      python_module/megengine/core/tensor.py
  2. +3
    -1
      python_module/megengine/module/qat/conv.py
  3. +3
    -1
      python_module/megengine/module/qat/conv_bn.py
  4. +3
    -1
      python_module/megengine/module/qat/linear.py
  5. +10
    -5
      python_module/megengine/module/qat/module.py
  6. +2
    -21
      python_module/megengine/quantization/fake_quant.py
  7. +2
    -10
      python_module/megengine/quantization/observer.py
  8. +72
    -1
      python_module/megengine/quantization/utils.py

+ 4
- 3
python_module/megengine/core/tensor.py View File

@@ -138,6 +138,7 @@ class Tensor:

def __init__(self, val=None, *, requires_grad=None):
self._reset(val, requires_grad=requires_grad)
self.q_dict = {"mode": None, "scale": None, "zero_point": None}

def _reset(self, val=None, *, requires_grad=None):
self.__sym_override = None
@@ -677,9 +678,9 @@ class Tensor:

def __deepcopy__(self, memo):
"""
Since Tensor have __getstate__ and __setstate__ method,
deepcopy only process the that and ignore the attribute of Parameter.
So we need to add __deepcopy__ method to deepcopy correct attribute.
The default deepcopy will ignore other attributes except those defined at
__getstate__ and __setstate__ method.
So we need to add __deepcopy__ method to deepcopy correct attributes.
"""
assert (self.__val is not None) and (
self.__sym is None


+ 3
- 1
python_module/megengine/module/qat/conv.py View File

@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import functional as F
from ...quantization.utils import fake_quant_bias
from .. import conv as Float
from .module import QATModule

@@ -18,7 +19,8 @@ class Conv2d(Float.Conv2d, QATModule):

def calc_conv_qat(self, inp):
w_qat = self.apply_quant_weight(self.weight)
conv = self.calc_conv(inp, w_qat, self.bias)
b_qat = fake_quant_bias(self.bias, inp, w_qat)
conv = self.calc_conv(inp, w_qat, b_qat)
return conv

@classmethod


+ 3
- 1
python_module/megengine/module/qat/conv_bn.py View File

@@ -7,6 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...core import ones, zeros
from ...functional import add_update, relu, sqrt, sum, zero_grad
from ...quantization.utils import fake_quant_bias
from .. import conv_bn as Float
from .module import QATModule

@@ -132,7 +133,8 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd

w_qat = self.apply_quant_weight(w_fold)
conv = self.conv.calc_conv(inp, w_qat, b_fold)
b_qat = fake_quant_bias(b_fold, inp, w_qat)
conv = self.conv.calc_conv(inp, w_qat, b_qat)
if not (self.training and approx):
return conv



+ 3
- 1
python_module/megengine/module/qat/linear.py View File

@@ -5,6 +5,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...quantization.utils import fake_quant_bias
from .. import linear as Float
from .module import QATModule

@@ -23,7 +24,8 @@ class Linear(Float.Linear, QATModule):

def forward(self, x):
w_qat = self.apply_quant_weight(self.weight)
return self.apply_quant_activation(self._calc_linear(x, w_qat, self.bias),)
b_qat = fake_quant_bias(self.bias, x, w_qat)
return self.apply_quant_activation(self._calc_linear(x, w_qat, b_qat))

@classmethod
def from_float_module(cls, float_module: Float.Linear):


+ 10
- 5
python_module/megengine/module/qat/module.py View File

@@ -73,11 +73,16 @@ class QATModule(Module):
if observer is None:
return target
oup = observer(target)
if fake_quant is None:
return oup
else:
q_dict = observer.get_qparams()
return fake_quant(oup, q_dict)
q_dict = observer.get_qparams()
# do fake quant
if fake_quant is not None:
oup = fake_quant(oup, q_dict)
# use qparams of fake_quant if have.
if hasattr(fake_quant, "get_qparams"):
q_dict = fake_quant.get_qparams()
# set to tensor qparams.
oup.q_dict.update(q_dict)
return oup

def apply_quant_weight(self, target: Tensor):
r"""


+ 2
- 21
python_module/megengine/quantization/fake_quant.py View File

@@ -15,8 +15,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, Parameter
from ..jit import sideeffect
from ..module import Module
from .observer import Round
from .utils import QuantMode, get_qparam_dict
from .utils import QuantMode, Round, fake_quant_tensor, get_qparam_dict


class _FakeQuantize(Module):
@@ -143,22 +142,4 @@ class FakeQuantize(_FakeQuantize):
"""

def fake_quant_forward(self, inp, q_dict):
if q_dict["mode"] == QuantMode.SYMMERTIC:
scale = q_dict["scale"]
# Quant
oup = Round()(inp / scale)
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup) * scale
return oup
else:
scale = q_dict["scale"]
zero_point = q_dict["zero_point"]
# Quant
oup = Round()(inp / scale) + zero_point
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup - zero_point) * scale
return oup
return fake_quant_tensor(inp, self.qmin, self.qmax, q_dict)

+ 2
- 10
python_module/megengine/quantization/observer.py View File

@@ -13,18 +13,10 @@ import numpy as np

from .. import functional as F
from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, tensor
from ..core import Buffer
from ..jit import sideeffect
from ..module import Module
from .utils import QuantMode, get_qparam_dict


class Round(Function):
def forward(self, x):
return x.round()

def backward(self, output_grads):
return output_grads
from .utils import QuantMode, Round, get_qparam_dict


class Observer(Module):


+ 72
- 1
python_module/megengine/quantization/utils.py View File

@@ -8,6 +8,24 @@

from enum import Enum
from functools import partial, update_wrapper, wraps
from typing import Dict

from .. import functional as F
from .._internal.dtype import _metadata_dict
from ..core import Function, Tensor


class Round(Function):
"""
The functional round have no grad and can not use for quantization-aware-training.
We use Function and STE(Straight-Through Estimator) to implement backward propagation.
"""

def forward(self, x):
return x.round()

def backward(self, output_grads):
return output_grads


def register_method_to_class(cls):
@@ -25,6 +43,9 @@ def register_method_to_class(cls):


class QuantMode(Enum):
"""Quantization mode enumerate class.
"""

SYMMERTIC = 1
ASYMMERTIC = 2
TQT = 3
@@ -41,5 +62,55 @@ qparam_dict = {
}


def get_qparam_dict(mode):
def get_qparam_dict(mode: QuantMode):
"""Return the quantization parameters dictory according to the mode.
"""
return qparam_dict.get(mode, None)


def fake_quant_tensor(inp: Tensor, qmin: int, qmax: int, q_dict: Dict) -> Tensor:
"""Apply fake quantization to the inp tensor.

:param inp: the input tensor which need to be faked.
:param qmin: the minimum value which the integer limit to.
:param qmax: the maximum value which the integer limit to.
:param q_dict: the quantization parameter dict.

"""
scale = q_dict["scale"]
zero_point = 0
if q_dict["mode"] == QuantMode.ASYMMERTIC:
zero_point = q_dict["zero_point"]
# Quant
oup = Round()(inp / scale) + zero_point
# Clip
oup = F.minimum(F.maximum(oup, qmin), qmax)
# Dequant
oup = (oup - zero_point) * scale
return oup


def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
"""Apply fake quantization to bias, the special scale from input tensor
and weight tensor, the quantized type set to qint32 also.

:param bias: the bias tensor which need to be faked.
:param inp: the input tensor which contain the quantization parameters.
:param qmax: the weight tensor which contain the quantization parameters.

.. warning::
Only work for symmetric quantization method now.

"""
b_qat = bias
if hasattr(inp, "q_dict") and b_qat is not None:
if inp.q_dict["scale"] is not None and w_qat.q_dict["scale"] is not None:
# use the same mode with weight.
b_dict = get_qparam_dict(w_qat.q_dict["mode"])
b_dict["scale"] = inp.q_dict["scale"] * w_qat.q_dict["scale"]
# TODO: add zero_point for ASYMMERTIC mode.
qmax = _metadata_dict["qint32"].qmax
qmin = _metadata_dict["qint32"].qmin
b_qat = fake_quant_tensor(b_qat, qmin, qmax, b_dict)

return b_qat

Loading…
Cancel
Save