You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

utils.py 6.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import abc
  9. from enum import Enum
  10. from functools import partial, update_wrapper, wraps
  11. from typing import Union
  12. import numpy as np
  13. from .. import functional as F
  14. from ..autodiff import Function
  15. from ..core._imperative_rt.core2 import apply
  16. from ..core.ops import builtin
  17. from ..core.tensor.dtype import (
  18. QuantDtypeMeta,
  19. _builtin_quant_dtypes,
  20. create_quantized_dtype,
  21. )
  22. from ..tensor import Tensor
  23. class Round(Function):
  24. """
  25. The functional round have no grad and can not use for quantization-aware-training.
  26. We use Function and STE(Straight-Through Estimator) to implement backward propagation.
  27. """
  28. def forward(self, x):
  29. return F.round(x)
  30. def backward(self, output_grads):
  31. return output_grads
  32. def tqt_forward(qmin, qmax, inp, scale):
  33. op = builtin.TQT(qmin=qmin, qmax=qmax)
  34. (output,) = apply(op, inp, scale)
  35. return output
  36. def register_method_to_class(cls):
  37. def decorator(func):
  38. @wraps(func)
  39. def wrapper(self, *args, **kwargs):
  40. return func(self, *args, **kwargs)
  41. if isinstance(func, partial):
  42. update_wrapper(func, func.func)
  43. setattr(cls, func.__name__, wrapper)
  44. return func
  45. return decorator
  46. class QuantMode(Enum):
  47. """
  48. Quantization mode enumerate class.
  49. """
  50. SYMMERTIC = 1
  51. ASYMMERTIC = 2
  52. class QParams:
  53. """
  54. To standardize FakeQuant, Observer and Tensor's qparams format. If custom
  55. qparams is needed, inherit this class and add custom ``__slots__``.
  56. """
  57. __slots__ = "mode", "dtype_meta", "scale", "zero_point"
  58. def __init__(
  59. self,
  60. mode: QuantMode,
  61. dtype_meta: QuantDtypeMeta,
  62. scale: Tensor,
  63. zero_point: Tensor,
  64. ):
  65. self.mode = mode
  66. self.dtype_meta = dtype_meta
  67. self.scale = scale
  68. self.zero_point = zero_point
  69. def update(self, qparams: "QParams"):
  70. for key in self.__slots__:
  71. setattr(self, key, getattr(qparams, key))
  72. def __eq__(self, other):
  73. if len(self.__slots__) != len(other.__slots__):
  74. return False
  75. for key in self.__slots__:
  76. if not hasattr(other, key) or getattr(self, key) != getattr(other, key):
  77. return False
  78. return True
  79. def __repr__(self):
  80. content = ", ".join(
  81. ["{}={}".format(key, getattr(self, key)) for key in self.__slots__]
  82. )
  83. return "QParams({})".format(content)
  84. class QParamsModuleMixin(abc.ABC):
  85. def get_quantized_dtype(self):
  86. qparams = self.get_qparams()
  87. dtype = qparams.dtype_meta
  88. scale = float(qparams.scale.numpy()) if qparams.scale is not None else None
  89. zero_point = (
  90. int(qparams.zero_point.numpy()) if qparams.zero_point is not None else None
  91. )
  92. return create_quantized_dtype(dtype, scale, zero_point)
  93. @abc.abstractmethod
  94. def get_qparams(self) -> QParams:
  95. pass
  96. _builtin_qparams = {
  97. QuantMode.SYMMERTIC: partial(QParams, mode=QuantMode.SYMMERTIC),
  98. QuantMode.ASYMMERTIC: partial(QParams, mode=QuantMode.ASYMMERTIC),
  99. }
  100. def create_qparams(
  101. mode: QuantMode = QuantMode.SYMMERTIC,
  102. dtype_meta: Union[str, QuantDtypeMeta] = None,
  103. scale: Tensor = None,
  104. zero_point: Tensor = None,
  105. ):
  106. """
  107. Return :class:`~.QParams` according to the mode.
  108. """
  109. if isinstance(dtype_meta, str):
  110. dtype_meta = _builtin_quant_dtypes[dtype_meta]
  111. if mode is None:
  112. return QParams(mode, dtype_meta, scale, zero_point)
  113. assert isinstance(mode, QuantMode)
  114. return _builtin_qparams[mode](
  115. dtype_meta=dtype_meta, scale=scale, zero_point=zero_point
  116. )
  117. def fake_quant_tensor(inp: Tensor, qparams: QParams) -> Tensor:
  118. """
  119. Apply fake quantization to the inp tensor.
  120. :param inp: the input tensor which need to be faked.
  121. :param qparams: to get mode, qmin, qmax, scale and zero_point from.
  122. """
  123. scale = qparams.scale
  124. if qparams.mode == QuantMode.ASYMMERTIC:
  125. zero_point = qparams.zero_point
  126. else:
  127. zero_point = Tensor([0.0], dtype=np.float32)
  128. qmin = qparams.dtype_meta.qmin
  129. qmax = qparams.dtype_meta.qmax
  130. op = builtin.FakeQuant(qmin=qmin, qmax=qmax)
  131. return apply(op, inp, scale, zero_point)[0]
  132. def fake_quant_bias(bias: Tensor, inp: Tensor, w_qat: Tensor) -> Tensor:
  133. """
  134. Apply fake quantization to bias, with the special scale from input tensor
  135. and weight tensor, the quantized type set to qint32 also.
  136. :param bias: the bias tensor which need to be faked.
  137. :param inp: the input tensor which contain the quantization parameters.
  138. :param w_qat: the weight tensor which contain the quantization parameters.
  139. .. warning::
  140. Only work for symmetric quantization method now.
  141. """
  142. b_qat = bias
  143. if (
  144. getattr(inp, "qparams", None) is not None
  145. and getattr(w_qat, "qparams", None) is not None
  146. and bias is not None
  147. ):
  148. inp_params = inp.qparams
  149. w_params = w_qat.qparams
  150. if inp_params.scale is not None and w_params.scale is not None:
  151. assert inp_params.mode == w_params.mode, "incompatible QuantMode"
  152. # TODO: support quint8 dtype.
  153. assert (
  154. inp_params.dtype_meta.np_dtype_str == "int8"
  155. and w_params.dtype_meta.np_dtype_str == "int8"
  156. ), "fake_quant_bias only support int8 like dtype now"
  157. # use the same mode with weight.
  158. # TODO: avoid hardcode
  159. b_dtype = _builtin_quant_dtypes["qint32"]
  160. b_param = create_qparams(
  161. w_params.mode, b_dtype, scale=inp_params.scale * w_params.scale
  162. )
  163. b_qat = fake_quant_tensor(bias, b_param)
  164. b_qat.qparams.update(b_param)
  165. return b_qat

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台