You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 7.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import abc
  9. from enum import Enum
  10. from functools import partial, update_wrapper, wraps
  11. from typing import Union
  12. import numpy as np
  13. from .. import functional as F
  14. from ..autodiff import Function
  15. from ..core._imperative_rt.core2 import apply
  16. from ..core.ops import builtin
  17. from ..core.tensor.dtype import (
  18. QuantDtypeMeta,
  19. _builtin_quant_dtypes,
  20. create_quantized_dtype,
  21. )
  22. from ..tensor import Tensor
  23. class Round(Function):
  24. """
  25. The functional round have no grad and can not use for quantization-aware-training.
  26. We use Function and STE(Straight-Through Estimator) to implement backward propagation.
  27. """
  28. def forward(self, x):
  29. return F.round(x)
  30. def backward(self, output_grads):
  31. return output_grads
  32. def tqt_forward(qmin, qmax, inp, scale):
  33. op = builtin.TQT(qmin=qmin, qmax=qmax)
  34. (output,) = apply(op, inp, scale)
  35. return output
  36. def lsq_forward(qmin, qmax, inp, step_size, zero_point=None, scale_grad=None):
  37. if zero_point is None:
  38. zero_point = Tensor([0.0], dtype=np.float32)
  39. if scale_grad is None:
  40. scale_grad = Tensor([1.0], dtype=np.float32)
  41. op = builtin.LSQ(qmin=qmin, qmax=qmax)
  42. (output,) = apply(op, inp, step_size, zero_point, scale_grad)
  43. return output
  44. def register_method_to_class(cls):
  45. def decorator(func):
  46. @wraps(func)
  47. def wrapper(self, *args, **kwargs):
  48. return func(self, *args, **kwargs)
  49. if isinstance(func, partial):
  50. update_wrapper(func, func.func)
  51. setattr(cls, func.__name__, wrapper)
  52. return func
  53. return decorator
  54. class QuantMode(Enum):
  55. """
  56. Quantization mode enumerate class.
  57. """
  58. SYMMERTIC = 1
  59. ASYMMERTIC = 2
  60. class QParams:
  61. """
  62. To standardize FakeQuant, Observer and Tensor's qparams format. If custom
  63. qparams is needed, inherit this class and add custom ``__slots__``.
  64. """
  65. __slots__ = "mode", "dtype_meta", "scale", "zero_point"
  66. def __init__(
  67. self,
  68. mode: QuantMode,
  69. dtype_meta: QuantDtypeMeta,
  70. scale: Tensor,
  71. zero_point: Tensor,
  72. ):
  73. self.mode = mode
  74. self.dtype_meta = dtype_meta
  75. self.scale = scale
  76. self.zero_point = zero_point
  77. def update(self, qparams: "QParams"):
  78. for key in self.__slots__:
  79. setattr(self, key, getattr(qparams, key))
  80. def __eq__(self, other):
  81. if len(self.__slots__) != len(other.__slots__):
  82. return False
  83. for key in self.__slots__:
  84. if not hasattr(other, key) or getattr(self, key) != getattr(other, key):
  85. return False
  86. return True
  87. def __repr__(self):
  88. content = ", ".join(
  89. ["{}={}".format(key, getattr(self, key)) for key in self.__slots__]
  90. )
  91. return "QParams({})".format(content)
  92. class LSQParams:
  93. """
  94. To standardize LSQ's qparams format. If custom
  95. qparams is needed, inherit this class and add custom ``__slots__``.
  96. """
  97. __slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale"
  98. def __init__(
  99. self,
  100. mode: QuantMode,
  101. dtype_meta: QuantDtypeMeta,
  102. scale: Tensor,
  103. zero_point: Tensor,
  104. grad_scale: Tensor,
  105. ):
  106. self.mode = mode
  107. self.dtype_meta = dtype_meta
  108. self.scale = scale
  109. self.zero_point = zero_point
  110. self.grad_scale = grad_scale
  111. def update(self, lsqparams: "LSQParams"):
  112. for key in self.__slots__:
  113. setattr(self, key, getattr(lsqparams, key))
  114. def __eq__(self, other):
  115. if len(self.__slots__) != len(other.__slots__):
  116. return False
  117. for key in self.__slots__:
  118. if not hasattr(other, key) or getattr(self, key) != getattr(other, key):
  119. return False
  120. return True
  121. def __repr__(self):
  122. content = ", ".join(
  123. ["{}={}".format(key, getattr(self, key)) for key in self.__slots__]
  124. )
  125. return "LSQParams({})".format(content)
  126. class QParamsModuleMixin(abc.ABC):
  127. def get_quantized_dtype(self):
  128. qparams = self.get_qparams()
  129. dtype = qparams.dtype_meta
  130. scale = float(qparams.scale.numpy()) if qparams.scale is not None else None
  131. zero_point = (
  132. int(qparams.zero_point.numpy()) if qparams.zero_point is not None else None
  133. )
  134. return create_quantized_dtype(dtype, scale, zero_point)
  135. @abc.abstractmethod
  136. def get_qparams(self) -> QParams:
  137. pass
  138. _builtin_qparams = {
  139. QuantMode.SYMMERTIC: partial(QParams, mode=QuantMode.SYMMERTIC),
  140. QuantMode.ASYMMERTIC: partial(QParams, mode=QuantMode.ASYMMERTIC),
  141. }
  142. def create_qparams(
  143. mode: QuantMode = QuantMode.SYMMERTIC,
  144. dtype_meta: Union[str, QuantDtypeMeta] = None,
  145. scale: Tensor = None,
  146. zero_point: Tensor = None,
  147. ):
  148. """
  149. Return :class:`~.QParams` according to the mode.
  150. """
  151. if isinstance(dtype_meta, str):
  152. dtype_meta = _builtin_quant_dtypes[dtype_meta]
  153. if mode is None:
  154. return QParams(mode, dtype_meta, scale, zero_point)
  155. assert isinstance(mode, QuantMode)
  156. return _builtin_qparams[mode](
  157. dtype_meta=dtype_meta, scale=scale, zero_point=zero_point
  158. )
  159. def fake_quant_tensor(inp: Tensor, qparams: QParams) -> Tensor:
  160. """
  161. Apply fake quantization to the inp tensor.
  162. :param inp: the input tensor which need to be faked.
  163. :param qparams: to get mode, qmin, qmax, scale and zero_point from.
  164. """
  165. scale = qparams.scale
  166. if qparams.mode == QuantMode.ASYMMERTIC:
  167. zero_point = qparams.zero_point
  168. else:
  169. zero_point = Tensor([0.0], dtype=np.float32)
  170. qmin = qparams.dtype_meta.qmin
  171. qmax = qparams.dtype_meta.qmax
  172. op = builtin.FakeQuant(qmin=qmin, qmax=qmax)
  173. return apply(op, inp, scale, zero_point)[0]
  174. def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
  175. """
  176. Apply fake quantization to bias, with the special scale from input tensor
  177. and weight tensor, the quantized type set to qint32 also.
  178. :param bias: the bias tensor which need to be faked.
  179. :param inp: the input tensor which contain the quantization parameters.
  180. :param w_qat: the weight tensor which contain the quantization parameters.
  181. .. warning::
  182. Only work for symmetric quantization method now.
  183. """
  184. b_qat = bias
  185. if (
  186. getattr(inp, "qparams", None) is not None
  187. and getattr(w_qat, "qparams", None) is not None
  188. and bias is not None
  189. ):
  190. inp_params = inp.qparams
  191. w_params = w_qat.qparams
  192. if inp_params.scale is not None and w_params.scale is not None:
  193. assert inp_params.mode == w_params.mode, "incompatible QuantMode"
  194. # TODO: support quint8 dtype.
  195. assert (
  196. inp_params.dtype_meta.np_dtype_str == "int8"
  197. and w_params.dtype_meta.np_dtype_str == "int8"
  198. ), "fake_quant_bias only support int8 like dtype now"
  199. # use the same mode with weight.
  200. # TODO: avoid hardcode
  201. b_dtype = _builtin_quant_dtypes["qint32"]
  202. b_param = create_qparams(
  203. w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale
  204. )
  205. b_qat = fake_quant_tensor(bias, b_param)
  206. b_qat.qparams.update(b_param)
  207. return b_qat

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台