|
|
@@ -11,6 +11,7 @@ from abc import abstractmethod |
|
|
|
from ...quantization.fake_quant import FakeQuantize |
|
|
|
from ...quantization.observer import Observer |
|
|
|
from ...quantization.qconfig import QConfig |
|
|
|
from ...quantization.utils import fake_quant_bias |
|
|
|
from ...tensor import Tensor |
|
|
|
from ..module import Module |
|
|
|
|
|
|
@@ -107,6 +108,24 @@ class QATModule(Module): |
|
|
|
target, self.act_fake_quant, self.act_observer |
|
|
|
) |
|
|
|
|
|
|
|
def apply_quant_bias(self, target: Tensor, inp: Tensor, w_qat: Tensor): |
|
|
|
r""" |
|
|
|
Use :func:`~.fake_quant_bias` to process ``target``. Only valid when |
|
|
|
``act_fake_quant`` and ``weight_fake_quant`` are both enabled. |
|
|
|
""" |
|
|
|
# bias should have the same dtype as activation, so act_fake_quant can also |
|
|
|
# decide whether to do bias fakequant |
|
|
|
if ( |
|
|
|
self.act_fake_quant |
|
|
|
and self.act_fake_quant.enabled |
|
|
|
and self.weight_fake_quant |
|
|
|
and self.weight_fake_quant.enabled |
|
|
|
): |
|
|
|
b_qat = fake_quant_bias(target, inp, w_qat) |
|
|
|
else: |
|
|
|
b_qat = target |
|
|
|
return b_qat |
|
|
|
|
|
|
|
def _get_method_result( |
|
|
|
self, method: str, fake_quant: FakeQuantize, observer: Observer |
|
|
|
): |
|
|
|