You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

observer.py 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from abc import abstractmethod
  9. import numpy as np
  10. from .. import functional as F
  11. from .._internal.dtype import _metadata_dict, get_quantized_dtype
  12. from ..core import Buffer, Function, ones, tensor, zeros
  13. from ..module import Module
  14. class Round(Function):
  15. def forward(self, x):
  16. return x.round()
  17. def backward(self, output_grads):
  18. return output_grads
  19. class Observer(Module):
  20. r"""
  21. A base class for Observer Module.
  22. :param dtype: a string indicating to collect scale and zero_point of which dtype
  23. """
  24. def __init__(self, dtype="qint8"):
  25. super().__init__()
  26. if dtype not in _metadata_dict.keys():
  27. raise ValueError(
  28. "unknown dtype: {}, only support {}".format(
  29. dtype, _metadata_dict.keys()
  30. )
  31. )
  32. self.dtype = dtype
  33. self.qmin = _metadata_dict[dtype].qmin
  34. self.qmax = _metadata_dict[dtype].qmax
  35. self.zero_point, self.scale = None, None
  36. self.enabled = True
  37. def get_dtype(self):
  38. scale, zero_point = self.get_qparams()
  39. numpy_scale = None if scale is None else scale.numpy()[0]
  40. numpy_zero_point = None if zero_point is None else zero_point.numpy()[0]
  41. return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point)
  42. def enable(self):
  43. self.enabled = True
  44. def disable(self):
  45. self.enabled = False
  46. def train(self, mode: bool = True) -> None:
  47. super().train(mode)
  48. if mode:
  49. self.enable()
  50. else:
  51. self.disable()
  52. @abstractmethod
  53. def forward(self, x):
  54. pass
  55. @abstractmethod
  56. def get_qparams(self, **kwargs):
  57. pass
  58. class IdentityObserver(Observer):
  59. r"""
  60. An test Observer that always return scale:1 and zero_point:0.
  61. """
  62. def __init__(self, *args, **kwargs):
  63. super().__init__(*args, **kwargs)
  64. self.zero_point = ones((1), dtype="float32")
  65. self.scale = zeros((1), dtype="float32")
  66. def forward(self, x):
  67. return x
  68. def get_qparams(self):
  69. return self.scale, self.zero_point
  70. class MinMaxObserver(Observer):
  71. def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs):
  72. super().__init__(*args, **kwargs)
  73. self.symmetric = symmetric
  74. if self.symmetric:
  75. # assert qmin + qmax == -1, 'when reduce_range, qmin + qmax shoule equal -1'
  76. self.zero_point = tensor((self.qmin + self.qmax + 1) // 2)
  77. self.min_val = Buffer(0.0, dtype=np.float32)
  78. self.max_val = Buffer(0.0, dtype=np.float32)
  79. self.scale_limit = eps
  80. # flag is used by cond_take, first time will be first flag, and after will be set as not_flag
  81. self.first_flag = Buffer(np.array([1, 0], dtype=np.int32))
  82. self.not_flag = Buffer(np.array([0, 1], dtype=np.int32))
  83. def set_min_max(self, tmp_min, tmp_max):
  84. # FIXME: cond_take will destory shape, use reshape to reset shape
  85. tmp_min = tmp_min.reshape(1)
  86. tmp_max = tmp_max.reshape(1)
  87. if self.training:
  88. F.zero_grad(
  89. F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0)
  90. )
  91. F.zero_grad(
  92. F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0)
  93. )
  94. F.zero_grad(
  95. F.add_update(
  96. self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0
  97. )
  98. )
  99. # FIXME: add_update is applied after the whole trace procedure in `symbolic=True`
  100. # mode. So use tmp_min/tmp_max to calc and save scale/zero_point for further
  101. # calculation in FakeQuant.
  102. self.set_scale_zero_point(tmp_min, tmp_max)
  103. def set_scale_zero_point(self, tmp_min, tmp_max):
  104. if self.symmetric:
  105. symmetric_max_vals = F.maximum(-tmp_min, tmp_max)
  106. # use maximun to avoid scale too small at the begin
  107. self.scale = F.maximum(
  108. symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
  109. )
  110. # zero_point = self.zero_point
  111. else:
  112. # use maximun to avoid scale too small at the begin
  113. self.scale = F.maximum(
  114. (tmp_max - tmp_min) / (self.qmax - self.qmin), self.scale_limit
  115. )
  116. # caculate zero_point
  117. self.zero_point = self.qmin - Round()((tmp_min / self.scale))
  118. def get_qparams(self):
  119. # scale and zero_point is runtime tensor rather than Buffer,
  120. # so need to re-calc if min_val and max_val are loaded.
  121. if self.scale is None:
  122. self.set_scale_zero_point(self.min_val, self.max_val)
  123. return self.scale, self.zero_point
  124. def forward(self, x_orig):
  125. if self.enabled:
  126. # stop gradient
  127. x = F.zero_grad(x_orig)
  128. # find max and min
  129. tmp_min, _ = F.cond_take(
  130. self.first_flag, F.concat([x.min(), F.minimum(self.min_val, x.min())])
  131. )
  132. tmp_max, _ = F.cond_take(
  133. self.first_flag, F.concat([x.max(), F.maximum(self.max_val, x.max())])
  134. )
  135. self.set_min_max(tmp_min, tmp_max)
  136. return x_orig
  137. class ExponentialMovingAverageObserver(MinMaxObserver):
  138. def __init__(self, momentum=0.9, *args, **kwargs):
  139. super().__init__(*args, **kwargs)
  140. self.momentum = Buffer(momentum)
  141. def set_momentum(self, momentum):
  142. self.momentum.set_value(momentum)
  143. def forward(self, x_orig):
  144. if self.enabled:
  145. # stop gradient
  146. x = F.zero_grad(x_orig)
  147. # Exponential Moving Average
  148. tmp_min, _ = F.cond_take(
  149. self.first_flag,
  150. F.concat(
  151. [
  152. x.min(),
  153. self.momentum * self.min_val + (1 - self.momentum) * x.min(),
  154. ]
  155. ),
  156. )
  157. tmp_max, _ = F.cond_take(
  158. self.first_flag,
  159. F.concat(
  160. [
  161. x.max(),
  162. self.momentum * self.max_val + (1 - self.momentum) * x.max(),
  163. ]
  164. ),
  165. )
  166. self.set_min_max(tmp_min, tmp_max)
  167. return x_orig

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台