You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fake_quant.py 1.9 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. from copy import deepcopy
  2. from typing import Union
  3. from ..core.tensor.dtype import QuantDtypeMeta
  4. from ..quantization.fake_quant import QParamsModuleMixin, _FakeQuantize
  5. from ..quantization.utils import QParams, QuantMode, fake_quant_tensor
  6. class FakeQuantize(_FakeQuantize, QParamsModuleMixin):
  7. r"""A module to do quant and dequant according to :attr:`~.FakeQuantize.qparams`."""
  8. def __init__(
  9. self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
  10. ):
  11. super().__init__(dtype, enable, **kwargs)
  12. self.qparams = None
  13. def fake_quant_forward(self, inp, qparams: QParams = None):
  14. if qparams is None:
  15. qparams = self.get_qparams()
  16. assert (
  17. qparams.dtype_meta is self.dtype
  18. ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
  19. qparams.dtype_meta, self.dtype
  20. )
  21. return fake_quant_tensor(inp, qparams)
  22. def get_qparams(self):
  23. return self.qparams
  24. def set_qparams(self, qparams: QParams):
  25. r"""Initialize :attr:`~.FakeQuantize.qparams`.
  26. Args:
  27. qparams: used to set initial ``scale`` and ``zero_point``.
  28. """
  29. if qparams.scale is None:
  30. raise AssertionError("Can not get an initialized scale")
  31. scale = qparams.scale
  32. if qparams.dtype_meta is None:
  33. qparams.dtype_meta = self.dtype
  34. else:
  35. assert (
  36. qparams.dtype_meta is self.dtype
  37. ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
  38. qparams.dtype_meta, self.dtype
  39. )
  40. dtype_meta = qparams.dtype_meta
  41. zero_point = qparams.zero_point
  42. mode = qparams.mode
  43. self.qparams = QParams(mode, dtype_meta, scale, zero_point)