You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

qconfig.py 2.4 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from functools import partial
  9. from ..module import Module
  10. from .fake_quant import FakeQuantize
  11. from .observer import ExponentialMovingAverageObserver, MinMaxObserver
  12. class QConfig:
  13. """
  14. A config class indicating how to do quantize toward :class:`~.QATModule`'s
  15. ``activation`` and ``weight``.
  16. And ``fake_quant`` parameter to indicate
  17. See :meth:`~.QATModule.set_qconfig` for detail usage.
  18. :param weight_observer: interface to instantiate an :class:`~.Observer` indicating
  19. - how to collect scales and zero_point of wegiht.
  20. :param act_observer: similar to ``weight_observer`` but toward activation.
  21. :param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating
  22. how to do fake_quant calculation. can be invoked multi times to get different
  23. instance for each target tensor, for better control on enable and disable.
  24. Examples:
  25. .. code-block::
  26. # Default EMA QConfig for QAT.
  27. ema_fakequant_qconfig = QConfig(
  28. weight_observer=MinMaxObserver,
  29. act_observer=ExponentialMovingAverageObserver,
  30. fake_quant=FakeQuantize,
  31. )
  32. """
  33. def __init__(
  34. self, act_observer, weight_observer, fake_quant,
  35. ):
  36. if isinstance(act_observer, Module) or isinstance(weight_observer, Module):
  37. raise ValueError(
  38. "QConfig must not receive observer instance, please pass observer"
  39. " class generator using `partial(Observer, ...)` instead. Use"
  40. " partial(MyObserver, x=1) to override arguments to constructor if needed"
  41. )
  42. self.act_observer = act_observer
  43. self.weight_observer = weight_observer
  44. self.fake_quant = fake_quant
  45. # Default QAT QConfigs
  46. min_max_fakequant_qconfig = QConfig(
  47. weight_observer=MinMaxObserver,
  48. act_observer=MinMaxObserver,
  49. fake_quant=FakeQuantize,
  50. )
  51. ema_fakequant_qconfig = QConfig(
  52. weight_observer=MinMaxObserver,
  53. act_observer=ExponentialMovingAverageObserver,
  54. fake_quant=FakeQuantize,
  55. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台