Browse Source

fix(mge/quant): fix zero grad warn in TQT train

GitOrigin-RevId: a6545ee366
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
e6e41242c7
4 changed files with 50 additions and 12 deletions
  1. +4
    -0
      python_module/megengine/module/qat/elemwise.py
  2. +33
    -4
      python_module/megengine/module/qat/module.py
  3. +9
    -0
      python_module/megengine/module/qat/quant_dequant.py
  4. +4
    -8
      python_module/megengine/quantization/quantize.py

+ 4
- 0
python_module/megengine/module/qat/elemwise.py View File

@@ -17,6 +17,10 @@ class Elemwise(Float.Elemwise, QATModule):
:param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail.
"""

def __init__(self, method):
super().__init__(method)
self.with_weight = False

def forward(self, *inps):
return self.apply_quant_activation(super().forward(*inps))



+ 33
- 4
python_module/megengine/module/qat/module.py View File

@@ -32,6 +32,9 @@ class QATModule(Module):
self.weight_fake_quant = None # type: FakeQuantize
self.act_fake_quant = None # type: FakeQuantize

self.with_weight = True
self.with_act = True

def set_qconfig(self, qconfig: QConfig):
r"""
Set quantization related configs with ``qconfig``, including
@@ -41,10 +44,36 @@ class QATModule(Module):
def safe_call(func):
return func() if func is not None else None

self.weight_observer = safe_call(qconfig.weight_observer)
self.act_observer = safe_call(qconfig.act_observer)
self.weight_fake_quant = safe_call(qconfig.weight_fake_quant)
self.act_fake_quant = safe_call(qconfig.act_fake_quant)
if self.with_act:
self.act_observer = safe_call(qconfig.act_observer)
self.act_fake_quant = safe_call(qconfig.act_fake_quant)
if self.with_weight:
self.weight_observer = safe_call(qconfig.weight_observer)
self.weight_fake_quant = safe_call(qconfig.weight_fake_quant)

def set_fake_quant(self, enable):
if self.with_act:
if enable:
self.act_fake_quant.enable()
else:
self.act_fake_quant.disable()
if self.with_weight:
if enable:
self.weight_fake_quant.enable()
else:
self.weight_fake_quant.disable()

def set_observer(self, enable):
if self.with_act:
if enable:
self.act_observer.enable()
else:
self.act_observer.disable()
if self.with_weight:
if enable:
self.weight_observer.enable()
else:
self.weight_observer.disable()

def _apply_fakequant_with_observer(
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer


+ 9
- 0
python_module/megengine/module/qat/quant_dequant.py View File

@@ -15,6 +15,10 @@ class QuantStub(Float.QuantStub, QATModule):
input after converted to :class:`~.QuantizedModule`.
"""

def __init__(self):
super().__init__()
self.with_weight = False

def forward(self, inp):
return self.apply_quant_activation(inp)

@@ -33,6 +37,11 @@ class DequantStub(Float.DequantStub, QATModule):
input after converted to :class:`~.QuantizedModule`.
"""

def __init__(self):
super().__init__()
self.with_weight = False
self.with_act = False

def forward(self, inp):
return inp



+ 4
- 8
python_module/megengine/quantization/quantize.py View File

@@ -143,8 +143,7 @@ def disable_fake_quant(module: Module):

def fn(mod: Module):
if isinstance(mod, QATModule):
mod.act_fake_quant.disable()
mod.weight_fake_quant.disable()
mod.set_fake_quant(False)

module.apply(fn)

@@ -158,8 +157,7 @@ def disable_observer(module: Module):

def fn(mod: Module):
if isinstance(mod, QATModule):
mod.act_observer.disable()
mod.weight_observer.disable()
self.set_observer(False)

module.apply(fn)

@@ -173,8 +171,7 @@ def enable_fake_quant(module: Module):

def fn(mod: Module):
if isinstance(mod, QATModule):
mod.act_fake_quant.enable()
mod.weight_fake_quant.enable()
mod.set_fake_quant(True)

module.apply(fn)

@@ -188,7 +185,6 @@ def enable_observer(module: Module):

def fn(mod: Module):
if isinstance(mod, QATModule):
mod.act_observer.enable()
mod.weight_observer.enable()
mod.set_observer(True)

module.apply(fn)

Loading…
Cancel
Save