You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

module.py 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from abc import abstractmethod
  9. # avoid circular reference
  10. from ...quantization.fake_quant import FakeQuantize
  11. from ...quantization.observer import Observer
  12. from ...quantization.qconfig import QConfig
  13. from ...quantization.utils import fake_quant_bias
  14. from ...tensor import Tensor
  15. from ..module import Module
  16. class QATModule(Module):
  17. r"""
  18. Base class of quantized-float related :class:`~.Module`, basically for QAT and Calibration.
  19. Use :meth:`from_float_module` to generate a instance from float :class:`~.Module`.
  20. Or use :func:`~.quantize.quantize_qat` to do it recursively and automatically.
  21. Can also be converted to :class:`~.QuantizedModule` for deployment using
  22. :func:`~.quantize.quantize` further.
  23. """
  24. with_weight = True
  25. with_act = True
  26. def __init__(self, **kwargs):
  27. super().__init__(**kwargs)
  28. self.weight_observer = None # type: Observer
  29. self.act_observer = None # type: Observer
  30. self.weight_fake_quant = None # type: FakeQuantize
  31. self.act_fake_quant = None # type: FakeQuantize
  32. def __repr__(self):
  33. return "QAT." + super().__repr__()
  34. def set_qconfig(self, qconfig: QConfig):
  35. r"""
  36. Set quantization related configs with ``qconfig``, including
  37. observer and fake_quant for weight and activation.
  38. """
  39. def safe_call(func):
  40. return func() if func is not None else None
  41. if self.with_act:
  42. self.act_observer = safe_call(qconfig.act_observer)
  43. self.act_fake_quant = safe_call(qconfig.act_fake_quant)
  44. if self.with_weight:
  45. self.weight_observer = safe_call(qconfig.weight_observer)
  46. self.weight_fake_quant = safe_call(qconfig.weight_fake_quant)
  47. def _enable_exec(self, with_module, func, enable):
  48. if not with_module or not func:
  49. return
  50. if enable:
  51. func.enable()
  52. else:
  53. func.disable()
  54. def set_fake_quant(self, enable):
  55. self._enable_exec(self.with_act, self.act_fake_quant, enable)
  56. self._enable_exec(self.with_weight, self.weight_fake_quant, enable)
  57. def set_observer(self, enable):
  58. self._enable_exec(self.with_act, self.act_observer, enable)
  59. self._enable_exec(self.with_weight, self.weight_observer, enable)
  60. def _apply_fakequant_with_observer(
  61. self, target: Tensor, fake_quant: FakeQuantize, observer: Observer
  62. ):
  63. # do observer
  64. if observer is None:
  65. oup = target
  66. qparams = None
  67. else:
  68. oup = observer(target)
  69. qparams = observer.get_qparams()
  70. # do fake quant
  71. if fake_quant is not None:
  72. oup = fake_quant(oup, qparams)
  73. # use qparams of fake_quant if have.
  74. if hasattr(fake_quant, "get_qparams"):
  75. qparams = fake_quant.get_qparams()
  76. # set to tensor qparams.
  77. if qparams is not None:
  78. oup.qparams.update(qparams)
  79. return oup
  80. def apply_quant_weight(self, target: Tensor):
  81. r"""
  82. Apply weight's observer and fake_quant from ``qconfig`` on ``target``.
  83. """
  84. return self._apply_fakequant_with_observer(
  85. target, self.weight_fake_quant, self.weight_observer
  86. )
  87. def apply_quant_activation(self, target: Tensor):
  88. r"""
  89. Apply weight's observer and fake_quant from ``qconfig`` on ``target``.
  90. """
  91. return self._apply_fakequant_with_observer(
  92. target, self.act_fake_quant, self.act_observer
  93. )
  94. def apply_quant_bias(self, target: Tensor, inp: Tensor, w_qat: Tensor):
  95. r"""
  96. Use :func:`~.fake_quant_bias` to process ``target``. Only valid when
  97. ``act_fake_quant`` and ``weight_fake_quant`` are both enabled.
  98. """
  99. # bias should have the same dtype as activation, so act_fake_quant can also
  100. # decide whether to do bias fakequant
  101. if (
  102. self.act_fake_quant
  103. and self.act_fake_quant.enabled
  104. and self.weight_fake_quant
  105. and self.weight_fake_quant.enabled
  106. ):
  107. b_qat = fake_quant_bias(target, inp, w_qat)
  108. else:
  109. b_qat = target
  110. return b_qat
  111. def _get_method_result(
  112. self, method: str, fake_quant: FakeQuantize, observer: Observer
  113. ):
  114. if hasattr(fake_quant, method):
  115. return getattr(fake_quant, method)()
  116. elif hasattr(observer, method):
  117. return getattr(observer, method)()
  118. return None
  119. def get_weight_dtype(self):
  120. r"""
  121. Get weight's quantization dtype as the method from ``qconfig``.
  122. """
  123. return self._get_method_result(
  124. "get_quantized_dtype", self.weight_fake_quant, self.weight_observer
  125. )
  126. def get_activation_dtype(self):
  127. r"""
  128. Get activation's quantization dtype as the method from ``qconfig``.
  129. """
  130. return self._get_method_result(
  131. "get_quantized_dtype", self.act_fake_quant, self.act_observer
  132. )
  133. def get_weight_qparams(self):
  134. r"""
  135. Get weight's quantization parameters.
  136. """
  137. return self._get_method_result(
  138. "get_qparams", self.weight_fake_quant, self.weight_observer
  139. )
  140. def get_activation_qparams(self):
  141. r"""
  142. Get activation's quantization parameters.
  143. """
  144. return self._get_method_result(
  145. "get_qparams", self.act_fake_quant, self.act_observer
  146. )
  147. @classmethod
  148. @abstractmethod
  149. def from_float_module(cls, float_module: Module):
  150. r"""
  151. Return a :class:`~.QATModule` instance converted from
  152. a float :class:`~.Module` instance.
  153. """

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台