Browse Source

fix(mge/quant): fix TQT epoch scale change bug

GitOrigin-RevId: 6e39de9cec
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
35c712767d
6 changed files with 39 additions and 63 deletions
  1. +1
    -3
      python_module/megengine/module/qat/elemwise.py
  2. +15
    -23
      python_module/megengine/module/qat/module.py
  3. +3
    -7
      python_module/megengine/module/qat/quant_dequant.py
  4. +5
    -4
      python_module/megengine/quantization/fake_quant.py
  5. +1
    -0
      python_module/megengine/quantization/observer.py
  6. +14
    -26
      python_module/megengine/quantization/quantize.py

+ 1
- 3
python_module/megengine/module/qat/elemwise.py View File

@@ -17,9 +17,7 @@ class Elemwise(Float.Elemwise, QATModule):
:param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail. :param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail.
""" """


def __init__(self, method):
super().__init__(method)
self.with_weight = False
with_weight = False


def forward(self, *inps): def forward(self, *inps):
return self.apply_quant_activation(super().forward(*inps)) return self.apply_quant_activation(super().forward(*inps))


+ 15
- 23
python_module/megengine/module/qat/module.py View File

@@ -23,6 +23,9 @@ class QATModule(Module):
:func:`~.quantize.quantize` further. :func:`~.quantize.quantize` further.
""" """


with_weight = True
with_act = True

def __init__(self): def __init__(self):
super().__init__() super().__init__()


@@ -32,9 +35,6 @@ class QATModule(Module):
self.weight_fake_quant = None # type: FakeQuantize self.weight_fake_quant = None # type: FakeQuantize
self.act_fake_quant = None # type: FakeQuantize self.act_fake_quant = None # type: FakeQuantize


self.with_weight = True
self.with_act = True

def set_qconfig(self, qconfig: QConfig): def set_qconfig(self, qconfig: QConfig):
r""" r"""
Set quantization related configs with ``qconfig``, including Set quantization related configs with ``qconfig``, including
@@ -51,29 +51,21 @@ class QATModule(Module):
self.weight_observer = safe_call(qconfig.weight_observer) self.weight_observer = safe_call(qconfig.weight_observer)
self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) self.weight_fake_quant = safe_call(qconfig.weight_fake_quant)


def _enable_exec(self, with_module, func, enable):
if not with_module:
return
if enable:
func.enable()
else:
func.disable()

def set_fake_quant(self, enable): def set_fake_quant(self, enable):
if self.with_act:
if enable:
self.act_fake_quant.enable()
else:
self.act_fake_quant.disable()
if self.with_weight:
if enable:
self.weight_fake_quant.enable()
else:
self.weight_fake_quant.disable()
self._enable_exec(self.with_act, self.act_fake_quant, enable)
self._enable_exec(self.with_weight, self.weight_fake_quant, enable)


def set_observer(self, enable): def set_observer(self, enable):
if self.with_act:
if enable:
self.act_observer.enable()
else:
self.act_observer.disable()
if self.with_weight:
if enable:
self.weight_observer.enable()
else:
self.weight_observer.disable()
self._enable_exec(self.with_act, self.act_observer, enable)
self._enable_exec(self.with_weight, self.weight_observer, enable)


def _apply_fakequant_with_observer( def _apply_fakequant_with_observer(
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer self, target: Tensor, fake_quant: FakeQuantize, observer: Observer


+ 3
- 7
python_module/megengine/module/qat/quant_dequant.py View File

@@ -15,9 +15,7 @@ class QuantStub(Float.QuantStub, QATModule):
input after converted to :class:`~.QuantizedModule`. input after converted to :class:`~.QuantizedModule`.
""" """


def __init__(self):
super().__init__()
self.with_weight = False
with_weight = False


def forward(self, inp): def forward(self, inp):
return self.apply_quant_activation(inp) return self.apply_quant_activation(inp)
@@ -37,10 +35,8 @@ class DequantStub(Float.DequantStub, QATModule):
input after converted to :class:`~.QuantizedModule`. input after converted to :class:`~.QuantizedModule`.
""" """


def __init__(self):
super().__init__()
self.with_weight = False
self.with_act = False
with_weight = False
with_act = False


def forward(self, inp): def forward(self, inp):
return inp return inp


+ 5
- 4
python_module/megengine/quantization/fake_quant.py View File

@@ -116,10 +116,11 @@ class TQT(_FakeQuantize):
return TQT_Function(self.qmin, self.qmax)(inp, self.scale) return TQT_Function(self.qmin, self.qmax)(inp, self.scale)


def normal_foward(self, inp, q_dict): def normal_foward(self, inp, q_dict):
# when disable, TQT will do normal forward, initialize scale weight
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"]))
tmp_scale = F.log(tmp_scale / 127) / F.log(2)
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0)
if q_dict["enable_observer"]:
# when disable, TQT will do normal forward, initialize scale weight
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"]))
tmp_scale = F.log(tmp_scale / 127) / F.log(2)
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0)
return inp return inp


def get_qparams(self): def get_qparams(self):


+ 1
- 0
python_module/megengine/quantization/observer.py View File

@@ -102,6 +102,7 @@ class MinMaxObserver(Observer):
q_dict = get_qparam_dict(self.mode) q_dict = get_qparam_dict(self.mode)
q_dict["min_val"] = inp_min_val q_dict["min_val"] = inp_min_val
q_dict["max_val"] = inp_max_val q_dict["max_val"] = inp_max_val
q_dict["enable_observer"] = self.enable
if self.mode == QuantMode.SYMMERTIC: if self.mode == QuantMode.SYMMERTIC:
symmetric_max_vals = F.maximum(-min_val, max_val) symmetric_max_vals = F.maximum(-min_val, max_val)
# use maximun to avoid scale too small at the begin # use maximun to avoid scale too small at the begin


+ 14
- 26
python_module/megengine/quantization/quantize.py View File

@@ -14,6 +14,7 @@ from ..module import qat as QAT
from ..module import quantized as Quantized from ..module import quantized as Quantized
from ..module.qat import QATModule from ..module.qat import QATModule
from ..module.quantized import QuantizedModule from ..module.quantized import QuantizedModule
from .fake_quant import TQT
from .qconfig import QConfig, ema_fakequant_qconfig from .qconfig import QConfig, ema_fakequant_qconfig




@@ -119,6 +120,14 @@ def quantize_qat(
return module return module




def _propagate(module: Module, func_str: str, *args, **kargs):
def fn(mod: Module):
if isinstance(mod, QATModule):
getattr(mod, func_str)(*args, **kargs)

module.apply(fn)


def propagate_qconfig(module: QATModule, qconfig: QConfig): def propagate_qconfig(module: QATModule, qconfig: QConfig):
r""" r"""
Recursively set ``module``'s qconfig through :meth:`~.Module.apply`. Recursively set ``module``'s qconfig through :meth:`~.Module.apply`.
@@ -126,12 +135,7 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig):
:param module: root module to traverse recursively. :param module: root module to traverse recursively.
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. :param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig.
""" """

def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_qconfig(qconfig)

module.apply(fn)
_propagate(module, "set_qconfig", qconfig)




def disable_fake_quant(module: Module): def disable_fake_quant(module: Module):
@@ -141,11 +145,7 @@ def disable_fake_quant(module: Module):
:param module: root module to do disable fake quantization recursively. :param module: root module to do disable fake quantization recursively.
""" """


def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_fake_quant(False)

module.apply(fn)
_propagate(module, "set_fake_quant", False)




def disable_observer(module: Module): def disable_observer(module: Module):
@@ -155,11 +155,7 @@ def disable_observer(module: Module):
:param module: root module to do disable observer recursively. :param module: root module to do disable observer recursively.
""" """


def fn(mod: Module):
if isinstance(mod, QATModule):
self.set_observer(False)

module.apply(fn)
_propagate(module, "set_observer", False)




def enable_fake_quant(module: Module): def enable_fake_quant(module: Module):
@@ -169,11 +165,7 @@ def enable_fake_quant(module: Module):
:param module: root module to do enable fake quantization recursively. :param module: root module to do enable fake quantization recursively.
""" """


def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_fake_quant(True)

module.apply(fn)
_propagate(module, "set_fake_quant", True)




def enable_observer(module: Module): def enable_observer(module: Module):
@@ -183,8 +175,4 @@ def enable_observer(module: Module):
:param module: root module to do enable observer recursively. :param module: root module to do enable observer recursively.
""" """


def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_observer(True)

module.apply(fn)
_propagate(module, "set_observer", False)

Loading…
Cancel
Save