You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

qconfig.py 6.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #'
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from functools import partial
  9. from ..module import Module
  10. from .fake_quant import TQT, FakeQuantize
  11. from .observer import (
  12. ExponentialMovingAverageObserver,
  13. HistogramObserver,
  14. MinMaxObserver,
  15. PassiveObserver,
  16. SyncExponentialMovingAverageObserver,
  17. SyncMinMaxObserver,
  18. )
  19. class QConfig:
  20. r"""
  21. A config class indicating how to do quantize toward :class:`~.QATModule`'s
  22. ``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage.
  23. :param weight_observer: interface to instantiate an :class:`~.Observer` indicating
  24. how to collect scales and zero_point of wegiht.
  25. :param act_observer: similar to ``weight_observer`` but toward activation.
  26. :param weight_fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating
  27. how to do fake_quant calculation.
  28. :param act_observer: similar to ``weight_fake_quant`` but toward activation.
  29. Examples:
  30. .. code-block::
  31. # Default EMA QConfig for QAT.
  32. ema_fakequant_qconfig = QConfig(
  33. weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
  34. act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False),
  35. weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
  36. act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
  37. )
  38. Each parameter is a ``class`` rather than an instance. And we recommand using ``functools.partial``
  39. to add initialization parameters of the ``class``, so that don't need to provide parameters in
  40. :meth:`~.QATModule.set_qconfig`.
  41. Usually we set ``narrow_range`` of weight related paramters to ``True`` and of activation related
  42. parameters to ``False``. For the result of multiplication and addition as ``a * b + c * d``, if
  43. four variables are all -128 of dtype ``qint8``, then the result will be ``2^15`` and cause overflow.
  44. Weights are commonly calculated in this way, so needed to narrow the range.
  45. """
  46. def __init__(
  47. self, weight_observer, act_observer, weight_fake_quant, act_fake_quant
  48. ):
  49. if isinstance(act_observer, Module) or isinstance(weight_observer, Module):
  50. raise ValueError(
  51. "QConfig must not receive observer instance, please pass observer"
  52. " class generator using `partial(Observer, ...)` instead. Use"
  53. " partial(MyObserver, x=1) to override arguments to constructor if needed"
  54. )
  55. self.weight_observer = weight_observer
  56. self.act_observer = act_observer
  57. self.weight_fake_quant = weight_fake_quant
  58. self.act_fake_quant = act_fake_quant
  59. def __eq__(self, other):
  60. def eq(a, b):
  61. if isinstance(a, partial) and isinstance(b, partial):
  62. return all(
  63. [a.func == b.func, a.args == b.args, a.keywords == b.keywords]
  64. )
  65. else:
  66. return a == b
  67. return (
  68. eq(self.weight_observer, other.weight_observer)
  69. and eq(self.act_observer, other.act_observer)
  70. and eq(self.weight_fake_quant, other.weight_fake_quant)
  71. and eq(self.act_fake_quant, other.act_fake_quant)
  72. )
  73. min_max_fakequant_qconfig = QConfig(
  74. weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
  75. act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False),
  76. weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
  77. act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
  78. )
  79. ema_fakequant_qconfig = QConfig(
  80. weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
  81. act_observer=partial(
  82. ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False
  83. ),
  84. weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
  85. act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
  86. )
  87. sync_ema_fakequant_qconfig = QConfig(
  88. weight_observer=partial(SyncMinMaxObserver, dtype="qint8", narrow_range=True),
  89. act_observer=partial(
  90. SyncExponentialMovingAverageObserver, dtype="qint8", narrow_range=False
  91. ),
  92. weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
  93. act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
  94. )
  95. ema_lowbit_fakequant_qconfig = QConfig(
  96. weight_observer=partial(MinMaxObserver, dtype="qint4", narrow_range=False),
  97. act_observer=partial(
  98. ExponentialMovingAverageObserver, dtype="qint4", narrow_range=False
  99. ),
  100. weight_fake_quant=partial(FakeQuantize, dtype="qint4", narrow_range=False),
  101. act_fake_quant=partial(FakeQuantize, dtype="qint4", narrow_range=False),
  102. )
  103. calibration_qconfig = QConfig(
  104. weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
  105. act_observer=partial(HistogramObserver, dtype="qint8", narrow_range=False),
  106. weight_fake_quant=None,
  107. act_fake_quant=None,
  108. )
  109. tqt_qconfig = QConfig(
  110. weight_observer=None,
  111. act_observer=None,
  112. weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True),
  113. act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False),
  114. )
  115. passive_qconfig = QConfig(
  116. weight_observer=partial(PassiveObserver, dtype="qint8", narrow_range=True),
  117. act_observer=partial(PassiveObserver, dtype="qint8", narrow_range=False),
  118. weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
  119. act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
  120. )
  121. easyquant_qconfig = passive_qconfig

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台