You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

observer.py 6.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from abc import abstractmethod
  9. import numpy as np
  10. from .. import functional as F
  11. from .._internal.dtype import _metadata_dict, get_quantized_dtype
  12. from ..core import Buffer, Function, ones, tensor, zeros
  13. from ..module import Module
  14. class Round(Function):
  15. def forward(self, x):
  16. return x.round()
  17. def backward(self, output_grads):
  18. return output_grads
  19. class Observer(Module):
  20. r"""
  21. A base class for Observer Module.
  22. :param dtype: a string indicating to collect scale and zero_point of which dtype
  23. """
  24. def __init__(self, dtype="qint8"):
  25. super().__init__()
  26. if dtype not in _metadata_dict.keys():
  27. raise ValueError(
  28. "unknown dtype: {}, only support {}".format(
  29. dtype, _metadata_dict.keys()
  30. )
  31. )
  32. self.dtype = dtype
  33. self.qmin = _metadata_dict[dtype].qmin
  34. self.qmax = _metadata_dict[dtype].qmax
  35. self.zero_point, self.scale = None, None
  36. self.enabled = True
  37. def get_dtype(self):
  38. scale, zero_point = self.get_qparams()
  39. numpy_scale = None if scale is None else scale.numpy()[0]
  40. numpy_zero_point = None if zero_point is None else zero_point.numpy()[0]
  41. return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point)
  42. def enable(self):
  43. self.enabled = True
  44. def disable(self):
  45. self.enabled = False
  46. @abstractmethod
  47. def forward(self, x):
  48. pass
  49. @abstractmethod
  50. def get_qparams(self, **kwargs):
  51. pass
  52. class IdentityObserver(Observer):
  53. r"""
  54. An test Observer that always return scale:1 and zero_point:0.
  55. """
  56. def __init__(self, *args, **kwargs):
  57. super().__init__(*args, **kwargs)
  58. self.zero_point = ones((1), dtype="float32")
  59. self.scale = zeros((1), dtype="float32")
  60. def forward(self, x):
  61. return x
  62. def get_qparams(self):
  63. return self.scale, self.zero_point
  64. class MinMaxObserver(Observer):
  65. def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs):
  66. super().__init__(*args, **kwargs)
  67. self.symmetric = symmetric
  68. if self.symmetric:
  69. # assert qmin + qmax == -1, 'when reduce_range, qmin + qmax shoule equal -1'
  70. self.zero_point = tensor((self.qmin + self.qmax + 1) // 2)
  71. self.min_val = Buffer(0.0, dtype=np.float32)
  72. self.max_val = Buffer(0.0, dtype=np.float32)
  73. self.scale_limit = eps
  74. # flag is used by cond_take, first time will be first flag, and after will be set as not_flag
  75. self.first_flag = Buffer(np.array([1, 0], dtype=np.int32))
  76. self.not_flag = Buffer(np.array([0, 1], dtype=np.int32))
  77. def set_min_max(self, tmp_min, tmp_max):
  78. # FIXME: cond_take will destory shape, use reshape to reset shape
  79. tmp_min = tmp_min.reshape(1)
  80. tmp_max = tmp_max.reshape(1)
  81. if self.training:
  82. F.zero_grad(
  83. F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0)
  84. )
  85. F.zero_grad(
  86. F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0)
  87. )
  88. F.zero_grad(
  89. F.add_update(
  90. self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0
  91. )
  92. )
  93. # FIXME: add_update is applied after the whole trace procedure in `symbolic=True`
  94. # mode. So use tmp_min/tmp_max to calc and save scale/zero_point for further
  95. # calculation in FakeQuant.
  96. self.set_scale_zero_point(tmp_min, tmp_max)
  97. def set_scale_zero_point(self, tmp_min, tmp_max):
  98. if self.symmetric:
  99. symmetric_max_vals = F.maximum(-tmp_min, tmp_max)
  100. # use maximun to avoid scale too small at the begin
  101. self.scale = F.maximum(
  102. symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
  103. )
  104. # zero_point = self.zero_point
  105. else:
  106. # use maximun to avoid scale too small at the begin
  107. self.scale = F.maximum(
  108. (tmp_max - tmp_min) / (self.qmax - self.qmin), self.scale_limit
  109. )
  110. # caculate zero_point
  111. self.zero_point = self.qmin - Round()((tmp_min / self.scale))
  112. def get_qparams(self):
  113. # scale and zero_point is runtime tensor rather than Buffer,
  114. # so need to re-calc if min_val and max_val are loaded.
  115. if self.scale is None:
  116. self.set_scale_zero_point(self.min_val, self.max_val)
  117. return self.scale, self.zero_point
  118. def forward(self, x_orig):
  119. if self.enabled:
  120. # stop gradient
  121. x = F.zero_grad(x_orig)
  122. # find max and min
  123. tmp_min, _ = F.cond_take(
  124. self.first_flag, F.concat([x.min(), F.minimum(self.min_val, x.min())])
  125. )
  126. tmp_max, _ = F.cond_take(
  127. self.first_flag, F.concat([x.max(), F.maximum(self.max_val, x.max())])
  128. )
  129. self.set_min_max(tmp_min, tmp_max)
  130. return x_orig
  131. class ExponentialMovingAverageObserver(MinMaxObserver):
  132. def __init__(self, momentum=0.9, *args, **kwargs):
  133. super().__init__(*args, **kwargs)
  134. self.momentum = Buffer(momentum)
  135. def set_momentum(self, momentum):
  136. self.momentum.set_value(momentum)
  137. def forward(self, x_orig):
  138. if self.enabled:
  139. # stop gradient
  140. x = F.zero_grad(x_orig)
  141. # Exponential Moving Average
  142. tmp_min, _ = F.cond_take(
  143. self.first_flag,
  144. F.concat(
  145. [
  146. x.min(),
  147. self.momentum * self.min_val + (1 - self.momentum) * x.min(),
  148. ]
  149. ),
  150. )
  151. tmp_max, _ = F.cond_take(
  152. self.first_flag,
  153. F.concat(
  154. [
  155. x.max(),
  156. self.momentum * self.max_val + (1 - self.momentum) * x.max(),
  157. ]
  158. ),
  159. )
  160. self.set_min_max(tmp_min, tmp_max)
  161. return x_orig

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台