You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

qconfig.py 4.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. from collections import namedtuple
  2. from functools import partial
  3. from ..module import Module
  4. from .fake_quant import TQT, FakeQuantize
  5. from .observer import (
  6. ExponentialMovingAverageObserver,
  7. HistogramObserver,
  8. MinMaxObserver,
  9. PassiveObserver,
  10. SyncExponentialMovingAverageObserver,
  11. SyncMinMaxObserver,
  12. )
  13. # use namedtuple to make class immutable, comparable and easy to print
  14. class QConfig(
  15. namedtuple(
  16. "QConfig",
  17. ["weight_observer", "act_observer", "weight_fake_quant", "act_fake_quant"],
  18. )
  19. ):
  20. r"""A config class indicating how to do quantize toward :class:`~.QATModule` 's
  21. ``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage.
  22. Args:
  23. weight_observer: interface to instantiate an :class:`~.Observer` indicating
  24. how to collect scales and zero_point of wegiht.
  25. act_observer: similar to ``weight_observer`` but toward activation.
  26. weight_fake_quant: interface to instantiate a :class:`~.quantization.fake_quant.FakeQuantize` indicating
  27. how to do fake_quant calculation.
  28. act_observer: similar to ``weight_fake_quant`` but toward activation.
  29. Examples:
  30. .. code-block::
  31. # Default EMA QConfig for QAT.
  32. ema_fakequant_qconfig = QConfig(
  33. weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
  34. act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8"),
  35. weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
  36. act_fake_quant=partial(FakeQuantize, dtype="qint8"),
  37. )
  38. Each parameter is a ``class`` rather than an instance. And we recommand using ``functools.partial``
  39. to add initialization parameters of the ``class``, so that don't need to provide parameters in
  40. :meth:`~.QATModule.set_qconfig`.
  41. Usually we choose narrow version dtype (like ``qint8_narrow``) for weight related
  42. paramters and normal version for activation related ones. For the result of
  43. multiplication and addition as ``a * b + c * d``, if four variables are all -128 of
  44. dtype ``qint8``, then the result will be ``2^15`` and cause overflow.
  45. Weights are commonly calculated in this way, so need to narrow qmin to -127.
  46. """
  47. def __new__(cls, weight_observer, act_observer, weight_fake_quant, act_fake_quant):
  48. if isinstance(act_observer, Module) or isinstance(weight_observer, Module):
  49. raise ValueError(
  50. "QConfig must not receive observer instance, please pass observer"
  51. " class generator using `partial(Observer, ...)` instead. Use"
  52. " partial(MyObserver, x=1) to override arguments to constructor if needed"
  53. )
  54. return super().__new__(
  55. cls, weight_observer, act_observer, weight_fake_quant, act_fake_quant
  56. )
  57. min_max_fakequant_qconfig = QConfig(
  58. weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
  59. act_observer=partial(MinMaxObserver, dtype="qint8"),
  60. weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
  61. act_fake_quant=partial(FakeQuantize, dtype="qint8"),
  62. )
  63. ema_fakequant_qconfig = QConfig(
  64. weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
  65. act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8"),
  66. weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
  67. act_fake_quant=partial(FakeQuantize, dtype="qint8"),
  68. )
  69. sync_ema_fakequant_qconfig = QConfig(
  70. weight_observer=partial(SyncMinMaxObserver, dtype="qint8_narrow"),
  71. act_observer=partial(SyncExponentialMovingAverageObserver, dtype="qint8"),
  72. weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
  73. act_fake_quant=partial(FakeQuantize, dtype="qint8"),
  74. )
  75. ema_lowbit_fakequant_qconfig = QConfig(
  76. weight_observer=partial(MinMaxObserver, dtype="qint4"),
  77. act_observer=partial(ExponentialMovingAverageObserver, dtype="qint4"),
  78. weight_fake_quant=partial(FakeQuantize, dtype="qint4"),
  79. act_fake_quant=partial(FakeQuantize, dtype="qint4"),
  80. )
  81. calibration_qconfig = QConfig(
  82. weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
  83. act_observer=partial(HistogramObserver, dtype="qint8"),
  84. weight_fake_quant=None,
  85. act_fake_quant=None,
  86. )
  87. tqt_qconfig = QConfig(
  88. weight_observer=None,
  89. act_observer=None,
  90. weight_fake_quant=partial(TQT, dtype="qint8_narrow"),
  91. act_fake_quant=partial(TQT, dtype="qint8"),
  92. )
  93. passive_qconfig = QConfig(
  94. weight_observer=partial(PassiveObserver, dtype="qint8_narrow"),
  95. act_observer=partial(PassiveObserver, dtype="qint8"),
  96. weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
  97. act_fake_quant=partial(FakeQuantize, dtype="qint8"),
  98. )
  99. easyquant_qconfig = passive_qconfig