Browse Source

refactor(mge/quantization): refactor qconfig, remove inp_observer and bias_fakequant

GitOrigin-RevId: e57f9edd12
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
38a95744fb
4 changed files with 17 additions and 28 deletions
  1. +0
    -4
      python_module/megengine/module/module.py
  2. +8
    -1
      python_module/megengine/quantization/__init__.py
  3. +7
    -21
      python_module/megengine/quantization/qconfig.py
  4. +2
    -2
      python_module/megengine/quantization/quantize.py

+ 0
- 4
python_module/megengine/module/module.py View File

@@ -476,21 +476,17 @@ class QATModule(Module):
self.quantizing = self.QATMode.DISABLED
self.scale = None

self.inp_observer = None # type: Observer
self.weight_observer = None # type: Observer
self.act_observer = None # type: Observer

self.weight_fake_quant = None # type: FakeQuantize
self.bias_fake_quant = None # type: FakeQuantize
self.act_fake_quant = None # type: FakeQuantize

def set_qconfig(self, qconfig: "QConfig"):
self.inp_observer = qconfig.inp_observer()
self.weight_observer = qconfig.weight_observer()
self.act_observer = qconfig.act_observer()

self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype)
self.bias_fake_quant = qconfig.bias_fake_quant()
self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype)

def apply_observer(self, target: Tensor, obs: "Observer"):


+ 8
- 1
python_module/megengine/quantization/__init__.py View File

@@ -8,4 +8,11 @@
from .fake_quant import FakeQuantize
from .observer import Observer
from .qconfig import QConfig, ema_fakequant_qconfig, min_max_fakequant_qconfig
from .quantize import quantize, quantize_qat
from .quantize import (
disable_fake_quant,
disable_observer,
enable_fake_quant,
enable_observer,
quantize,
quantize_qat,
)

+ 7
- 21
python_module/megengine/quantization/qconfig.py View File

@@ -15,21 +15,18 @@ from .observer import ExponentialMovingAverageObserver, MinMaxObserver
class QConfig:
"""
A config class indicating how to do quantize toward :class:`~.QATModule`'s
``activation``, ``weight`` and ``bias``.
``activation`` and ``weight``.

And ``fake_quant`` parameter to indicate

See :meth:`~.QATModule.set_qconfig` for detail usage.

:param inp_observer: interface to instantiate an :class:`~.Observer` indicating
how to collect scales and zero_point of input.
:param weight_observer: similar to ``inp_observer`` but toward weight.
:param act_observer: similar to ``inp_observer`` but toward activation.
:param weight_observer: interface to instantiate an :class:`~.Observer` indicating
- how to collect scales and zero_point of wegiht.
:param act_observer: similar to ``weight_observer`` but toward activation.
:param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating
how to do fake_quant calculation. can be invoked multi times to get different
instance for each target tensor, for better control on enable and disable.
:param bias_fake_quant: similar to ``fake_quant``, but usually need to set ``dtype``
in advance, for bias's dtype is unable to be inferred from observer.

Examples:

@@ -37,21 +34,16 @@ class QConfig:

# Default EMA QConfig for QAT.
ema_fakequant_qconfig = QConfig(
inp_observer=ExponentialMovingAverageObserver,
weight_observer=ExponentialMovingAverageObserver,
weight_observer=MinMaxObserver,
act_observer=ExponentialMovingAverageObserver,
fake_quant=FakeQuantize,
)
"""

def __init__(
self, act_observer, weight_observer, inp_observer, fake_quant, bias_fake_quant,
self, act_observer, weight_observer, fake_quant,
):
if (
isinstance(act_observer, Module)
or isinstance(weight_observer, Module)
or isinstance(inp_observer, Module)
):
if isinstance(act_observer, Module) or isinstance(weight_observer, Module):
raise ValueError(
"QConfig must not receive observer instance, please pass observer"
" class generator using `partial(Observer, ...)` instead. Use"
@@ -59,24 +51,18 @@ class QConfig:
)
self.act_observer = act_observer
self.weight_observer = weight_observer
self.inp_observer = inp_observer
self.fake_quant = fake_quant
self.bias_fake_quant = bias_fake_quant


# Default QAT QConfigs
min_max_fakequant_qconfig = QConfig(
inp_observer=MinMaxObserver,
weight_observer=MinMaxObserver,
act_observer=MinMaxObserver,
fake_quant=FakeQuantize,
bias_fake_quant=partial(FakeQuantize, dtype="qint32"),
)

ema_fakequant_qconfig = QConfig(
inp_observer=ExponentialMovingAverageObserver,
weight_observer=MinMaxObserver,
act_observer=ExponentialMovingAverageObserver,
fake_quant=FakeQuantize,
bias_fake_quant=partial(FakeQuantize, dtype="qint32"),
)

+ 2
- 2
python_module/megengine/quantization/quantize.py View File

@@ -64,7 +64,6 @@ def disable_fake_quant(module: Module):
if isinstance(mod, QATModule):
mod.act_fake_quant.disable()
mod.weight_fake_quant.disable()
mod.inp_fake_quant.disable()

module.apply(fn)

@@ -79,6 +78,7 @@ def disable_observer(module: Module):
def fn(mod):
if isinstance(mod, QATModule):
mod.act_observer.disable()
mod.weight_observer.disable()

module.apply(fn)

@@ -94,7 +94,6 @@ def enable_fake_quant(module: Module):
if isinstance(mod, QATModule):
mod.act_fake_quant.enable()
mod.weight_fake_quant.enable()
mod.inp_fake_quant.enable()

module.apply(fn)

@@ -109,5 +108,6 @@ def enable_observer(module: Module):
def fn(mod):
if isinstance(mod, QATModule):
mod.act_observer.enable()
mod.weight_observer.enable()

module.apply(fn)

Loading…
Cancel
Save