Browse Source

fix(mge/quantization): handle empty Observer in QATModule

GitOrigin-RevId: e8a62297bc
tags/v1.0.0-rc1
Megvii Engine Team Xinran Xu 4 years ago
parent
commit
b43f6a2602
1 changed files with 32 additions and 24 deletions
  1. +32
    -24
      python_module/megengine/module/qat/module.py

+ 32
- 24
python_module/megengine/module/qat/module.py View File

@@ -70,18 +70,22 @@ class QATModule(Module):
def _apply_fakequant_with_observer(
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer
):
# do observer
if observer is None:
return target
oup = observer(target)
q_dict = observer.get_qparams()
q_dict = None
oup = target
else:
q_dict = observer.get_qparams()
oup = observer(target)
# do fake quant
if fake_quant is not None:
oup = fake_quant(oup, q_dict)
# use qparams of fake_quant if have.
if hasattr(fake_quant, "get_qparams"):
q_dict = fake_quant.get_qparams()
# use qparams of fake_quant if have.
if hasattr(fake_quant, "get_qparams"):
q_dict = fake_quant.get_qparams()
# set to tensor qparams.
oup.q_dict.update(q_dict)
if q_dict is not None:
oup.q_dict.update(q_dict)
return oup

def apply_quant_weight(self, target: Tensor):
@@ -100,42 +104,46 @@ class QATModule(Module):
target, self.act_fake_quant, self.act_observer
)

def _get_method_result(
self, method: str, fake_quant: FakeQuantize, observer: Observer
):
if hasattr(fake_quant, method):
return getattr(fake_quant, method)()
elif hasattr(observer, method):
return getattr(observer, method)()
return None

def get_weight_dtype(self):
r"""
Get weight's quantization dtype as the method from ``qconfig``.
"""
if hasattr(self.weight_fake_quant, "get_dtype"):
return self.weight_fake_quant.get_dtype()
else:
return self.weight_observer.get_dtype()
return self._get_method_result(
"get_dtype", self.weight_fake_quant, self.weight_observer
)

def get_activation_dtype(self):
r"""
Get activation's quantization dtype as the method from ``qconfig``.
"""
if hasattr(self.act_fake_quant, "get_dtype"):
return self.act_fake_quant.get_dtype()
else:
return self.act_observer.get_dtype()

def _get_qparams(self, fake_quant: FakeQuantize, observer: Observer):
if hasattr(fake_quant, "get_qparams"):
return fake_quant.get_qparams()
elif observer is not None:
return observer.get_qparams()
return None
return self._get_method_result(
"get_dtype", self.act_fake_quant, self.act_observer
)

def get_weight_qparams(self):
r"""
Get weight's quantization parameters.
"""
return self._get_qparams(self.weight_fake_quant, self.weight_observer)
return self._get_method_result(
"get_qparams", self.weight_fake_quant, self.weight_observer
)

def get_activation_qparams(self):
r"""
Get activation's quantization parameters.
"""
return self._get_qparams(self.act_fake_quant, self.act_observer)
return self._get_method_result(
"get_qparams", self.act_fake_quant, self.act_observer
)

@classmethod
@abstractmethod


Loading…
Cancel
Save