You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 7.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import abc
  9. from enum import Enum
  10. from functools import partial, update_wrapper, wraps
  11. from typing import Union
  12. import numpy as np
  13. from .. import functional as F
  14. from ..autodiff import Function
  15. from ..core._imperative_rt.core2 import apply
  16. from ..core.ops import builtin
  17. from ..core.tensor.dtype import (
  18. QuantDtypeMeta,
  19. _builtin_quant_dtypes,
  20. create_quantized_dtype,
  21. )
  22. from ..tensor import Tensor
  23. class Round(Function):
  24. r"""The functional round have no grad and can not use for quantization-aware-training.
  25. We use Function and STE(Straight-Through Estimator) to implement backward propagation.
  26. """
  27. def forward(self, x):
  28. return F.round(x)
  29. def backward(self, output_grads):
  30. return output_grads
  31. def tqt_forward(qmin, qmax, inp, scale):
  32. op = builtin.TQT(qmin=qmin, qmax=qmax)
  33. (output,) = apply(op, inp, scale)
  34. return output
  35. def lsq_forward(qmin, qmax, inp, step_size, zero_point=None, scale_grad=None):
  36. if zero_point is None:
  37. zero_point = Tensor([0.0], dtype=np.float32)
  38. if scale_grad is None:
  39. scale_grad = Tensor([1.0], dtype=np.float32)
  40. op = builtin.LSQ(qmin=qmin, qmax=qmax)
  41. (output,) = apply(op, inp, step_size, zero_point, scale_grad)
  42. return output
  43. def register_method_to_class(cls):
  44. def decorator(func):
  45. @wraps(func)
  46. def wrapper(self, *args, **kwargs):
  47. return func(self, *args, **kwargs)
  48. if isinstance(func, partial):
  49. update_wrapper(func, func.func)
  50. setattr(cls, func.__name__, wrapper)
  51. return func
  52. return decorator
  53. class QuantMode(Enum):
  54. r"""Quantization mode enumerate class."""
  55. SYMMERTIC = 1
  56. ASYMMERTIC = 2
  57. class QParams:
  58. r"""To standardize FakeQuant, Observer and Tensor's qparams format. If custom
  59. qparams is needed, inherit this class and add custom ``__slots__``.
  60. """
  61. __slots__ = "mode", "dtype_meta", "scale", "zero_point"
  62. def __init__(
  63. self,
  64. mode: QuantMode,
  65. dtype_meta: QuantDtypeMeta,
  66. scale: Tensor,
  67. zero_point: Tensor,
  68. ):
  69. self.mode = mode
  70. self.dtype_meta = dtype_meta
  71. self.scale = scale
  72. self.zero_point = zero_point
  73. def update(self, qparams: "QParams"):
  74. for key in self.__slots__:
  75. setattr(self, key, getattr(qparams, key))
  76. def __eq__(self, other):
  77. if len(self.__slots__) != len(other.__slots__):
  78. return False
  79. for key in self.__slots__:
  80. if not hasattr(other, key) or getattr(self, key) != getattr(other, key):
  81. return False
  82. return True
  83. def __repr__(self):
  84. content = ", ".join(
  85. ["{}={}".format(key, getattr(self, key)) for key in self.__slots__]
  86. )
  87. return "QParams({})".format(content)
  88. class LSQParams:
  89. r"""To standardize LSQ's qparams format. If custom
  90. qparams is needed, inherit this class and add custom ``__slots__``.
  91. """
  92. __slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale"
  93. def __init__(
  94. self,
  95. mode: QuantMode,
  96. dtype_meta: QuantDtypeMeta,
  97. scale: Tensor,
  98. zero_point: Tensor,
  99. grad_scale: Tensor,
  100. ):
  101. self.mode = mode
  102. self.dtype_meta = dtype_meta
  103. self.scale = scale
  104. self.zero_point = zero_point
  105. self.grad_scale = grad_scale
  106. def update(self, lsqparams: "LSQParams"):
  107. for key in self.__slots__:
  108. setattr(self, key, getattr(lsqparams, key))
  109. def __eq__(self, other):
  110. if len(self.__slots__) != len(other.__slots__):
  111. return False
  112. for key in self.__slots__:
  113. if not hasattr(other, key) or getattr(self, key) != getattr(other, key):
  114. return False
  115. return True
  116. def __repr__(self):
  117. content = ", ".join(
  118. ["{}={}".format(key, getattr(self, key)) for key in self.__slots__]
  119. )
  120. return "LSQParams({})".format(content)
  121. class QParamsModuleMixin(abc.ABC):
  122. def get_quantized_dtype(self):
  123. qparams = self.get_qparams()
  124. dtype = qparams.dtype_meta
  125. scale = float(qparams.scale.numpy()) if qparams.scale is not None else None
  126. zero_point = (
  127. int(qparams.zero_point.numpy()) if qparams.zero_point is not None else None
  128. )
  129. return create_quantized_dtype(dtype, scale, zero_point)
  130. @abc.abstractmethod
  131. def get_qparams(self) -> QParams:
  132. pass
  133. _builtin_qparams = {
  134. QuantMode.SYMMERTIC: partial(QParams, mode=QuantMode.SYMMERTIC),
  135. QuantMode.ASYMMERTIC: partial(QParams, mode=QuantMode.ASYMMERTIC),
  136. }
  137. def create_qparams(
  138. mode: QuantMode = QuantMode.SYMMERTIC,
  139. dtype_meta: Union[str, QuantDtypeMeta] = None,
  140. scale: Tensor = None,
  141. zero_point: Tensor = None,
  142. ):
  143. r"""
  144. Args:
  145. mode: QuantMode:
  146. dtype_meta: Union[str:
  147. QuantDtypeMeta]:
  148. scale: Tensor:
  149. zero_point: Tensor:
  150. """
  151. if isinstance(dtype_meta, str):
  152. dtype_meta = _builtin_quant_dtypes[dtype_meta]
  153. if mode is None:
  154. return QParams(mode, dtype_meta, scale, zero_point)
  155. assert isinstance(mode, QuantMode)
  156. return _builtin_qparams[mode](
  157. dtype_meta=dtype_meta, scale=scale, zero_point=zero_point
  158. )
  159. def fake_quant_tensor(inp: Tensor, qparams: QParams) -> Tensor:
  160. """Apply fake quantization to the inp tensor.
  161. Args:
  162. inp: the input tensor which need to be faked.
  163. qparams: to get mode, qmin, qmax, scale and zero_point from.
  164. """
  165. scale = qparams.scale
  166. if qparams.mode == QuantMode.ASYMMERTIC:
  167. zero_point = qparams.zero_point
  168. else:
  169. zero_point = Tensor([0.0], dtype=np.float32)
  170. qmin = qparams.dtype_meta.qmin
  171. qmax = qparams.dtype_meta.qmax
  172. op = builtin.FakeQuant(qmin=qmin, qmax=qmax)
  173. return apply(op, inp, scale, zero_point)[0]
  174. def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
  175. """Apply fake quantization to bias, with the special scale from input tensor
  176. and weight tensor, the quantized type set to qint32 also.
  177. Args:
  178. bias: the bias tensor which need to be faked.
  179. inp: the input tensor which contain the quantization parameters.
  180. w_qat: the weight tensor which contain the quantization parameters.
  181. Warning:
  182. Only work for symmetric quantization method now.
  183. """
  184. b_qat = bias
  185. if (
  186. getattr(inp, "qparams", None) is not None
  187. and getattr(w_qat, "qparams", None) is not None
  188. and bias is not None
  189. ):
  190. inp_params = inp.qparams
  191. w_params = w_qat.qparams
  192. if inp_params.scale is not None and w_params.scale is not None:
  193. assert inp_params.mode == w_params.mode, "incompatible QuantMode"
  194. # TODO: support quint8 dtype.
  195. assert (
  196. inp_params.dtype_meta.np_dtype_str == "int8"
  197. and w_params.dtype_meta.np_dtype_str == "int8"
  198. ), "fake_quant_bias only support int8 like dtype now"
  199. # use the same mode with weight.
  200. # TODO: avoid hardcode
  201. b_dtype = _builtin_quant_dtypes["qint32"]
  202. b_param = create_qparams(
  203. w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale
  204. )
  205. b_qat = fake_quant_tensor(bias, b_param)
  206. b_qat.qparams.update(b_param)
  207. return b_qat

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台