Browse Source

refactor(mge/quantization): refactor qconfig, remove inp_observer and bias_fakequant

GitOrigin-RevId: e57f9edd12
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
38a95744fb
4 changed files with 17 additions and 28 deletions
  1. +0
    -4
      python_module/megengine/module/module.py
  2. +8
    -1
      python_module/megengine/quantization/__init__.py
  3. +7
    -21
      python_module/megengine/quantization/qconfig.py
  4. +2
    -2
      python_module/megengine/quantization/quantize.py

+ 0
- 4
python_module/megengine/module/module.py View File

@@ -476,21 +476,17 @@ class QATModule(Module):
self.quantizing = self.QATMode.DISABLED self.quantizing = self.QATMode.DISABLED
self.scale = None self.scale = None


self.inp_observer = None # type: Observer
self.weight_observer = None # type: Observer self.weight_observer = None # type: Observer
self.act_observer = None # type: Observer self.act_observer = None # type: Observer


self.weight_fake_quant = None # type: FakeQuantize self.weight_fake_quant = None # type: FakeQuantize
self.bias_fake_quant = None # type: FakeQuantize
self.act_fake_quant = None # type: FakeQuantize self.act_fake_quant = None # type: FakeQuantize


def set_qconfig(self, qconfig: "QConfig"): def set_qconfig(self, qconfig: "QConfig"):
self.inp_observer = qconfig.inp_observer()
self.weight_observer = qconfig.weight_observer() self.weight_observer = qconfig.weight_observer()
self.act_observer = qconfig.act_observer() self.act_observer = qconfig.act_observer()


self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype) self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype)
self.bias_fake_quant = qconfig.bias_fake_quant()
self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype) self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype)


def apply_observer(self, target: Tensor, obs: "Observer"): def apply_observer(self, target: Tensor, obs: "Observer"):


+ 8
- 1
python_module/megengine/quantization/__init__.py View File

@@ -8,4 +8,11 @@
from .fake_quant import FakeQuantize from .fake_quant import FakeQuantize
from .observer import Observer from .observer import Observer
from .qconfig import QConfig, ema_fakequant_qconfig, min_max_fakequant_qconfig from .qconfig import QConfig, ema_fakequant_qconfig, min_max_fakequant_qconfig
from .quantize import quantize, quantize_qat
from .quantize import (
disable_fake_quant,
disable_observer,
enable_fake_quant,
enable_observer,
quantize,
quantize_qat,
)

+ 7
- 21
python_module/megengine/quantization/qconfig.py View File

@@ -15,21 +15,18 @@ from .observer import ExponentialMovingAverageObserver, MinMaxObserver
class QConfig: class QConfig:
""" """
A config class indicating how to do quantize toward :class:`~.QATModule`'s A config class indicating how to do quantize toward :class:`~.QATModule`'s
``activation``, ``weight`` and ``bias``.
``activation`` and ``weight``.


And ``fake_quant`` parameter to indicate And ``fake_quant`` parameter to indicate


See :meth:`~.QATModule.set_qconfig` for detail usage. See :meth:`~.QATModule.set_qconfig` for detail usage.


:param inp_observer: interface to instantiate an :class:`~.Observer` indicating
how to collect scales and zero_point of input.
:param weight_observer: similar to ``inp_observer`` but toward weight.
:param act_observer: similar to ``inp_observer`` but toward activation.
:param weight_observer: interface to instantiate an :class:`~.Observer` indicating
- how to collect scales and zero_point of wegiht.
:param act_observer: similar to ``weight_observer`` but toward activation.
:param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating :param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating
how to do fake_quant calculation. can be invoked multi times to get different how to do fake_quant calculation. can be invoked multi times to get different
instance for each target tensor, for better control on enable and disable. instance for each target tensor, for better control on enable and disable.
:param bias_fake_quant: similar to ``fake_quant``, but usually need to set ``dtype``
in advance, for bias's dtype is unable to be inferred from observer.


Examples: Examples:


@@ -37,21 +34,16 @@ class QConfig:


# Default EMA QConfig for QAT. # Default EMA QConfig for QAT.
ema_fakequant_qconfig = QConfig( ema_fakequant_qconfig = QConfig(
inp_observer=ExponentialMovingAverageObserver,
weight_observer=ExponentialMovingAverageObserver,
weight_observer=MinMaxObserver,
act_observer=ExponentialMovingAverageObserver, act_observer=ExponentialMovingAverageObserver,
fake_quant=FakeQuantize, fake_quant=FakeQuantize,
) )
""" """


def __init__( def __init__(
self, act_observer, weight_observer, inp_observer, fake_quant, bias_fake_quant,
self, act_observer, weight_observer, fake_quant,
): ):
if (
isinstance(act_observer, Module)
or isinstance(weight_observer, Module)
or isinstance(inp_observer, Module)
):
if isinstance(act_observer, Module) or isinstance(weight_observer, Module):
raise ValueError( raise ValueError(
"QConfig must not receive observer instance, please pass observer" "QConfig must not receive observer instance, please pass observer"
" class generator using `partial(Observer, ...)` instead. Use" " class generator using `partial(Observer, ...)` instead. Use"
@@ -59,24 +51,18 @@ class QConfig:
) )
self.act_observer = act_observer self.act_observer = act_observer
self.weight_observer = weight_observer self.weight_observer = weight_observer
self.inp_observer = inp_observer
self.fake_quant = fake_quant self.fake_quant = fake_quant
self.bias_fake_quant = bias_fake_quant




# Default QAT QConfigs # Default QAT QConfigs
min_max_fakequant_qconfig = QConfig( min_max_fakequant_qconfig = QConfig(
inp_observer=MinMaxObserver,
weight_observer=MinMaxObserver, weight_observer=MinMaxObserver,
act_observer=MinMaxObserver, act_observer=MinMaxObserver,
fake_quant=FakeQuantize, fake_quant=FakeQuantize,
bias_fake_quant=partial(FakeQuantize, dtype="qint32"),
) )


ema_fakequant_qconfig = QConfig( ema_fakequant_qconfig = QConfig(
inp_observer=ExponentialMovingAverageObserver,
weight_observer=MinMaxObserver, weight_observer=MinMaxObserver,
act_observer=ExponentialMovingAverageObserver, act_observer=ExponentialMovingAverageObserver,
fake_quant=FakeQuantize, fake_quant=FakeQuantize,
bias_fake_quant=partial(FakeQuantize, dtype="qint32"),
) )

+ 2
- 2
python_module/megengine/quantization/quantize.py View File

@@ -64,7 +64,6 @@ def disable_fake_quant(module: Module):
if isinstance(mod, QATModule): if isinstance(mod, QATModule):
mod.act_fake_quant.disable() mod.act_fake_quant.disable()
mod.weight_fake_quant.disable() mod.weight_fake_quant.disable()
mod.inp_fake_quant.disable()


module.apply(fn) module.apply(fn)


@@ -79,6 +78,7 @@ def disable_observer(module: Module):
def fn(mod): def fn(mod):
if isinstance(mod, QATModule): if isinstance(mod, QATModule):
mod.act_observer.disable() mod.act_observer.disable()
mod.weight_observer.disable()


module.apply(fn) module.apply(fn)


@@ -94,7 +94,6 @@ def enable_fake_quant(module: Module):
if isinstance(mod, QATModule): if isinstance(mod, QATModule):
mod.act_fake_quant.enable() mod.act_fake_quant.enable()
mod.weight_fake_quant.enable() mod.weight_fake_quant.enable()
mod.inp_fake_quant.enable()


module.apply(fn) module.apply(fn)


@@ -109,5 +108,6 @@ def enable_observer(module: Module):
def fn(mod): def fn(mod):
if isinstance(mod, QATModule): if isinstance(mod, QATModule):
mod.act_observer.enable() mod.act_observer.enable()
mod.weight_observer.enable()


module.apply(fn) module.apply(fn)

Loading…
Cancel
Save