GitOrigin-RevId: a6545ee366
tags/v1.0.0-rc1
@@ -17,6 +17,10 @@ class Elemwise(Float.Elemwise, QATModule): | |||||
:param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail. | :param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail. | ||||
""" | """ | ||||
def __init__(self, method): | |||||
super().__init__(method) | |||||
self.with_weight = False | |||||
def forward(self, *inps): | def forward(self, *inps): | ||||
return self.apply_quant_activation(super().forward(*inps)) | return self.apply_quant_activation(super().forward(*inps)) | ||||
@@ -32,6 +32,9 @@ class QATModule(Module): | |||||
self.weight_fake_quant = None # type: FakeQuantize | self.weight_fake_quant = None # type: FakeQuantize | ||||
self.act_fake_quant = None # type: FakeQuantize | self.act_fake_quant = None # type: FakeQuantize | ||||
self.with_weight = True | |||||
self.with_act = True | |||||
def set_qconfig(self, qconfig: QConfig): | def set_qconfig(self, qconfig: QConfig): | ||||
r""" | r""" | ||||
Set quantization related configs with ``qconfig``, including | Set quantization related configs with ``qconfig``, including | ||||
@@ -41,10 +44,36 @@ class QATModule(Module): | |||||
def safe_call(func): | def safe_call(func): | ||||
return func() if func is not None else None | return func() if func is not None else None | ||||
self.weight_observer = safe_call(qconfig.weight_observer) | |||||
self.act_observer = safe_call(qconfig.act_observer) | |||||
self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) | |||||
self.act_fake_quant = safe_call(qconfig.act_fake_quant) | |||||
if self.with_act: | |||||
self.act_observer = safe_call(qconfig.act_observer) | |||||
self.act_fake_quant = safe_call(qconfig.act_fake_quant) | |||||
if self.with_weight: | |||||
self.weight_observer = safe_call(qconfig.weight_observer) | |||||
self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) | |||||
def set_fake_quant(self, enable): | |||||
if self.with_act: | |||||
if enable: | |||||
self.act_fake_quant.enable() | |||||
else: | |||||
self.act_fake_quant.disable() | |||||
if self.with_weight: | |||||
if enable: | |||||
self.weight_fake_quant.enable() | |||||
else: | |||||
self.weight_fake_quant.disable() | |||||
def set_observer(self, enable): | |||||
if self.with_act: | |||||
if enable: | |||||
self.act_observer.enable() | |||||
else: | |||||
self.act_observer.disable() | |||||
if self.with_weight: | |||||
if enable: | |||||
self.weight_observer.enable() | |||||
else: | |||||
self.weight_observer.disable() | |||||
def _apply_fakequant_with_observer( | def _apply_fakequant_with_observer( | ||||
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | ||||
@@ -15,6 +15,10 @@ class QuantStub(Float.QuantStub, QATModule): | |||||
input after converted to :class:`~.QuantizedModule`. | input after converted to :class:`~.QuantizedModule`. | ||||
""" | """ | ||||
def __init__(self): | |||||
super().__init__() | |||||
self.with_weight = False | |||||
def forward(self, inp): | def forward(self, inp): | ||||
return self.apply_quant_activation(inp) | return self.apply_quant_activation(inp) | ||||
@@ -33,6 +37,11 @@ class DequantStub(Float.DequantStub, QATModule): | |||||
input after converted to :class:`~.QuantizedModule`. | input after converted to :class:`~.QuantizedModule`. | ||||
""" | """ | ||||
def __init__(self): | |||||
super().__init__() | |||||
self.with_weight = False | |||||
self.with_act = False | |||||
def forward(self, inp): | def forward(self, inp): | ||||
return inp | return inp | ||||
@@ -143,8 +143,7 @@ def disable_fake_quant(module: Module): | |||||
def fn(mod: Module): | def fn(mod: Module): | ||||
if isinstance(mod, QATModule): | if isinstance(mod, QATModule): | ||||
mod.act_fake_quant.disable() | |||||
mod.weight_fake_quant.disable() | |||||
mod.set_fake_quant(False) | |||||
module.apply(fn) | module.apply(fn) | ||||
@@ -158,8 +157,7 @@ def disable_observer(module: Module): | |||||
def fn(mod: Module): | def fn(mod: Module): | ||||
if isinstance(mod, QATModule): | if isinstance(mod, QATModule): | ||||
mod.act_observer.disable() | |||||
mod.weight_observer.disable() | |||||
self.set_observer(False) | |||||
module.apply(fn) | module.apply(fn) | ||||
@@ -173,8 +171,7 @@ def enable_fake_quant(module: Module): | |||||
def fn(mod: Module): | def fn(mod: Module): | ||||
if isinstance(mod, QATModule): | if isinstance(mod, QATModule): | ||||
mod.act_fake_quant.enable() | |||||
mod.weight_fake_quant.enable() | |||||
mod.set_fake_quant(True) | |||||
module.apply(fn) | module.apply(fn) | ||||
@@ -188,7 +185,6 @@ def enable_observer(module: Module): | |||||
def fn(mod: Module): | def fn(mod: Module): | ||||
if isinstance(mod, QATModule): | if isinstance(mod, QATModule): | ||||
mod.act_observer.enable() | |||||
mod.weight_observer.enable() | |||||
mod.set_observer(True) | |||||
module.apply(fn) | module.apply(fn) |