From e6e41242c7ab52ca7644e9af6f99060c7727fe50 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 3 Aug 2020 15:20:12 +0800 Subject: [PATCH] fix(mge/quant): fix zero grad warn in TQT train GitOrigin-RevId: a6545ee3660816b4bf1cba6364d4ad2026a82833 --- python_module/megengine/module/qat/elemwise.py | 4 +++ python_module/megengine/module/qat/module.py | 37 +++++++++++++++++++--- .../megengine/module/qat/quant_dequant.py | 9 ++++++ python_module/megengine/quantization/quantize.py | 12 +++---- 4 files changed, 50 insertions(+), 12 deletions(-) diff --git a/python_module/megengine/module/qat/elemwise.py b/python_module/megengine/module/qat/elemwise.py index 37e03e81..3385e774 100644 --- a/python_module/megengine/module/qat/elemwise.py +++ b/python_module/megengine/module/qat/elemwise.py @@ -17,6 +17,10 @@ class Elemwise(Float.Elemwise, QATModule): :param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail. """ + def __init__(self, method): + super().__init__(method) + self.with_weight = False + def forward(self, *inps): return self.apply_quant_activation(super().forward(*inps)) diff --git a/python_module/megengine/module/qat/module.py b/python_module/megengine/module/qat/module.py index e56cb54b..931747c4 100644 --- a/python_module/megengine/module/qat/module.py +++ b/python_module/megengine/module/qat/module.py @@ -32,6 +32,9 @@ class QATModule(Module): self.weight_fake_quant = None # type: FakeQuantize self.act_fake_quant = None # type: FakeQuantize + self.with_weight = True + self.with_act = True + def set_qconfig(self, qconfig: QConfig): r""" Set quantization related configs with ``qconfig``, including @@ -41,10 +44,36 @@ class QATModule(Module): def safe_call(func): return func() if func is not None else None - self.weight_observer = safe_call(qconfig.weight_observer) - self.act_observer = safe_call(qconfig.act_observer) - self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) - self.act_fake_quant = safe_call(qconfig.act_fake_quant) + if self.with_act: + self.act_observer = safe_call(qconfig.act_observer) + self.act_fake_quant = safe_call(qconfig.act_fake_quant) + if self.with_weight: + self.weight_observer = safe_call(qconfig.weight_observer) + self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) + + def set_fake_quant(self, enable): + if self.with_act: + if enable: + self.act_fake_quant.enable() + else: + self.act_fake_quant.disable() + if self.with_weight: + if enable: + self.weight_fake_quant.enable() + else: + self.weight_fake_quant.disable() + + def set_observer(self, enable): + if self.with_act: + if enable: + self.act_observer.enable() + else: + self.act_observer.disable() + if self.with_weight: + if enable: + self.weight_observer.enable() + else: + self.weight_observer.disable() def _apply_fakequant_with_observer( self, target: Tensor, fake_quant: FakeQuantize, observer: Observer diff --git a/python_module/megengine/module/qat/quant_dequant.py b/python_module/megengine/module/qat/quant_dequant.py index 84ebdf92..fb4018ff 100644 --- a/python_module/megengine/module/qat/quant_dequant.py +++ b/python_module/megengine/module/qat/quant_dequant.py @@ -15,6 +15,10 @@ class QuantStub(Float.QuantStub, QATModule): input after converted to :class:`~.QuantizedModule`. """ + def __init__(self): + super().__init__() + self.with_weight = False + def forward(self, inp): return self.apply_quant_activation(inp) @@ -33,6 +37,11 @@ class DequantStub(Float.DequantStub, QATModule): input after converted to :class:`~.QuantizedModule`. """ + def __init__(self): + super().__init__() + self.with_weight = False + self.with_act = False + def forward(self, inp): return inp diff --git a/python_module/megengine/quantization/quantize.py b/python_module/megengine/quantization/quantize.py index 36d6cc00..e36f5344 100644 --- a/python_module/megengine/quantization/quantize.py +++ b/python_module/megengine/quantization/quantize.py @@ -143,8 +143,7 @@ def disable_fake_quant(module: Module): def fn(mod: Module): if isinstance(mod, QATModule): - mod.act_fake_quant.disable() - mod.weight_fake_quant.disable() + mod.set_fake_quant(False) module.apply(fn) @@ -158,8 +157,7 @@ def disable_observer(module: Module): def fn(mod: Module): if isinstance(mod, QATModule): - mod.act_observer.disable() - mod.weight_observer.disable() + self.set_observer(False) module.apply(fn) @@ -173,8 +171,7 @@ def enable_fake_quant(module: Module): def fn(mod: Module): if isinstance(mod, QATModule): - mod.act_fake_quant.enable() - mod.weight_fake_quant.enable() + mod.set_fake_quant(True) module.apply(fn) @@ -188,7 +185,6 @@ def enable_observer(module: Module): def fn(mod: Module): if isinstance(mod, QATModule): - mod.act_observer.enable() - mod.weight_observer.enable() + mod.set_observer(True) module.apply(fn)