You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 7.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import abc
  9. from enum import Enum
  10. from functools import partial, update_wrapper, wraps
  11. from typing import Union
  12. import numpy as np
  13. from .. import functional as F
  14. from ..autodiff import Function
  15. from ..core._imperative_rt.core2 import apply
  16. from ..core.ops import builtin
  17. from ..core.tensor.dtype import (
  18. QuantDtypeMeta,
  19. _builtin_quant_dtypes,
  20. create_quantized_dtype,
  21. )
  22. from ..tensor import Tensor
  23. class Round(Function):
  24. r"""The functional round have no grad and can not use for quantization-aware-training.
  25. We use Function and STE(Straight-Through Estimator) to implement backward propagation.
  26. """
  27. def forward(self, x):
  28. return F.round(x)
  29. def backward(self, output_grads):
  30. return output_grads
  31. def tqt_forward(qmin, qmax, inp, scale):
  32. op = builtin.TQT(qmin=qmin, qmax=qmax)
  33. (output,) = apply(op, inp, scale)
  34. return output
  35. def lsq_forward(qmin, qmax, inp, step_size, zero_point=None, scale_grad=None):
  36. if zero_point is None:
  37. zero_point = Tensor([0.0], dtype=np.float32)
  38. if scale_grad is None:
  39. scale_grad = Tensor([1.0], dtype=np.float32)
  40. op = builtin.LSQ(qmin=qmin, qmax=qmax)
  41. (output,) = apply(op, inp, step_size, zero_point, scale_grad)
  42. return output
  43. def register_method_to_class(cls):
  44. def decorator(func):
  45. @wraps(func)
  46. def wrapper(self, *args, **kwargs):
  47. return func(self, *args, **kwargs)
  48. if isinstance(func, partial):
  49. update_wrapper(func, func.func)
  50. setattr(cls, func.__name__, wrapper)
  51. return func
  52. return decorator
  53. class QuantMode(Enum):
  54. r"""Quantization mode enumerate class."""
  55. SYMMERTIC = 1
  56. ASYMMERTIC = 2
  57. class QParams:
  58. r"""To standardize FakeQuant, Observer and Tensor's qparams format. If custom
  59. qparams is needed, inherit this class and add custom ``__slots__``.
  60. """
  61. __slots__ = "mode", "dtype_meta", "scale", "zero_point"
  62. def __init__(
  63. self,
  64. mode: QuantMode,
  65. dtype_meta: QuantDtypeMeta,
  66. scale: Tensor,
  67. zero_point: Tensor,
  68. ):
  69. self.mode = mode
  70. self.dtype_meta = dtype_meta
  71. self.scale = scale
  72. self.zero_point = zero_point
  73. def update(self, qparams: "QParams"):
  74. for key in self.__slots__:
  75. setattr(self, key, getattr(qparams, key))
  76. def __eq__(self, other):
  77. if len(self.__slots__) != len(other.__slots__):
  78. return False
  79. for key in self.__slots__:
  80. if not hasattr(other, key) or getattr(self, key) != getattr(other, key):
  81. return False
  82. return True
  83. def __repr__(self):
  84. content = ", ".join(
  85. ["{}={}".format(key, getattr(self, key)) for key in self.__slots__]
  86. )
  87. return "QParams({})".format(content)
  88. class LSQParams(QParams):
  89. r"""LSQ qparams with extra grad_scale slot."""
  90. __slots__ = "mode", "dtype_meta", "scale", "zero_point", "grad_scale"
  91. def __init__(
  92. self,
  93. mode: QuantMode,
  94. dtype_meta: QuantDtypeMeta,
  95. scale: Tensor,
  96. zero_point: Tensor,
  97. grad_scale: Tensor,
  98. ):
  99. super().__init__(mode, dtype_meta, scale, zero_point)
  100. self.grad_scale = grad_scale
  101. class QParamsModuleMixin(abc.ABC):
  102. def get_quantized_dtype(self):
  103. qparams = self.get_qparams()
  104. dtype = qparams.dtype_meta
  105. scale = float(qparams.scale.numpy()) if qparams.scale is not None else None
  106. zero_point = (
  107. int(qparams.zero_point.numpy()) if qparams.zero_point is not None else None
  108. )
  109. return create_quantized_dtype(dtype, scale, zero_point)
  110. @abc.abstractmethod
  111. def get_qparams(self) -> QParams:
  112. pass
  113. _builtin_qparams = {
  114. QuantMode.SYMMERTIC: partial(QParams, mode=QuantMode.SYMMERTIC),
  115. QuantMode.ASYMMERTIC: partial(QParams, mode=QuantMode.ASYMMERTIC),
  116. }
  117. def create_qparams(
  118. mode: QuantMode = QuantMode.SYMMERTIC,
  119. dtype_meta: Union[str, QuantDtypeMeta] = None,
  120. scale: Tensor = None,
  121. zero_point: Tensor = None,
  122. ):
  123. r"""
  124. Args:
  125. mode: QuantMode:
  126. dtype_meta: Union[str:
  127. QuantDtypeMeta]:
  128. scale: Tensor:
  129. zero_point: Tensor:
  130. """
  131. if isinstance(dtype_meta, str):
  132. dtype_meta = _builtin_quant_dtypes[dtype_meta]
  133. if mode is None:
  134. return QParams(mode, dtype_meta, scale, zero_point)
  135. assert isinstance(mode, QuantMode)
  136. return _builtin_qparams[mode](
  137. dtype_meta=dtype_meta, scale=scale, zero_point=zero_point
  138. )
  139. def fake_quant_tensor(inp: Tensor, qparams: QParams) -> Tensor:
  140. """Apply fake quantization to the inp tensor.
  141. Args:
  142. inp: the input tensor which need to be faked.
  143. qparams: to get mode, qmin, qmax, scale and zero_point from.
  144. """
  145. scale = qparams.scale
  146. if qparams.mode == QuantMode.ASYMMERTIC:
  147. zero_point = qparams.zero_point
  148. else:
  149. zero_point = Tensor([0.0], dtype=np.float32)
  150. qmin = qparams.dtype_meta.qmin
  151. qmax = qparams.dtype_meta.qmax
  152. op = builtin.FakeQuant(qmin=qmin, qmax=qmax)
  153. return apply(op, inp, scale, zero_point)[0]
  154. def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
  155. """Apply fake quantization to bias, with the special scale from input tensor
  156. and weight tensor, the quantized type set to qint32 also.
  157. Args:
  158. bias: the bias tensor which need to be faked.
  159. inp: the input tensor which contain the quantization parameters.
  160. w_qat: the weight tensor which contain the quantization parameters.
  161. Warning:
  162. Only work for symmetric quantization method now.
  163. """
  164. b_qat = bias
  165. if (
  166. getattr(inp, "qparams", None) is not None
  167. and getattr(w_qat, "qparams", None) is not None
  168. and bias is not None
  169. ):
  170. inp_params = inp.qparams
  171. w_params = w_qat.qparams
  172. if inp_params.scale is not None and w_params.scale is not None:
  173. assert inp_params.mode == w_params.mode, "incompatible QuantMode"
  174. # TODO: support quint8 dtype.
  175. assert (
  176. inp_params.dtype_meta.np_dtype_str == "int8"
  177. and w_params.dtype_meta.np_dtype_str == "int8"
  178. ), "fake_quant_bias only support int8 like dtype now"
  179. # use the same mode with weight.
  180. # TODO: avoid hardcode
  181. b_dtype = _builtin_quant_dtypes["qint32"]
  182. b_param = create_qparams(
  183. w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale
  184. )
  185. b_qat = fake_quant_tensor(bias, b_param)
  186. b_qat.qparams.update(b_param)
  187. return b_qat

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台