GitOrigin-RevId: 6e39de9cec
tags/v1.0.0-rc1
@@ -17,9 +17,7 @@ class Elemwise(Float.Elemwise, QATModule): | |||||
:param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail. | :param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail. | ||||
""" | """ | ||||
def __init__(self, method): | |||||
super().__init__(method) | |||||
self.with_weight = False | |||||
with_weight = False | |||||
def forward(self, *inps): | def forward(self, *inps): | ||||
return self.apply_quant_activation(super().forward(*inps)) | return self.apply_quant_activation(super().forward(*inps)) | ||||
@@ -23,6 +23,9 @@ class QATModule(Module): | |||||
:func:`~.quantize.quantize` further. | :func:`~.quantize.quantize` further. | ||||
""" | """ | ||||
with_weight = True | |||||
with_act = True | |||||
def __init__(self): | def __init__(self): | ||||
super().__init__() | super().__init__() | ||||
@@ -32,9 +35,6 @@ class QATModule(Module): | |||||
self.weight_fake_quant = None # type: FakeQuantize | self.weight_fake_quant = None # type: FakeQuantize | ||||
self.act_fake_quant = None # type: FakeQuantize | self.act_fake_quant = None # type: FakeQuantize | ||||
self.with_weight = True | |||||
self.with_act = True | |||||
def set_qconfig(self, qconfig: QConfig): | def set_qconfig(self, qconfig: QConfig): | ||||
r""" | r""" | ||||
Set quantization related configs with ``qconfig``, including | Set quantization related configs with ``qconfig``, including | ||||
@@ -51,29 +51,21 @@ class QATModule(Module): | |||||
self.weight_observer = safe_call(qconfig.weight_observer) | self.weight_observer = safe_call(qconfig.weight_observer) | ||||
self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) | self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) | ||||
def _enable_exec(self, with_module, func, enable): | |||||
if not with_module: | |||||
return | |||||
if enable: | |||||
func.enable() | |||||
else: | |||||
func.disable() | |||||
def set_fake_quant(self, enable): | def set_fake_quant(self, enable): | ||||
if self.with_act: | |||||
if enable: | |||||
self.act_fake_quant.enable() | |||||
else: | |||||
self.act_fake_quant.disable() | |||||
if self.with_weight: | |||||
if enable: | |||||
self.weight_fake_quant.enable() | |||||
else: | |||||
self.weight_fake_quant.disable() | |||||
self._enable_exec(self.with_act, self.act_fake_quant, enable) | |||||
self._enable_exec(self.with_weight, self.weight_fake_quant, enable) | |||||
def set_observer(self, enable): | def set_observer(self, enable): | ||||
if self.with_act: | |||||
if enable: | |||||
self.act_observer.enable() | |||||
else: | |||||
self.act_observer.disable() | |||||
if self.with_weight: | |||||
if enable: | |||||
self.weight_observer.enable() | |||||
else: | |||||
self.weight_observer.disable() | |||||
self._enable_exec(self.with_act, self.act_observer, enable) | |||||
self._enable_exec(self.with_weight, self.weight_observer, enable) | |||||
def _apply_fakequant_with_observer( | def _apply_fakequant_with_observer( | ||||
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | ||||
@@ -15,9 +15,7 @@ class QuantStub(Float.QuantStub, QATModule): | |||||
input after converted to :class:`~.QuantizedModule`. | input after converted to :class:`~.QuantizedModule`. | ||||
""" | """ | ||||
def __init__(self): | |||||
super().__init__() | |||||
self.with_weight = False | |||||
with_weight = False | |||||
def forward(self, inp): | def forward(self, inp): | ||||
return self.apply_quant_activation(inp) | return self.apply_quant_activation(inp) | ||||
@@ -37,10 +35,8 @@ class DequantStub(Float.DequantStub, QATModule): | |||||
input after converted to :class:`~.QuantizedModule`. | input after converted to :class:`~.QuantizedModule`. | ||||
""" | """ | ||||
def __init__(self): | |||||
super().__init__() | |||||
self.with_weight = False | |||||
self.with_act = False | |||||
with_weight = False | |||||
with_act = False | |||||
def forward(self, inp): | def forward(self, inp): | ||||
return inp | return inp | ||||
@@ -116,10 +116,11 @@ class TQT(_FakeQuantize): | |||||
return TQT_Function(self.qmin, self.qmax)(inp, self.scale) | return TQT_Function(self.qmin, self.qmax)(inp, self.scale) | ||||
def normal_foward(self, inp, q_dict): | def normal_foward(self, inp, q_dict): | ||||
# when disable, TQT will do normal forward, initialize scale weight | |||||
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) | |||||
tmp_scale = F.log(tmp_scale / 127) / F.log(2) | |||||
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) | |||||
if q_dict["enable_observer"]: | |||||
# when disable, TQT will do normal forward, initialize scale weight | |||||
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) | |||||
tmp_scale = F.log(tmp_scale / 127) / F.log(2) | |||||
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) | |||||
return inp | return inp | ||||
def get_qparams(self): | def get_qparams(self): | ||||
@@ -102,6 +102,7 @@ class MinMaxObserver(Observer): | |||||
q_dict = get_qparam_dict(self.mode) | q_dict = get_qparam_dict(self.mode) | ||||
q_dict["min_val"] = inp_min_val | q_dict["min_val"] = inp_min_val | ||||
q_dict["max_val"] = inp_max_val | q_dict["max_val"] = inp_max_val | ||||
q_dict["enable_observer"] = self.enable | |||||
if self.mode == QuantMode.SYMMERTIC: | if self.mode == QuantMode.SYMMERTIC: | ||||
symmetric_max_vals = F.maximum(-min_val, max_val) | symmetric_max_vals = F.maximum(-min_val, max_val) | ||||
# use maximun to avoid scale too small at the begin | # use maximun to avoid scale too small at the begin | ||||
@@ -14,6 +14,7 @@ from ..module import qat as QAT | |||||
from ..module import quantized as Quantized | from ..module import quantized as Quantized | ||||
from ..module.qat import QATModule | from ..module.qat import QATModule | ||||
from ..module.quantized import QuantizedModule | from ..module.quantized import QuantizedModule | ||||
from .fake_quant import TQT | |||||
from .qconfig import QConfig, ema_fakequant_qconfig | from .qconfig import QConfig, ema_fakequant_qconfig | ||||
@@ -119,6 +120,14 @@ def quantize_qat( | |||||
return module | return module | ||||
def _propagate(module: Module, func_str: str, *args, **kargs): | |||||
def fn(mod: Module): | |||||
if isinstance(mod, QATModule): | |||||
getattr(mod, func_str)(*args, **kargs) | |||||
module.apply(fn) | |||||
def propagate_qconfig(module: QATModule, qconfig: QConfig): | def propagate_qconfig(module: QATModule, qconfig: QConfig): | ||||
r""" | r""" | ||||
Recursively set ``module``'s qconfig through :meth:`~.Module.apply`. | Recursively set ``module``'s qconfig through :meth:`~.Module.apply`. | ||||
@@ -126,12 +135,7 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig): | |||||
:param module: root module to traverse recursively. | :param module: root module to traverse recursively. | ||||
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. | :param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. | ||||
""" | """ | ||||
def fn(mod: Module): | |||||
if isinstance(mod, QATModule): | |||||
mod.set_qconfig(qconfig) | |||||
module.apply(fn) | |||||
_propagate(module, "set_qconfig", qconfig) | |||||
def disable_fake_quant(module: Module): | def disable_fake_quant(module: Module): | ||||
@@ -141,11 +145,7 @@ def disable_fake_quant(module: Module): | |||||
:param module: root module to do disable fake quantization recursively. | :param module: root module to do disable fake quantization recursively. | ||||
""" | """ | ||||
def fn(mod: Module): | |||||
if isinstance(mod, QATModule): | |||||
mod.set_fake_quant(False) | |||||
module.apply(fn) | |||||
_propagate(module, "set_fake_quant", False) | |||||
def disable_observer(module: Module): | def disable_observer(module: Module): | ||||
@@ -155,11 +155,7 @@ def disable_observer(module: Module): | |||||
:param module: root module to do disable observer recursively. | :param module: root module to do disable observer recursively. | ||||
""" | """ | ||||
def fn(mod: Module): | |||||
if isinstance(mod, QATModule): | |||||
self.set_observer(False) | |||||
module.apply(fn) | |||||
_propagate(module, "set_observer", False) | |||||
def enable_fake_quant(module: Module): | def enable_fake_quant(module: Module): | ||||
@@ -169,11 +165,7 @@ def enable_fake_quant(module: Module): | |||||
:param module: root module to do enable fake quantization recursively. | :param module: root module to do enable fake quantization recursively. | ||||
""" | """ | ||||
def fn(mod: Module): | |||||
if isinstance(mod, QATModule): | |||||
mod.set_fake_quant(True) | |||||
module.apply(fn) | |||||
_propagate(module, "set_fake_quant", True) | |||||
def enable_observer(module: Module): | def enable_observer(module: Module): | ||||
@@ -183,8 +175,4 @@ def enable_observer(module: Module): | |||||
:param module: root module to do enable observer recursively. | :param module: root module to do enable observer recursively. | ||||
""" | """ | ||||
def fn(mod: Module): | |||||
if isinstance(mod, QATModule): | |||||
mod.set_observer(True) | |||||
module.apply(fn) | |||||
_propagate(module, "set_observer", False) |