You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225
  1. import abc
  2. from enum import Enum
  3. from functools import partial, update_wrapper, wraps
  4. from typing import Union
  5. import numpy as np
  6. from .. import functional as F
  7. from ..autodiff import Function
  8. from ..core._imperative_rt.core2 import apply
  9. from ..core.ops import builtin
  10. from ..core.tensor.dtype import (
  11. QuantDtypeMeta,
  12. _builtin_quant_dtypes,
  13. create_quantized_dtype,
  14. )
  15. from ..tensor import Tensor
  16. class Round(Function):
  17. r"""The functional round have no grad and can not use for quantization-aware-training.
  18. We use Function and STE(Straight-Through Estimator) to implement backward propagation.
  19. """
  20. def forward(self, x):
  21. return F.round(x)
  22. def backward(self, output_grads):
  23. return output_grads
  24. def tqt_forward(qmin, qmax, inp, scale):
  25. op = builtin.TQT(qmin=qmin, qmax=qmax)
  26. (output,) = apply(op, inp, scale)
  27. return output
  28. def lsq_forward(qmin, qmax, inp, step_size, zero_point=None, scale_grad=None):
  29. if zero_point is None:
  30. zero_point = Tensor([0.0], dtype=np.float32)
  31. if scale_grad is None:
  32. scale_grad = Tensor([1.0], dtype=np.float32)
  33. op = builtin.LSQ(qmin=qmin, qmax=qmax)
  34. (output,) = apply(op, inp, step_size, zero_point, scale_grad)
  35. return output
  36. def register_method_to_class(cls):
  37. def decorator(func):
  38. @wraps(func)
  39. def wrapper(self, *args, **kwargs):
  40. return func(self, *args, **kwargs)
  41. if isinstance(func, partial):
  42. update_wrapper(func, func.func)
  43. setattr(cls, func.__name__, wrapper)
  44. return func
  45. return decorator
  46. class QuantMode(Enum):
  47. r"""Quantization mode enumerate class."""
  48. SYMMERTIC = 1
  49. ASYMMERTIC = 2
  50. class QParams:
  51. r"""To standardize FakeQuant, Observer and Tensor's qparams format. If custom
  52. qparams is needed, inherit this class and add custom ``__slots__``.
  53. """
  54. __slots__ = "mode", "dtype_meta", "scale", "zero_point"
  55. def __init__(
  56. self,
  57. mode: QuantMode,
  58. dtype_meta: QuantDtypeMeta,
  59. scale: Tensor,
  60. zero_point: Tensor,
  61. ):
  62. self.mode = mode
  63. self.dtype_meta = dtype_meta
  64. self.scale = scale
  65. self.zero_point = zero_point
  66. def update(self, qparams: "QParams"):
  67. for key in self.__slots__:
  68. setattr(self, key, getattr(qparams, key))
  69. def __eq__(self, other):
  70. if len(self.__slots__) != len(other.__slots__):
  71. return False
  72. for key in self.__slots__:
  73. if not hasattr(other, key) or getattr(self, key) != getattr(other, key):
  74. return False
  75. return True
  76. def __repr__(self):
  77. content = ", ".join(
  78. ["{}={}".format(key, getattr(self, key)) for key in self.__slots__]
  79. )
  80. return "QParams({})".format(content)
  81. class LSQParams(QParams):
  82. r"""LSQ qparams with extra grad_scale slot."""
  83. __slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale"
  84. def __init__(
  85. self,
  86. mode: QuantMode,
  87. dtype_meta: QuantDtypeMeta,
  88. scale: Tensor,
  89. zero_point: Tensor,
  90. grad_scale: Tensor,
  91. ):
  92. super().__init__(mode, dtype_meta, scale, zero_point)
  93. self.grad_scale = grad_scale
  94. class QParamsModuleMixin(abc.ABC):
  95. def get_quantized_dtype(self):
  96. qparams = self.get_qparams()
  97. dtype = qparams.dtype_meta
  98. scale = float(qparams.scale.numpy()) if qparams.scale is not None else None
  99. zero_point = (
  100. int(qparams.zero_point.numpy()) if qparams.zero_point is not None else None
  101. )
  102. return create_quantized_dtype(dtype, scale, zero_point)
  103. @abc.abstractmethod
  104. def get_qparams(self) -> QParams:
  105. pass
  106. _builtin_qparams = {
  107. QuantMode.SYMMERTIC: partial(QParams, mode=QuantMode.SYMMERTIC),
  108. QuantMode.ASYMMERTIC: partial(QParams, mode=QuantMode.ASYMMERTIC),
  109. }
  110. def create_qparams(
  111. mode: QuantMode = QuantMode.SYMMERTIC,
  112. dtype_meta: Union[str, QuantDtypeMeta] = None,
  113. scale: Tensor = None,
  114. zero_point: Tensor = None,
  115. ):
  116. r"""
  117. Args:
  118. mode: QuantMode:
  119. dtype_meta: Union[str:
  120. QuantDtypeMeta]:
  121. scale: Tensor:
  122. zero_point: Tensor:
  123. """
  124. if isinstance(dtype_meta, str):
  125. dtype_meta = _builtin_quant_dtypes[dtype_meta]
  126. if mode is None:
  127. return QParams(mode, dtype_meta, scale, zero_point)
  128. assert isinstance(mode, QuantMode)
  129. return _builtin_qparams[mode](
  130. dtype_meta=dtype_meta, scale=scale, zero_point=zero_point
  131. )
  132. def fake_quant_tensor(inp: Tensor, qparams: QParams) -> Tensor:
  133. """Apply fake quantization to the inp tensor.
  134. Args:
  135. inp: the input tensor which need to be faked.
  136. qparams: to get mode, qmin, qmax, scale and zero_point from.
  137. """
  138. scale = qparams.scale
  139. if qparams.mode == QuantMode.ASYMMERTIC:
  140. zero_point = qparams.zero_point
  141. else:
  142. zero_point = Tensor([0.0], dtype=np.float32)
  143. qmin = qparams.dtype_meta.qmin
  144. qmax = qparams.dtype_meta.qmax
  145. op = builtin.FakeQuant(qmin=qmin, qmax=qmax)
  146. return apply(op, inp, scale, zero_point)[0]
  147. def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
  148. """Apply fake quantization to bias, with the special scale from input tensor
  149. and weight tensor, the quantized type set to qint32 also.
  150. Args:
  151. bias: the bias tensor which need to be faked.
  152. inp: the input tensor which contain the quantization parameters.
  153. w_qat: the weight tensor which contain the quantization parameters.
  154. Warning:
  155. Only work for symmetric quantization method now.
  156. """
  157. b_qat = bias
  158. if (
  159. getattr(inp, "qparams", None) is not None
  160. and getattr(w_qat, "qparams", None) is not None
  161. and bias is not None
  162. ):
  163. inp_params = inp.qparams
  164. w_params = w_qat.qparams
  165. if inp_params.scale is not None and w_params.scale is not None:
  166. assert inp_params.mode == w_params.mode, "incompatible QuantMode"
  167. # TODO: support quint8 dtype.
  168. assert (
  169. inp_params.dtype_meta.np_dtype_str == "int8"
  170. and w_params.dtype_meta.np_dtype_str == "int8"
  171. ), "fake_quant_bias only support int8 like dtype now"
  172. # use the same mode with weight.
  173. # TODO: avoid hardcode
  174. b_dtype = _builtin_quant_dtypes["qint32"]
  175. b_param = create_qparams(
  176. w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale
  177. )
  178. b_qat = fake_quant_tensor(bias, b_param)
  179. b_qat.qparams.update(b_param)
  180. return b_qat