GitOrigin-RevId: 6e39de9cec
tags/v1.0.0-rc1
@@ -17,9 +17,7 @@ class Elemwise(Float.Elemwise, QATModule): | |||
:param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail. | |||
""" | |||
def __init__(self, method): | |||
super().__init__(method) | |||
self.with_weight = False | |||
with_weight = False | |||
def forward(self, *inps): | |||
return self.apply_quant_activation(super().forward(*inps)) | |||
@@ -23,6 +23,9 @@ class QATModule(Module): | |||
:func:`~.quantize.quantize` further. | |||
""" | |||
with_weight = True | |||
with_act = True | |||
def __init__(self): | |||
super().__init__() | |||
@@ -32,9 +35,6 @@ class QATModule(Module): | |||
self.weight_fake_quant = None # type: FakeQuantize | |||
self.act_fake_quant = None # type: FakeQuantize | |||
self.with_weight = True | |||
self.with_act = True | |||
def set_qconfig(self, qconfig: QConfig): | |||
r""" | |||
Set quantization related configs with ``qconfig``, including | |||
@@ -51,29 +51,21 @@ class QATModule(Module): | |||
self.weight_observer = safe_call(qconfig.weight_observer) | |||
self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) | |||
def _enable_exec(self, with_module, func, enable): | |||
if not with_module: | |||
return | |||
if enable: | |||
func.enable() | |||
else: | |||
func.disable() | |||
def set_fake_quant(self, enable): | |||
if self.with_act: | |||
if enable: | |||
self.act_fake_quant.enable() | |||
else: | |||
self.act_fake_quant.disable() | |||
if self.with_weight: | |||
if enable: | |||
self.weight_fake_quant.enable() | |||
else: | |||
self.weight_fake_quant.disable() | |||
self._enable_exec(self.with_act, self.act_fake_quant, enable) | |||
self._enable_exec(self.with_weight, self.weight_fake_quant, enable) | |||
def set_observer(self, enable): | |||
if self.with_act: | |||
if enable: | |||
self.act_observer.enable() | |||
else: | |||
self.act_observer.disable() | |||
if self.with_weight: | |||
if enable: | |||
self.weight_observer.enable() | |||
else: | |||
self.weight_observer.disable() | |||
self._enable_exec(self.with_act, self.act_observer, enable) | |||
self._enable_exec(self.with_weight, self.weight_observer, enable) | |||
def _apply_fakequant_with_observer( | |||
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | |||
@@ -15,9 +15,7 @@ class QuantStub(Float.QuantStub, QATModule): | |||
input after converted to :class:`~.QuantizedModule`. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self.with_weight = False | |||
with_weight = False | |||
def forward(self, inp): | |||
return self.apply_quant_activation(inp) | |||
@@ -37,10 +35,8 @@ class DequantStub(Float.DequantStub, QATModule): | |||
input after converted to :class:`~.QuantizedModule`. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self.with_weight = False | |||
self.with_act = False | |||
with_weight = False | |||
with_act = False | |||
def forward(self, inp): | |||
return inp | |||
@@ -116,10 +116,11 @@ class TQT(_FakeQuantize): | |||
return TQT_Function(self.qmin, self.qmax)(inp, self.scale) | |||
def normal_foward(self, inp, q_dict): | |||
# when disable, TQT will do normal forward, initialize scale weight | |||
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) | |||
tmp_scale = F.log(tmp_scale / 127) / F.log(2) | |||
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) | |||
if q_dict["enable_observer"]: | |||
# when disable, TQT will do normal forward, initialize scale weight | |||
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) | |||
tmp_scale = F.log(tmp_scale / 127) / F.log(2) | |||
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) | |||
return inp | |||
def get_qparams(self): | |||
@@ -102,6 +102,7 @@ class MinMaxObserver(Observer): | |||
q_dict = get_qparam_dict(self.mode) | |||
q_dict["min_val"] = inp_min_val | |||
q_dict["max_val"] = inp_max_val | |||
q_dict["enable_observer"] = self.enable | |||
if self.mode == QuantMode.SYMMERTIC: | |||
symmetric_max_vals = F.maximum(-min_val, max_val) | |||
# use maximun to avoid scale too small at the begin | |||
@@ -14,6 +14,7 @@ from ..module import qat as QAT | |||
from ..module import quantized as Quantized | |||
from ..module.qat import QATModule | |||
from ..module.quantized import QuantizedModule | |||
from .fake_quant import TQT | |||
from .qconfig import QConfig, ema_fakequant_qconfig | |||
@@ -119,6 +120,14 @@ def quantize_qat( | |||
return module | |||
def _propagate(module: Module, func_str: str, *args, **kargs): | |||
def fn(mod: Module): | |||
if isinstance(mod, QATModule): | |||
getattr(mod, func_str)(*args, **kargs) | |||
module.apply(fn) | |||
def propagate_qconfig(module: QATModule, qconfig: QConfig): | |||
r""" | |||
Recursively set ``module``'s qconfig through :meth:`~.Module.apply`. | |||
@@ -126,12 +135,7 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig): | |||
:param module: root module to traverse recursively. | |||
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. | |||
""" | |||
def fn(mod: Module): | |||
if isinstance(mod, QATModule): | |||
mod.set_qconfig(qconfig) | |||
module.apply(fn) | |||
_propagate(module, "set_qconfig", qconfig) | |||
def disable_fake_quant(module: Module): | |||
@@ -141,11 +145,7 @@ def disable_fake_quant(module: Module): | |||
:param module: root module to do disable fake quantization recursively. | |||
""" | |||
def fn(mod: Module): | |||
if isinstance(mod, QATModule): | |||
mod.set_fake_quant(False) | |||
module.apply(fn) | |||
_propagate(module, "set_fake_quant", False) | |||
def disable_observer(module: Module): | |||
@@ -155,11 +155,7 @@ def disable_observer(module: Module): | |||
:param module: root module to do disable observer recursively. | |||
""" | |||
def fn(mod: Module): | |||
if isinstance(mod, QATModule): | |||
self.set_observer(False) | |||
module.apply(fn) | |||
_propagate(module, "set_observer", False) | |||
def enable_fake_quant(module: Module): | |||
@@ -169,11 +165,7 @@ def enable_fake_quant(module: Module): | |||
:param module: root module to do enable fake quantization recursively. | |||
""" | |||
def fn(mod: Module): | |||
if isinstance(mod, QATModule): | |||
mod.set_fake_quant(True) | |||
module.apply(fn) | |||
_propagate(module, "set_fake_quant", True) | |||
def enable_observer(module: Module): | |||
@@ -183,8 +175,4 @@ def enable_observer(module: Module): | |||
:param module: root module to do enable observer recursively. | |||
""" | |||
def fn(mod: Module): | |||
if isinstance(mod, QATModule): | |||
mod.set_observer(True) | |||
module.apply(fn) | |||
_propagate(module, "set_observer", False) |