You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fake_quant.py 3.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import math
  9. from .. import functional as F
  10. from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype
  11. from ..module import Module
  12. from ..tensor import Parameter, Tensor
  13. from .utils import QuantMode, fake_quant_tensor, get_qparam_dict, tqt_forward
  14. class _FakeQuantize(Module):
  15. r"""
  16. A Basic Fake Quant module.
  17. :param dtype: a string indicating the target quantization type of input.
  18. :param narrow_range: whether the absolute value of ``qmin`` is the same as ``qmax``,
  19. instead of 1 greater. Usually True for weight and False for activation.
  20. :param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
  21. """
  22. def __init__(
  23. self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs
  24. ):
  25. super().__init__()
  26. if not dtype in _metadata_dict.keys():
  27. raise ValueError(
  28. "unknown dtype: {}, only support {}".format(
  29. dtype, _metadata_dict.keys()
  30. )
  31. )
  32. self.dtype = dtype
  33. self.narrow_range = narrow_range
  34. self.qmin = (
  35. -_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin
  36. )
  37. self.qmax = _metadata_dict[dtype].qmax
  38. self.enabled = enable
  39. def enable(self):
  40. self.enabled = True
  41. def disable(self):
  42. self.enabled = False
  43. def fake_quant_forward(self, inp, q_dict=None):
  44. return inp
  45. def normal_foward(self, inp, q_dict=None):
  46. return inp
  47. def forward(self, inp, q_dict=None):
  48. if self.enabled:
  49. return self.fake_quant_forward(inp, q_dict=q_dict)
  50. else:
  51. return self.normal_foward(inp, q_dict=q_dict)
  52. class TQT(_FakeQuantize):
  53. r"""
  54. TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds
  55. for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks.
  56. """
  57. def __init__(
  58. self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs
  59. ):
  60. super().__init__(dtype, narrow_range, enable, **kwargs)
  61. self.scale = Parameter(0.0, dtype="float32")
  62. def fake_quant_forward(self, inp, q_dict=None):
  63. # when enable, TQT will do fakequant forward, finetune the scale
  64. return tqt_forward(self.qmin, self.qmax, inp, self.scale)
  65. def get_qparams(self):
  66. q_dict = get_qparam_dict(QuantMode.SYMMERTIC)
  67. q_dict["scale"] = 2 ** self.scale.detach()
  68. return q_dict
  69. def set_qparams(self, q_dict):
  70. assert (
  71. q_dict["mode"] == QuantMode.SYMMERTIC
  72. ), "only symmetric quantization is supported by TQT"
  73. if "scale" not in q_dict or q_dict["scale"] is None:
  74. raise AssertionError("Can not get an initialized scale")
  75. self.scale._reset(F.log(q_dict["scale"]) / math.log(2))
  76. def get_dtype(self):
  77. q_dict = self.get_qparams()
  78. scale = None if "scale" not in q_dict else q_dict["scale"].numpy()
  79. zero_point = (
  80. None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()
  81. )
  82. return get_quantized_dtype(self.dtype, scale, zero_point)
  83. class FakeQuantize(_FakeQuantize):
  84. r"""
  85. A module to do quant and dequant according to observer's scale and zero_point.
  86. """
  87. def fake_quant_forward(self, inp, q_dict=None):
  88. return fake_quant_tensor(inp, self.qmin, self.qmax, q_dict)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台