Browse Source

feat(mge/quantization): add calibration support

GitOrigin-RevId: f16fbba2b7
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
7c4f1a3851
8 changed files with 45 additions and 43 deletions
  1. +8
    -9
      python_module/megengine/module/module.py
  2. +3
    -8
      python_module/megengine/module/quantized/concat.py
  3. +4
    -7
      python_module/megengine/module/quantized/conv_bn_relu.py
  4. +3
    -8
      python_module/megengine/module/quantized/elemwise.py
  5. +4
    -10
      python_module/megengine/module/quantized/quant_dequant.py
  6. +1
    -0
      python_module/megengine/quantization/__init__.py
  7. +1
    -1
      python_module/megengine/quantization/observer.py
  8. +21
    -0
      python_module/megengine/quantization/quantize.py

+ 8
- 9
python_module/megengine/module/module.py View File

@@ -496,8 +496,11 @@ class QATModule(Module):
self, target: Tensor, fq: "FakeQuantize", obs: "Observer"
):
oup = self.apply_observer(target, obs)
scale, zero_point = obs.get_qparams()
return fq(oup, scale, zero_point)
if self.quantizing == self.QATMode.CALIBRATION:
return oup
else:
scale, zero_point = obs.get_qparams()
return fq(oup, scale, zero_point)

def set_qat_mode(self, mode: QATMode):
r"""
@@ -524,11 +527,7 @@ class QATModule(Module):
"""

def __call__(self, *args, **kwargs):
if self.quantizing == self.QATMode.QAT:
return self.forward_qat(*args, **kwargs)
elif self.quantizing == self.QATMode.CALIBRATION:
# TODO implement the CALIBRATION
assert False
return None
else:
if self.quantizing == self.QATMode.DISABLED:
return self.forward(*args, **kwargs)
else:
return self.forward_qat(*args, **kwargs)

+ 3
- 8
python_module/megengine/module/quantized/concat.py View File

@@ -20,11 +20,9 @@ class Concat(Module):
A :class:`~.Module` to do quantized concat, inference only.
"""

def __init__(self):
def __init__(self, dtype=None):
super().__init__()
self.scale = 1.0
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)
self.output_dtype = dtype

def forward(self, inps: Iterable[Tensor], axis: int = 0):
if self.training:
@@ -39,7 +37,4 @@ def to_quantized(float_module):
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod = Concat()
qmod.output_dtype = float_module.act_observer.get_dtype()
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams()
return qmod
return Concat(float_module.act_observer.get_dtype())

+ 4
- 7
python_module/megengine/module/quantized/conv_bn_relu.py View File

@@ -34,6 +34,7 @@ class _ConvBnActivation2d(Conv2d):
groups: int = 1,
conv_mode: str = "CROSS_CORRELATION",
compute_mode: str = "DEFAULT",
dtype=None,
):
super().__init__(
in_channels,
@@ -47,11 +48,7 @@ class _ConvBnActivation2d(Conv2d):
conv_mode,
compute_mode,
)
self.scale = 1.0
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)
self.weight = self.weight.astype(self.output_dtype)
self.bias = self.bias.astype(mgb.dtype.qint32(self.scale))
self.output_dtype = dtype

def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"):
inp_scale = mgb.dtype.get_scale(inp.dtype)
@@ -87,6 +84,7 @@ class ConvBnRelu2d(_ConvBnActivation2d):


def to_quantized(quantized_class, float_module):
output_dtype = float_module.act_observer.get_dtype()
qconv = quantized_class(
float_module.conv.in_channels,
float_module.conv.out_channels,
@@ -95,15 +93,14 @@ def to_quantized(quantized_class, float_module):
float_module.conv.padding,
float_module.conv.dilation,
float_module.conv.groups,
dtype=output_dtype,
)
w_fold, b_fold = float_module.fold_weight_bias(
float_module.bn.running_mean, float_module.bn.running_var
)
weight = w_fold.astype(float_module.weight_observer.get_dtype())
qconv.output_dtype = float_module.act_observer.get_dtype()
qconv.weight = Parameter(weight.numpy())
qconv.bias = Parameter(b_fold.numpy())
qconv.scale, qconv.zero_point = float_module.act_observer.get_qparams()

return qconv



+ 3
- 8
python_module/megengine/module/quantized/elemwise.py View File

@@ -34,12 +34,10 @@ class Elemwise(Module):

_elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode

def __init__(self, method):
def __init__(self, method, dtype=None):
super().__init__()
self.method = self._elemwise_multi_type_mode.convert("Q" + method)
self.scale = 1.0
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)
self.output_dtype = dtype

def forward(self, *inps):
if self.training:
@@ -53,7 +51,4 @@ def to_quantized(float_module):
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod = Elemwise(float_module.method.name)
qmod.output_dtype = float_module.act_observer.get_dtype()
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams()
return qmod
return Elemwise(float_module.method.name, float_module.act_observer.get_dtype())

+ 4
- 10
python_module/megengine/module/quantized/quant_dequant.py View File

@@ -16,11 +16,9 @@ class QuantStub(Module):
A helper quantize operation on input and inference only.
"""

def __init__(self):
def __init__(self, dtype=None):
super().__init__()
self.scale = 1.0
self.zero_point = 0.0
self.output_dtype = mgb.dtype.qint8(self.scale)
self.output_dtype = dtype

def forward(self, inp):
if self.training:
@@ -45,10 +43,7 @@ def to_quantized(float_module):
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod = QuantStub()
qmod.output_dtype = float_module.act_observer.get_dtype()
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams()
return qmod
return QuantStub(float_module.act_observer.get_dtype())


@register_method_to_class(Float.DequantStub)
@@ -57,5 +52,4 @@ def to_quantized(float_module):
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
qmod = DequantStub()
return qmod
return DequantStub()

+ 1
- 0
python_module/megengine/quantization/__init__.py View File

@@ -14,5 +14,6 @@ from .quantize import (
enable_fake_quant,
enable_observer,
quantize,
quantize_calibration,
quantize_qat,
)

+ 1
- 1
python_module/megengine/quantization/observer.py View File

@@ -11,7 +11,7 @@ import numpy as np

from .. import functional as F
from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, ones, tensor, zeros
from ..core import Buffer, Function, tensor
from ..module import Module




+ 21
- 0
python_module/megengine/quantization/quantize.py View File

@@ -34,6 +34,8 @@ def quantize(module: Module, inplace=True):
else:
setattr(parent, key.split(".")[-1], submodule.to_quantized())

return module


def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig):
r"""
@@ -53,6 +55,25 @@ def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig):
module.apply(fn)


def quantize_calibration(module: Module, qconfig: QConfig = ema_fakequant_qconfig):
r"""
Recursively convert `module` to `calibration` mode through :meth:`~.Module.apply`
and set qconfig relatively.

:param module: root module to do convert recursively.
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is :any:`~.qconfig.ema_fakequant_qconfig`.
"""

def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_qat_mode(QATModule.QATMode.CALIBRATION)
mod.set_qconfig(qconfig)

module.apply(fn)
enable_observer(module)


def disable_fake_quant(module: Module):
r"""
Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply`


Loading…
Cancel
Save