|
|
@@ -70,18 +70,22 @@ class QATModule(Module): |
|
|
|
def _apply_fakequant_with_observer( |
|
|
|
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer |
|
|
|
): |
|
|
|
# do observer |
|
|
|
if observer is None: |
|
|
|
return target |
|
|
|
oup = observer(target) |
|
|
|
q_dict = observer.get_qparams() |
|
|
|
q_dict = None |
|
|
|
oup = target |
|
|
|
else: |
|
|
|
q_dict = observer.get_qparams() |
|
|
|
oup = observer(target) |
|
|
|
# do fake quant |
|
|
|
if fake_quant is not None: |
|
|
|
oup = fake_quant(oup, q_dict) |
|
|
|
# use qparams of fake_quant if have. |
|
|
|
if hasattr(fake_quant, "get_qparams"): |
|
|
|
q_dict = fake_quant.get_qparams() |
|
|
|
# use qparams of fake_quant if have. |
|
|
|
if hasattr(fake_quant, "get_qparams"): |
|
|
|
q_dict = fake_quant.get_qparams() |
|
|
|
# set to tensor qparams. |
|
|
|
oup.q_dict.update(q_dict) |
|
|
|
if q_dict is not None: |
|
|
|
oup.q_dict.update(q_dict) |
|
|
|
return oup |
|
|
|
|
|
|
|
def apply_quant_weight(self, target: Tensor): |
|
|
@@ -100,42 +104,46 @@ class QATModule(Module): |
|
|
|
target, self.act_fake_quant, self.act_observer |
|
|
|
) |
|
|
|
|
|
|
|
def _get_method_result( |
|
|
|
self, method: str, fake_quant: FakeQuantize, observer: Observer |
|
|
|
): |
|
|
|
if hasattr(fake_quant, method): |
|
|
|
return getattr(fake_quant, method)() |
|
|
|
elif hasattr(observer, method): |
|
|
|
return getattr(observer, method)() |
|
|
|
return None |
|
|
|
|
|
|
|
def get_weight_dtype(self): |
|
|
|
r""" |
|
|
|
Get weight's quantization dtype as the method from ``qconfig``. |
|
|
|
""" |
|
|
|
if hasattr(self.weight_fake_quant, "get_dtype"): |
|
|
|
return self.weight_fake_quant.get_dtype() |
|
|
|
else: |
|
|
|
return self.weight_observer.get_dtype() |
|
|
|
return self._get_method_result( |
|
|
|
"get_dtype", self.weight_fake_quant, self.weight_observer |
|
|
|
) |
|
|
|
|
|
|
|
def get_activation_dtype(self): |
|
|
|
r""" |
|
|
|
Get activation's quantization dtype as the method from ``qconfig``. |
|
|
|
""" |
|
|
|
if hasattr(self.act_fake_quant, "get_dtype"): |
|
|
|
return self.act_fake_quant.get_dtype() |
|
|
|
else: |
|
|
|
return self.act_observer.get_dtype() |
|
|
|
|
|
|
|
def _get_qparams(self, fake_quant: FakeQuantize, observer: Observer): |
|
|
|
if hasattr(fake_quant, "get_qparams"): |
|
|
|
return fake_quant.get_qparams() |
|
|
|
elif observer is not None: |
|
|
|
return observer.get_qparams() |
|
|
|
return None |
|
|
|
return self._get_method_result( |
|
|
|
"get_dtype", self.act_fake_quant, self.act_observer |
|
|
|
) |
|
|
|
|
|
|
|
def get_weight_qparams(self): |
|
|
|
r""" |
|
|
|
Get weight's quantization parameters. |
|
|
|
""" |
|
|
|
return self._get_qparams(self.weight_fake_quant, self.weight_observer) |
|
|
|
return self._get_method_result( |
|
|
|
"get_qparams", self.weight_fake_quant, self.weight_observer |
|
|
|
) |
|
|
|
|
|
|
|
def get_activation_qparams(self): |
|
|
|
r""" |
|
|
|
Get activation's quantization parameters. |
|
|
|
""" |
|
|
|
return self._get_qparams(self.act_fake_quant, self.act_observer) |
|
|
|
return self._get_method_result( |
|
|
|
"get_qparams", self.act_fake_quant, self.act_observer |
|
|
|
) |
|
|
|
|
|
|
|
@classmethod |
|
|
|
@abstractmethod |
|
|
|