You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module.py 5.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. from abc import abstractmethod
  2. # avoid circular reference
  3. from ...quantization.fake_quant import FakeQuantize
  4. from ...quantization.observer import Observer
  5. from ...quantization.qconfig import QConfig
  6. from ...quantization.utils import fake_quant_bias
  7. from ...tensor import Tensor
  8. from ..module import Module
  9. class QATModule(Module):
  10. r"""Base class of quantized-float related :class:`~.Module`, basically for QAT and Calibration.
  11. Use :meth:`from_float_module` to generate a instance from float :class:`~.Module`.
  12. Or use :func:`~.quantize.quantize_qat` to do it recursively and automatically.
  13. Can also be converted to :class:`~.QuantizedModule` for deployment using
  14. :func:`~.quantize.quantize` further.
  15. """
  16. with_weight = True
  17. with_act = True
  18. def __init__(self, **kwargs):
  19. super().__init__(**kwargs)
  20. self.weight_observer = None # type: Observer
  21. self.act_observer = None # type: Observer
  22. self.weight_fake_quant = None # type: FakeQuantize
  23. self.act_fake_quant = None # type: FakeQuantize
  24. def __repr__(self):
  25. return "QAT." + super().__repr__()
  26. def set_qconfig(self, qconfig: QConfig):
  27. r"""Set quantization related configs with ``qconfig``, including
  28. observer and fake_quant for weight and activation.
  29. """
  30. def safe_call(func):
  31. return func() if func is not None else None
  32. if self.with_act:
  33. self.act_observer = safe_call(qconfig.act_observer)
  34. self.act_fake_quant = safe_call(qconfig.act_fake_quant)
  35. if self.with_weight:
  36. self.weight_observer = safe_call(qconfig.weight_observer)
  37. self.weight_fake_quant = safe_call(qconfig.weight_fake_quant)
  38. def _enable_exec(self, with_module, func, enable):
  39. if not with_module or not func:
  40. return
  41. if enable:
  42. func.enable()
  43. else:
  44. func.disable()
  45. def set_fake_quant(self, enable):
  46. self._enable_exec(self.with_act, self.act_fake_quant, enable)
  47. self._enable_exec(self.with_weight, self.weight_fake_quant, enable)
  48. def set_observer(self, enable):
  49. self._enable_exec(self.with_act, self.act_observer, enable)
  50. self._enable_exec(self.with_weight, self.weight_observer, enable)
  51. def _apply_fakequant_with_observer(
  52. self, target: Tensor, fake_quant: FakeQuantize, observer: Observer
  53. ):
  54. # do observer
  55. if observer is None:
  56. oup = target
  57. qparams = None
  58. else:
  59. oup = observer(target)
  60. qparams = observer.get_qparams()
  61. # do fake quant
  62. if fake_quant is not None:
  63. oup = fake_quant(oup, qparams)
  64. # use qparams of fake_quant if have.
  65. if hasattr(fake_quant, "get_qparams"):
  66. qparams = fake_quant.get_qparams()
  67. # set to tensor qparams.
  68. if qparams is not None:
  69. oup.qparams.update(qparams)
  70. return oup
  71. def apply_quant_weight(self, target: Tensor):
  72. r"""Apply weight's observer and fake_quant from ``qconfig`` on ``target``."""
  73. return self._apply_fakequant_with_observer(
  74. target, self.weight_fake_quant, self.weight_observer
  75. )
  76. def apply_quant_activation(self, target: Tensor):
  77. r"""Apply weight's observer and fake_quant from ``qconfig`` on ``target``."""
  78. return self._apply_fakequant_with_observer(
  79. target, self.act_fake_quant, self.act_observer
  80. )
  81. def apply_quant_bias(self, target: Tensor, inp: Tensor, w_qat: Tensor):
  82. r"""Use :func:`~.fake_quant_bias` to process ``target``. Only valid when
  83. ``act_fake_quant`` and ``weight_fake_quant`` are both enabled.
  84. """
  85. # bias should have the same dtype as activation, so act_fake_quant can also
  86. # decide whether to do bias fakequant
  87. if (
  88. self.act_fake_quant
  89. and self.act_fake_quant.enabled
  90. and self.weight_fake_quant
  91. and self.weight_fake_quant.enabled
  92. ):
  93. b_qat = fake_quant_bias(target, inp, w_qat)
  94. else:
  95. b_qat = target
  96. return b_qat
  97. def _get_method_result(
  98. self, method: str, fake_quant: FakeQuantize, observer: Observer
  99. ):
  100. if hasattr(fake_quant, method):
  101. return getattr(fake_quant, method)()
  102. elif hasattr(observer, method):
  103. return getattr(observer, method)()
  104. return None
  105. def get_weight_dtype(self):
  106. r"""Get weight's quantization dtype as the method from ``qconfig``."""
  107. return self._get_method_result(
  108. "get_quantized_dtype", self.weight_fake_quant, self.weight_observer
  109. )
  110. def get_activation_dtype(self):
  111. r"""Get activation's quantization dtype as the method from ``qconfig``."""
  112. return self._get_method_result(
  113. "get_quantized_dtype", self.act_fake_quant, self.act_observer
  114. )
  115. def get_weight_qparams(self):
  116. r"""Get weight's quantization parameters."""
  117. return self._get_method_result(
  118. "get_qparams", self.weight_fake_quant, self.weight_observer
  119. )
  120. def get_activation_qparams(self):
  121. r"""Get activation's quantization parameters."""
  122. return self._get_method_result(
  123. "get_qparams", self.act_fake_quant, self.act_observer
  124. )
  125. @classmethod
  126. @abstractmethod
  127. def from_float_module(cls, float_module: Module):
  128. r"""Return a :class:`~.QATModule` instance converted from
  129. a float :class:`~.Module` instance.
  130. """