|
|
@@ -32,6 +32,9 @@ class QATModule(Module): |
|
|
|
self.weight_fake_quant = None # type: FakeQuantize |
|
|
|
self.act_fake_quant = None # type: FakeQuantize |
|
|
|
|
|
|
|
self.with_weight = True |
|
|
|
self.with_act = True |
|
|
|
|
|
|
|
def set_qconfig(self, qconfig: QConfig): |
|
|
|
r""" |
|
|
|
Set quantization related configs with ``qconfig``, including |
|
|
@@ -41,10 +44,36 @@ class QATModule(Module): |
|
|
|
def safe_call(func): |
|
|
|
return func() if func is not None else None |
|
|
|
|
|
|
|
self.weight_observer = safe_call(qconfig.weight_observer) |
|
|
|
self.act_observer = safe_call(qconfig.act_observer) |
|
|
|
self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) |
|
|
|
self.act_fake_quant = safe_call(qconfig.act_fake_quant) |
|
|
|
if self.with_act: |
|
|
|
self.act_observer = safe_call(qconfig.act_observer) |
|
|
|
self.act_fake_quant = safe_call(qconfig.act_fake_quant) |
|
|
|
if self.with_weight: |
|
|
|
self.weight_observer = safe_call(qconfig.weight_observer) |
|
|
|
self.weight_fake_quant = safe_call(qconfig.weight_fake_quant) |
|
|
|
|
|
|
|
def set_fake_quant(self, enable): |
|
|
|
if self.with_act: |
|
|
|
if enable: |
|
|
|
self.act_fake_quant.enable() |
|
|
|
else: |
|
|
|
self.act_fake_quant.disable() |
|
|
|
if self.with_weight: |
|
|
|
if enable: |
|
|
|
self.weight_fake_quant.enable() |
|
|
|
else: |
|
|
|
self.weight_fake_quant.disable() |
|
|
|
|
|
|
|
def set_observer(self, enable): |
|
|
|
if self.with_act: |
|
|
|
if enable: |
|
|
|
self.act_observer.enable() |
|
|
|
else: |
|
|
|
self.act_observer.disable() |
|
|
|
if self.with_weight: |
|
|
|
if enable: |
|
|
|
self.weight_observer.enable() |
|
|
|
else: |
|
|
|
self.weight_observer.disable() |
|
|
|
|
|
|
|
def _apply_fakequant_with_observer( |
|
|
|
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer |
|
|
|