You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

qconfig.py 2.5 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from ..module import Module
  9. from .fake_quant import FakeQuantize
  10. from .observer import (
  11. ExponentialMovingAverageObserver,
  12. HistogramObserver,
  13. MinMaxObserver,
  14. )
  15. class QConfig:
  16. """
  17. A config class indicating how to do quantize toward :class:`~.QATModule`'s
  18. ``activation`` and ``weight``.
  19. And ``fake_quant`` parameter to indicate
  20. See :meth:`~.QATModule.set_qconfig` for detail usage.
  21. :param weight_observer: interface to instantiate an :class:`~.Observer` indicating
  22. - how to collect scales and zero_point of wegiht.
  23. :param act_observer: similar to ``weight_observer`` but toward activation.
  24. :param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating
  25. how to do fake_quant calculation. can be invoked multi times to get different
  26. instance for each target tensor, for better control on enable and disable.
  27. Examples:
  28. .. code-block::
  29. # Default EMA QConfig for QAT.
  30. ema_fakequant_qconfig = QConfig(
  31. weight_observer=MinMaxObserver,
  32. act_observer=ExponentialMovingAverageObserver,
  33. fake_quant=FakeQuantize,
  34. )
  35. """
  36. def __init__(
  37. self, act_observer, weight_observer, fake_quant,
  38. ):
  39. if isinstance(act_observer, Module) or isinstance(weight_observer, Module):
  40. raise ValueError(
  41. "QConfig must not receive observer instance, please pass observer"
  42. " class generator using `partial(Observer, ...)` instead. Use"
  43. " partial(MyObserver, x=1) to override arguments to constructor if needed"
  44. )
  45. self.act_observer = act_observer
  46. self.weight_observer = weight_observer
  47. self.fake_quant = fake_quant
  48. # Default QAT QConfigs
  49. min_max_fakequant_qconfig = QConfig(
  50. weight_observer=MinMaxObserver,
  51. act_observer=MinMaxObserver,
  52. fake_quant=FakeQuantize,
  53. )
  54. ema_fakequant_qconfig = QConfig(
  55. weight_observer=MinMaxObserver,
  56. act_observer=ExponentialMovingAverageObserver,
  57. fake_quant=FakeQuantize,
  58. )
  59. calibration_qconfig = QConfig(
  60. weight_observer=MinMaxObserver, act_observer=HistogramObserver, fake_quant=None,
  61. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台