You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

qconfig.py 3.1 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from functools import partial
  9. from ..module import Module
  10. from .fake_quant import FakeQuantize
  11. from .observer import ExponentialMovingAverageObserver, MinMaxObserver
  12. class QConfig:
  13. """
  14. A config class indicating how to do quantize toward :class:`~.QATModule`'s
  15. ``activation``, ``weight`` and ``bias``.
  16. And ``fake_quant`` parameter to indicate
  17. See :meth:`~.QATModule.set_qconfig` for detail usage.
  18. :param inp_observer: interface to instantiate an :class:`~.Observer` indicating
  19. how to collect scales and zero_point of input.
  20. :param weight_observer: similar to ``inp_observer`` but toward weight.
  21. :param act_observer: similar to ``inp_observer`` but toward activation.
  22. :param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating
  23. how to do fake_quant calculation. can be invoked multi times to get different
  24. instance for each target tensor, for better control on enable and disable.
  25. :param bias_fake_quant: similar to ``fake_quant``, but usually need to set ``dtype``
  26. in advance, for bias's dtype is unable to be inferred from observer.
  27. Examples:
  28. .. code-block::
  29. # Default EMA QConfig for QAT.
  30. ema_fakequant_qconfig = QConfig(
  31. inp_observer=ExponentialMovingAverageObserver,
  32. weight_observer=ExponentialMovingAverageObserver,
  33. act_observer=ExponentialMovingAverageObserver,
  34. fake_quant=FakeQuantize,
  35. )
  36. """
  37. def __init__(
  38. self, act_observer, weight_observer, inp_observer, fake_quant, bias_fake_quant,
  39. ):
  40. if (
  41. isinstance(act_observer, Module)
  42. or isinstance(weight_observer, Module)
  43. or isinstance(inp_observer, Module)
  44. ):
  45. raise ValueError(
  46. "QConfig must not receive observer instance, please pass observer"
  47. " class generator using `partial(Observer, ...)` instead. Use"
  48. " partial(MyObserver, x=1) to override arguments to constructor if needed"
  49. )
  50. self.act_observer = act_observer
  51. self.weight_observer = weight_observer
  52. self.inp_observer = inp_observer
  53. self.fake_quant = fake_quant
  54. self.bias_fake_quant = bias_fake_quant
  55. # Default QAT QConfigs
  56. min_max_fakequant_qconfig = QConfig(
  57. inp_observer=MinMaxObserver,
  58. weight_observer=MinMaxObserver,
  59. act_observer=MinMaxObserver,
  60. fake_quant=FakeQuantize,
  61. bias_fake_quant=partial(FakeQuantize, dtype="qint32"),
  62. )
  63. ema_fakequant_qconfig = QConfig(
  64. inp_observer=ExponentialMovingAverageObserver,
  65. weight_observer=MinMaxObserver,
  66. act_observer=ExponentialMovingAverageObserver,
  67. fake_quant=FakeQuantize,
  68. bias_fake_quant=partial(FakeQuantize, dtype="qint32"),
  69. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台