You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

qconfig.py 2.7 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #'
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from ..module import Module
  9. from .fake_quant import TQT, FakeQuantize
  10. from .observer import (
  11. ExponentialMovingAverageObserver,
  12. HistogramObserver,
  13. MinMaxObserver,
  14. )
  15. class QConfig:
  16. r"""
  17. A config class indicating how to do quantize toward :class:`~.QATModule`'s
  18. ``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage.
  19. :param weight_observer: interface to instantiate an :class:`~.Observer` indicating
  20. how to collect scales and zero_point of wegiht.
  21. :param act_observer: similar to ``weight_observer`` but toward activation.
  22. :param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating
  23. how to do fake_quant calculation. can be invoked multi times to get different
  24. instance for each target tensor, for better control on enable and disable.
  25. Examples:
  26. .. code-block::
  27. # Default EMA QConfig for QAT.
  28. ema_fakequant_qconfig = QConfig(
  29. weight_observer=MinMaxObserver,
  30. act_observer=ExponentialMovingAverageObserver,
  31. fake_quant=FakeQuantize,
  32. )
  33. """
  34. def __init__(
  35. self, act_observer, weight_observer, fake_quant,
  36. ):
  37. if isinstance(act_observer, Module) or isinstance(weight_observer, Module):
  38. raise ValueError(
  39. "QConfig must not receive observer instance, please pass observer"
  40. " class generator using `partial(Observer, ...)` instead. Use"
  41. " partial(MyObserver, x=1) to override arguments to constructor if needed"
  42. )
  43. self.act_observer = act_observer
  44. self.weight_observer = weight_observer
  45. self.fake_quant = fake_quant
  46. tqt_quant_qconfig = QConfig(
  47. weight_observer=ExponentialMovingAverageObserver,
  48. act_observer=ExponentialMovingAverageObserver,
  49. fake_quant=TQT,
  50. )
  51. # Default QAT QConfigs
  52. min_max_fakequant_qconfig = QConfig(
  53. weight_observer=MinMaxObserver,
  54. act_observer=MinMaxObserver,
  55. fake_quant=FakeQuantize,
  56. )
  57. ema_fakequant_qconfig = QConfig(
  58. weight_observer=MinMaxObserver,
  59. act_observer=ExponentialMovingAverageObserver,
  60. fake_quant=FakeQuantize,
  61. )
  62. calibration_qconfig = QConfig(
  63. weight_observer=MinMaxObserver, act_observer=HistogramObserver, fake_quant=None,
  64. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台