You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fake_quant.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import math
  9. from typing import Union
  10. from .. import functional as F
  11. from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes
  12. from ..logger import get_logger
  13. from ..module import Module
  14. from ..tensor import Parameter, Tensor
  15. from .utils import (
  16. LSQParams,
  17. QParams,
  18. QParamsModuleMixin,
  19. QuantMode,
  20. create_qparams,
  21. fake_quant_tensor,
  22. lsq_forward,
  23. tqt_forward,
  24. )
  25. logger = get_logger(__name__)
  26. class _FakeQuantize(Module):
  27. def __init__(
  28. self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
  29. ):
  30. super().__init__()
  31. if isinstance(dtype, str):
  32. if not dtype in _builtin_quant_dtypes:
  33. raise ValueError(
  34. "unknown dtype: {}, only support {}".format(
  35. dtype, _builtin_quant_dtypes.keys()
  36. )
  37. )
  38. dtype = _builtin_quant_dtypes[dtype]
  39. if "narrow_range" in kwargs:
  40. del kwargs["narrow_range"]
  41. logger.warning(
  42. "FakeQuantize currently has no narrow_range param "
  43. "so it is ignored here",
  44. exc_info=DeprecationWarning,
  45. )
  46. self.dtype = dtype
  47. self.qmin = dtype.qmin
  48. self.qmax = dtype.qmax
  49. self.enabled = enable
  50. def enable(self):
  51. self.enabled = True
  52. def disable(self):
  53. self.enabled = False
  54. def fake_quant_forward(self, inp, qparams: QParams = None):
  55. raise NotImplementedError
  56. def normal_forward(self, inp, qparams: QParams = None):
  57. return inp
  58. def forward(self, inp, qparams: QParams = None):
  59. if self.enabled:
  60. return self.fake_quant_forward(inp, qparams=qparams)
  61. else:
  62. return self.normal_forward(inp, qparams=qparams)
  63. class TQT(_FakeQuantize, QParamsModuleMixin):
  64. r"""TQT: https://arxiv.org/abs/1903.08066 Trained Quantization Thresholds
  65. for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks.
  66. Args:
  67. dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
  68. quantization dtype of input.
  69. enable: whether do ``normal_forward`` or ``fake_quant_forward``.
  70. """
  71. def __init__(
  72. self, dtype: Union[str, QuantDtypeMeta], enable: bool = True, **kwargs
  73. ):
  74. super().__init__(dtype, enable, **kwargs)
  75. self.scale = Parameter(0.0, dtype="float32")
  76. def fake_quant_forward(self, inp, qparams: QParams = None):
  77. # when enable, TQT will do fakequant forward, finetune the scale
  78. return tqt_forward(self.qmin, self.qmax, inp, self.scale)
  79. def set_qparams(self, qparams: QParams):
  80. assert (
  81. qparams.mode == QuantMode.SYMMERTIC
  82. ), "only symmetric quantization is supported by TQT"
  83. if qparams.scale is None:
  84. raise AssertionError("Can not get an initialized scale")
  85. self.scale[...] = F.log(qparams.scale) / math.log(2)
  86. def get_qparams(self):
  87. return create_qparams(QuantMode.SYMMERTIC, self.dtype, scale=2 ** self.scale)
  88. class FakeQuantize(_FakeQuantize):
  89. r"""A module to do quant and dequant according to observer's scale and zero_point.
  90. Args:
  91. dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
  92. quantization dtype of input.
  93. enable: whether do ``normal_forward`` or ``fake_quant_forward``.
  94. """
  95. def fake_quant_forward(self, inp, qparams: QParams = None):
  96. assert (
  97. qparams.dtype_meta is self.dtype
  98. ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
  99. qparams.dtype_meta, self.dtype
  100. )
  101. return fake_quant_tensor(inp, qparams)
  102. class LSQ(_FakeQuantize, QParamsModuleMixin):
  103. r"""LSQ: https://arxiv.org/pdf/1902.08153.pdf Estimating and scaling the
  104. task loss gradient at each weight and activation layer's quantizer step size
  105. Args:
  106. dtype: a string or :class:`~.QuantDtypeMeta` indicating the target
  107. quantization dtype of input.
  108. enable: whether do ``normal_forward`` or ``fake_quant_forward``.
  109. eps: a small value to avoid division by zero. Default: 1e-5
  110. """
  111. def __init__(
  112. self,
  113. dtype: Union[str, QuantDtypeMeta],
  114. enable: bool = True,
  115. eps: float = 1e-5,
  116. **kwargs
  117. ):
  118. super().__init__(dtype=dtype, enable=enable, **kwargs)
  119. self.eps = Tensor(eps, dtype="float32")
  120. self.step_size = Parameter(1.0, dtype="float32")
  121. self.mode = None
  122. self.zero_point = Tensor(0.0, dtype="float32")
  123. self.grad_scale = Tensor(1.0, dtype="float32")
  124. def set_qparams(self, qparams: LSQParams):
  125. self.mode = qparams.mode
  126. if qparams.mode == QuantMode.ASYMMERTIC:
  127. self.zero_point = qparams.zero_point
  128. else:
  129. self.zero_point = Tensor([0.0], dtype="float32")
  130. if qparams.scale is None:
  131. raise AssertionError("Can not get an initialized scale")
  132. init_step_size = qparams.scale
  133. if init_step_size < self.eps:
  134. init_step_size = 0
  135. else:
  136. init_step_size = init_step_size - self.eps
  137. self.step_size = Parameter(init_step_size, dtype="float32")
  138. self.grad_scale = qparams.grad_scale
  139. def fake_quant_forward(self, inp, qparams: LSQParams = None):
  140. step_size = F.abs(self.step_size) + self.eps
  141. return lsq_forward(
  142. self.qmin, self.qmax, inp, step_size, self.zero_point, self.grad_scale
  143. )
  144. def get_qparams(self):
  145. return LSQParams(
  146. mode=self.mode,
  147. dtype_meta=self.dtype,
  148. scale=F.abs(self.step_size.detach()) + self.eps,
  149. zero_point=self.zero_point,
  150. grad_scale=self.grad_scale,
  151. )

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台