Browse Source

fix(mge/quant): fix TQT epoch scale change bug

GitOrigin-RevId: 6e39de9cec
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
35c712767d
6 changed files with 39 additions and 63 deletions
  1. +1
    -3
      python_module/megengine/module/qat/elemwise.py
  2. +15
    -23
      python_module/megengine/module/qat/module.py
  3. +3
    -7
      python_module/megengine/module/qat/quant_dequant.py
  4. +5
    -4
      python_module/megengine/quantization/fake_quant.py
  5. +1
    -0
      python_module/megengine/quantization/observer.py
  6. +14
    -26
      python_module/megengine/quantization/quantize.py

+ 1
- 3
python_module/megengine/module/qat/elemwise.py View File

@@ -17,9 +17,7 @@ class Elemwise(Float.Elemwise, QATModule):
:param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail.
"""

def __init__(self, method):
super().__init__(method)
self.with_weight = False
with_weight = False

def forward(self, *inps):
return self.apply_quant_activation(super().forward(*inps))


+ 15
- 23
python_module/megengine/module/qat/module.py View File

@@ -23,6 +23,9 @@ class QATModule(Module):
:func:`~.quantize.quantize` further.
"""

with_weight = True
with_act = True

def __init__(self):
super().__init__()

@@ -32,9 +35,6 @@ class QATModule(Module):
self.weight_fake_quant = None # type: FakeQuantize
self.act_fake_quant = None # type: FakeQuantize

self.with_weight = True
self.with_act = True

def set_qconfig(self, qconfig: QConfig):
r"""
Set quantization related configs with ``qconfig``, including
@@ -51,29 +51,21 @@ class QATModule(Module):
self.weight_observer = safe_call(qconfig.weight_observer)
self.weight_fake_quant = safe_call(qconfig.weight_fake_quant)

def _enable_exec(self, with_module, func, enable):
if not with_module:
return
if enable:
func.enable()
else:
func.disable()

def set_fake_quant(self, enable):
if self.with_act:
if enable:
self.act_fake_quant.enable()
else:
self.act_fake_quant.disable()
if self.with_weight:
if enable:
self.weight_fake_quant.enable()
else:
self.weight_fake_quant.disable()
self._enable_exec(self.with_act, self.act_fake_quant, enable)
self._enable_exec(self.with_weight, self.weight_fake_quant, enable)

def set_observer(self, enable):
if self.with_act:
if enable:
self.act_observer.enable()
else:
self.act_observer.disable()
if self.with_weight:
if enable:
self.weight_observer.enable()
else:
self.weight_observer.disable()
self._enable_exec(self.with_act, self.act_observer, enable)
self._enable_exec(self.with_weight, self.weight_observer, enable)

def _apply_fakequant_with_observer(
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer


+ 3
- 7
python_module/megengine/module/qat/quant_dequant.py View File

@@ -15,9 +15,7 @@ class QuantStub(Float.QuantStub, QATModule):
input after converted to :class:`~.QuantizedModule`.
"""

def __init__(self):
super().__init__()
self.with_weight = False
with_weight = False

def forward(self, inp):
return self.apply_quant_activation(inp)
@@ -37,10 +35,8 @@ class DequantStub(Float.DequantStub, QATModule):
input after converted to :class:`~.QuantizedModule`.
"""

def __init__(self):
super().__init__()
self.with_weight = False
self.with_act = False
with_weight = False
with_act = False

def forward(self, inp):
return inp


+ 5
- 4
python_module/megengine/quantization/fake_quant.py View File

@@ -116,10 +116,11 @@ class TQT(_FakeQuantize):
return TQT_Function(self.qmin, self.qmax)(inp, self.scale)

def normal_foward(self, inp, q_dict):
# when disable, TQT will do normal forward, initialize scale weight
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"]))
tmp_scale = F.log(tmp_scale / 127) / F.log(2)
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0)
if q_dict["enable_observer"]:
# when disable, TQT will do normal forward, initialize scale weight
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"]))
tmp_scale = F.log(tmp_scale / 127) / F.log(2)
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0)
return inp

def get_qparams(self):


+ 1
- 0
python_module/megengine/quantization/observer.py View File

@@ -102,6 +102,7 @@ class MinMaxObserver(Observer):
q_dict = get_qparam_dict(self.mode)
q_dict["min_val"] = inp_min_val
q_dict["max_val"] = inp_max_val
q_dict["enable_observer"] = self.enable
if self.mode == QuantMode.SYMMERTIC:
symmetric_max_vals = F.maximum(-min_val, max_val)
# use maximun to avoid scale too small at the begin


+ 14
- 26
python_module/megengine/quantization/quantize.py View File

@@ -14,6 +14,7 @@ from ..module import qat as QAT
from ..module import quantized as Quantized
from ..module.qat import QATModule
from ..module.quantized import QuantizedModule
from .fake_quant import TQT
from .qconfig import QConfig, ema_fakequant_qconfig


@@ -119,6 +120,14 @@ def quantize_qat(
return module


def _propagate(module: Module, func_str: str, *args, **kargs):
def fn(mod: Module):
if isinstance(mod, QATModule):
getattr(mod, func_str)(*args, **kargs)

module.apply(fn)


def propagate_qconfig(module: QATModule, qconfig: QConfig):
r"""
Recursively set ``module``'s qconfig through :meth:`~.Module.apply`.
@@ -126,12 +135,7 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig):
:param module: root module to traverse recursively.
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig.
"""

def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_qconfig(qconfig)

module.apply(fn)
_propagate(module, "set_qconfig", qconfig)


def disable_fake_quant(module: Module):
@@ -141,11 +145,7 @@ def disable_fake_quant(module: Module):
:param module: root module to do disable fake quantization recursively.
"""

def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_fake_quant(False)

module.apply(fn)
_propagate(module, "set_fake_quant", False)


def disable_observer(module: Module):
@@ -155,11 +155,7 @@ def disable_observer(module: Module):
:param module: root module to do disable observer recursively.
"""

def fn(mod: Module):
if isinstance(mod, QATModule):
self.set_observer(False)

module.apply(fn)
_propagate(module, "set_observer", False)


def enable_fake_quant(module: Module):
@@ -169,11 +165,7 @@ def enable_fake_quant(module: Module):
:param module: root module to do enable fake quantization recursively.
"""

def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_fake_quant(True)

module.apply(fn)
_propagate(module, "set_fake_quant", True)


def enable_observer(module: Module):
@@ -183,8 +175,4 @@ def enable_observer(module: Module):
:param module: root module to do enable observer recursively.
"""

def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_observer(True)

module.apply(fn)
_propagate(module, "set_observer", False)

Loading…
Cancel
Save