You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fake_quant.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. import math
  2. from typing import Union
  3. from .. import functional as F
  4. from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes
  5. from ..logger import get_logger
  6. from ..module import Module
  7. from ..tensor import Parameter, Tensor
  8. from .utils import (
  9. LSQParams,
  10. QParams,
  11. QParamsModuleMixin,
  12. QuantMode,
  13. create_qparams,
  14. fake_quant_tensor,
  15. lsq_forward,
  16. tqt_forward,
  17. )
  18. logger = get_logger(__name__)
  19. class _FakeQuantize(Module):
  20. def __init__(
  21. self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
  22. ):
  23. super().__init__()
  24. if isinstance(dtype, str):
  25. if not dtype in _builtin_quant_dtypes:
  26. raise ValueError(
  27. "unknown dtype: {}, only support {}".format(
  28. dtype, _builtin_quant_dtypes.keys()
  29. )
  30. )
  31. dtype = _builtin_quant_dtypes[dtype]
  32. if "narrow_range" in kwargs:
  33. del kwargs["narrow_range"]
  34. logger.warning(
  35. "FakeQuantize currently has no narrow_range param "
  36. "so it is ignored here",
  37. exc_info=DeprecationWarning,
  38. )
  39. self.dtype = dtype
  40. self.qmin = dtype.qmin
  41. self.qmax = dtype.qmax
  42. self.enabled = enable
  43. def enable(self):
  44. self.enabled = True
  45. def disable(self):
  46. self.enabled = False
  47. def fake_quant_forward(self, inp, qparams: QParams = None):
  48. raise NotImplementedError
  49. def normal_forward(self, inp, qparams: QParams = None):
  50. return inp
  51. def forward(self, inp, qparams: QParams = None):
  52. if self.enabled:
  53. return self.fake_quant_forward(inp, qparams=qparams)
  54. else:
  55. return self.normal_forward(inp, qparams=qparams)
  56. class TQT(_FakeQuantize, QParamsModuleMixin):
  57. r"""TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds
  58. for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks.
  59. Args:
  60. dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
  61. quantization dtype of input.
  62. enable: whether do ``normal_forward`` or ``fake_quant_forward``.
  63. """
  64. def __init__(
  65. self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
  66. ):
  67. super().__init__(dtype, enable, **kwargs)
  68. self.scale = Parameter(0.0, dtype="float32")
  69. def fake_quant_forward(self, inp, qparams: QParams = None):
  70. # when enable, TQT will do fakequant forward, finetune the scale
  71. return tqt_forward(self.qmin, self.qmax, inp, self.scale)
  72. def set_qparams(self, qparams: QParams):
  73. assert (
  74. qparams.mode == QuantMode.SYMMERTIC
  75. ), "only symmetric quantization is supported by TQT"
  76. if qparams.scale is None:
  77. raise AssertionError("Can not get an initialized scale")
  78. self.scale[...] = F.log(qparams.scale) / math.log(2)
  79. def get_qparams(self):
  80. return create_qparams(QuantMode.SYMMERTIC, self.dtype, scale=2 ** self.scale)
  81. class FakeQuantize(_FakeQuantize):
  82. r"""A module to do quant and dequant according to observer's scale and zero_point.
  83. Args:
  84. dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
  85. quantization dtype of input.
  86. enable: whether do ``normal_forward`` or ``fake_quant_forward``.
  87. """
  88. def fake_quant_forward(self, inp, qparams: QParams = None):
  89. assert (
  90. qparams.dtype_meta is self.dtype
  91. ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
  92. qparams.dtype_meta, self.dtype
  93. )
  94. return fake_quant_tensor(inp, qparams)
  95. class LSQ(_FakeQuantize, QParamsModuleMixin):
  96. r"""LSQ: https://arxiv.org/pdf/1902.08153.pdf Estimating and scaling the
  97. task loss gradient at each weight and activation layer's quantizer step size
  98. Args:
  99. dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
  100. quantization dtype of input.
  101. enable: whether do ``normal_forward`` or ``fake_quant_forward``.
  102. eps: a small value to avoid division by zero. Default: 1e-5
  103. """
  104. def __init__(
  105. self,
  106. dtype: Union[str, QuantDtypeMeta],
  107. enable: bool = True,
  108. eps: float = 1e-5,
  109. **kwargs
  110. ):
  111. super().__init__(dtype=dtype, enable=enable, **kwargs)
  112. self.eps = Tensor(eps, dtype="float32")
  113. self.step_size = Parameter(1.0, dtype="float32")
  114. self.mode = None
  115. self.zero_point = Tensor(0.0, dtype="float32")
  116. self.grad_scale = Tensor(1.0, dtype="float32")
  117. def set_qparams(self, qparams: LSQParams):
  118. self.mode = qparams.mode
  119. if qparams.mode == QuantMode.ASYMMERTIC:
  120. self.zero_point = qparams.zero_point
  121. else:
  122. self.zero_point = Tensor([0.0], dtype="float32")
  123. if qparams.scale is None:
  124. raise AssertionError("Can not get an initialized scale")
  125. init_step_size = qparams.scale
  126. if init_step_size < self.eps:
  127. init_step_size = 0
  128. else:
  129. init_step_size = init_step_size - self.eps
  130. self.step_size = Parameter(init_step_size, dtype="float32")
  131. self.grad_scale = qparams.grad_scale
  132. def fake_quant_forward(self, inp, qparams: LSQParams = None):
  133. step_size = F.abs(self.step_size) + self.eps
  134. return lsq_forward(
  135. self.qmin, self.qmax, inp, step_size, self.zero_point, self.grad_scale
  136. )
  137. def get_qparams(self):
  138. return LSQParams(
  139. mode=self.mode,
  140. dtype_meta=self.dtype,
  141. scale=F.abs(self.step_size.detach()) + self.eps,
  142. zero_point=self.zero_point,
  143. grad_scale=self.grad_scale,
  144. )