You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

qconfig.py 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #'
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from collections import namedtuple
  9. from functools import partial
  10. from ..module import Module
  11. from .fake_quant import TQT, FakeQuantize
  12. from .observer import (
  13. ExponentialMovingAverageObserver,
  14. HistogramObserver,
  15. MinMaxObserver,
  16. PassiveObserver,
  17. SyncExponentialMovingAverageObserver,
  18. SyncMinMaxObserver,
  19. )
  20. # use namedtuple to make class immutable, comparable and easy to print
  21. class QConfig(
  22. namedtuple(
  23. "QConfig",
  24. ["weight_observer", "act_observer", "weight_fake_quant", "act_fake_quant"],
  25. )
  26. ):
  27. r"""A config class indicating how to do quantize toward :class:`~.QATModule` 's
  28. ``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage.
  29. Args:
  30. weight_observer: interface to instantiate an :class:`~.Observer` indicating
  31. how to collect scales and zero_point of wegiht.
  32. act_observer: similar to ``weight_observer`` but toward activation.
  33. weight_fake_quant: interface to instantiate a :class:`~.quantization.fake_quant.FakeQuantize` indicating
  34. how to do fake_quant calculation.
  35. act_observer: similar to ``weight_fake_quant`` but toward activation.
  36. Examples:
  37. .. code-block::
  38. # Default EMA QConfig for QAT.
  39. ema_fakequant_qconfig = QConfig(
  40. weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
  41. act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8"),
  42. weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
  43. act_fake_quant=partial(FakeQuantize, dtype="qint8"),
  44. )
  45. Each parameter is a ``class`` rather than an instance. And we recommand using ``functools.partial``
  46. to add initialization parameters of the ``class``, so that don't need to provide parameters in
  47. :meth:`~.QATModule.set_qconfig`.
  48. Usually we choose narrow version dtype (like ``qint8_narrow``) for weight related
  49. paramters and normal version for activation related ones. For the result of
  50. multiplication and addition as ``a * b + c * d``, if four variables are all -128 of
  51. dtype ``qint8``, then the result will be ``2^15`` and cause overflow.
  52. Weights are commonly calculated in this way, so need to narrow qmin to -127.
  53. """
  54. def __new__(cls, weight_observer, act_observer, weight_fake_quant, act_fake_quant):
  55. if isinstance(act_observer, Module) or isinstance(weight_observer, Module):
  56. raise ValueError(
  57. "QConfig must not receive observer instance, please pass observer"
  58. " class generator using `partial(Observer, ...)` instead. Use"
  59. " partial(MyObserver, x=1) to override arguments to constructor if needed"
  60. )
  61. return super().__new__(
  62. cls, weight_observer, act_observer, weight_fake_quant, act_fake_quant
  63. )
  64. min_max_fakequant_qconfig = QConfig(
  65. weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
  66. act_observer=partial(MinMaxObserver, dtype="qint8"),
  67. weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
  68. act_fake_quant=partial(FakeQuantize, dtype="qint8"),
  69. )
  70. ema_fakequant_qconfig = QConfig(
  71. weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
  72. act_observer=partial(ExponentialMovingAverageObserver, dtype="qint8"),
  73. weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
  74. act_fake_quant=partial(FakeQuantize, dtype="qint8"),
  75. )
  76. sync_ema_fakequant_qconfig = QConfig(
  77. weight_observer=partial(SyncMinMaxObserver, dtype="qint8_narrow"),
  78. act_observer=partial(SyncExponentialMovingAverageObserver, dtype="qint8"),
  79. weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
  80. act_fake_quant=partial(FakeQuantize, dtype="qint8"),
  81. )
  82. ema_lowbit_fakequant_qconfig = QConfig(
  83. weight_observer=partial(MinMaxObserver, dtype="qint4"),
  84. act_observer=partial(ExponentialMovingAverageObserver, dtype="qint4"),
  85. weight_fake_quant=partial(FakeQuantize, dtype="qint4"),
  86. act_fake_quant=partial(FakeQuantize, dtype="qint4"),
  87. )
  88. calibration_qconfig = QConfig(
  89. weight_observer=partial(MinMaxObserver, dtype="qint8_narrow"),
  90. act_observer=partial(HistogramObserver, dtype="qint8"),
  91. weight_fake_quant=None,
  92. act_fake_quant=None,
  93. )
  94. tqt_qconfig = QConfig(
  95. weight_observer=None,
  96. act_observer=None,
  97. weight_fake_quant=partial(TQT, dtype="qint8_narrow"),
  98. act_fake_quant=partial(TQT, dtype="qint8"),
  99. )
  100. passive_qconfig = QConfig(
  101. weight_observer=partial(PassiveObserver, dtype="qint8_narrow"),
  102. act_observer=partial(PassiveObserver, dtype="qint8"),
  103. weight_fake_quant=partial(FakeQuantize, dtype="qint8_narrow"),
  104. act_fake_quant=partial(FakeQuantize, dtype="qint8"),
  105. )
  106. easyquant_qconfig = passive_qconfig

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台