You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

observer.py 15 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import math
  9. from abc import abstractmethod
  10. import numpy as np
  11. from .. import functional as F
  12. from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype
  13. from ..module import Module
  14. from ..tensor import Tensor
  15. from .utils import QuantMode, Round, get_qparam_dict
  16. class Observer(Module):
  17. r"""
  18. A base class for Observer Module.
  19. :param dtype: a string indicating to collect scale and zero_point of which dtype.
  20. :param narrow_range: whether the absolute value of ``qmin`` is the same as ``qmax``,
  21. instead of 1 greater. Usually True for weight and False for activation.
  22. """
  23. def __init__(self, dtype: str, narrow_range: bool = False):
  24. super().__init__()
  25. if dtype not in _metadata_dict.keys():
  26. raise ValueError(
  27. "unknown dtype: {}, only support {}".format(
  28. dtype, _metadata_dict.keys()
  29. )
  30. )
  31. self.dtype = dtype
  32. self.narrow_range = narrow_range
  33. self.qmin = (
  34. -_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin
  35. )
  36. self.qmax = _metadata_dict[dtype].qmax
  37. self.enabled = True
  38. def get_dtype(self):
  39. q_dict = self.get_qparams()
  40. numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0]
  41. numpy_zero_point = (
  42. None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0]
  43. )
  44. return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point)
  45. def enable(self):
  46. self.enabled = True
  47. def disable(self):
  48. self.enabled = False
  49. def train(self, mode: bool = True, recursive: bool = True) -> None:
  50. super().train(mode, recursive)
  51. if mode:
  52. self.enable()
  53. else:
  54. self.disable()
  55. @abstractmethod
  56. def forward(self, x):
  57. pass
  58. @abstractmethod
  59. def get_qparams(self, **kwargs):
  60. pass
  61. class MinMaxObserver(Observer):
  62. def __init__(
  63. self,
  64. mode=QuantMode.SYMMERTIC,
  65. eps=0.00001,
  66. dtype="qint8",
  67. narrow_range: bool = False,
  68. ):
  69. super().__init__(dtype, narrow_range)
  70. self.mode = mode
  71. self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32)
  72. self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32)
  73. self.scale_limit = eps
  74. def _calculate_qparams(self, inp_min_val, inp_max_val):
  75. min_val = F.minimum(0.0, inp_min_val)
  76. max_val = F.maximum(0.0, inp_max_val)
  77. q_dict = get_qparam_dict(self.mode)
  78. q_dict["min_val"] = inp_min_val
  79. q_dict["max_val"] = inp_max_val
  80. q_dict["enable_observer"] = self.enable
  81. if self.mode == QuantMode.SYMMERTIC:
  82. symmetric_max_vals = F.maximum(-min_val, max_val)
  83. # use maximun to avoid scale too small at the begin
  84. q_dict["scale"] = F.maximum(
  85. symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
  86. )
  87. # zero_point = self.zero_point
  88. else:
  89. # use maximun to avoid scale too small at the begin
  90. q_dict["scale"] = F.maximum(
  91. (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit,
  92. )
  93. # caculate zero_point
  94. q_dict["zero_point"] = self.qmin - Round()((min_val / q_dict["scale"]))
  95. return q_dict
  96. def get_qparams(self):
  97. return self._calculate_qparams(self.min_val, self.max_val)
  98. def forward(self, x_orig):
  99. if self.enabled:
  100. # stop gradient
  101. x = x_orig.detach()
  102. # find max and min
  103. self.min_val._reset(F.minimum(self.min_val, x.min()))
  104. self.max_val._reset(F.maximum(self.max_val, x.max()))
  105. return x_orig
  106. class ExponentialMovingAverageObserver(MinMaxObserver):
  107. def __init__(
  108. self,
  109. momentum=0.9,
  110. mode=QuantMode.SYMMERTIC,
  111. eps=0.00001,
  112. dtype="qint8",
  113. narrow_range: bool = False,
  114. ):
  115. super().__init__(mode, eps, dtype, narrow_range)
  116. self.momentum = Tensor(momentum)
  117. self.runtime_momentum = Tensor(0.0)
  118. def set_momentum(self, momentum):
  119. self.momentum._reset(momentum)
  120. def forward(self, x_orig):
  121. if self.enabled:
  122. # stop gradient
  123. x = x_orig.detach()
  124. # Exponential Moving Average
  125. self.min_val._reset(
  126. self.min_val * self.runtime_momentum
  127. + (1 - self.runtime_momentum) * x.min()
  128. )
  129. self.max_val._reset(
  130. self.max_val * self.runtime_momentum
  131. + (1 - self.runtime_momentum) * x.max()
  132. )
  133. self.runtime_momentum = self.momentum
  134. return x_orig
  135. class HistogramObserver(MinMaxObserver):
  136. def __init__(
  137. self,
  138. bins=2048,
  139. upsample_rate=128,
  140. mode=QuantMode.SYMMERTIC,
  141. eps=0.00001,
  142. dtype="qint8",
  143. narrow_range: bool = False,
  144. ):
  145. super().__init__(mode, eps, dtype, narrow_range)
  146. self.bins = bins
  147. self.upsample_rate = upsample_rate
  148. self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1
  149. self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32")
  150. def _non_linear_param_search(self):
  151. r"""
  152. Non-linear parameter search.
  153. An approximation for L2 error minimization for selecting min/max.
  154. By selecting new min/max, we filter out outliers in input distribution.
  155. """
  156. np_min_val = self.min_val.numpy()[0]
  157. np_max_val = self.max_val.numpy()[0]
  158. np_histogram = self.histogram.numpy()
  159. assert len(np_histogram) == self.bins, "bins mistmatch"
  160. bin_width = (np_max_val - np_min_val) / self.bins
  161. def _get_norm(delta_begin, delta_end, density, norm_type):
  162. r"""
  163. Compute the norm of the values uniformaly distributed between
  164. delta_begin and delta_end.
  165. norm = density * (integral_{begin, end} x^2)
  166. = density * (end^3 - begin^3) / 3
  167. """
  168. assert norm_type == "L2", "Only L2 norms are currently supported"
  169. norm = 0.0
  170. if norm_type == "L2":
  171. norm = (
  172. delta_end * delta_end * delta_end
  173. - delta_begin * delta_begin * delta_begin
  174. ) / 3
  175. return density * norm
  176. def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
  177. r"""
  178. Compute the quantization error if we use start_bin to end_bin as the
  179. min and max to do the quantization.
  180. """
  181. norm = 0.0
  182. dst_bin_width = (
  183. bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
  184. )
  185. if dst_bin_width == 0.0:
  186. return 0.0
  187. for src_bin in range(self.bins):
  188. # distances from the beginning of first dst_bin to the beginning and
  189. # end of src_bin
  190. src_bin_begin = (src_bin - next_start_bin) * bin_width
  191. src_bin_end = src_bin_begin + bin_width
  192. # which dst_bins the beginning and end of src_bin belong to?
  193. dst_bin_of_begin = min(
  194. self.dst_nbins - 1,
  195. max(0.0, math.floor(src_bin_begin / dst_bin_width)),
  196. )
  197. dst_bin_of_end = min(
  198. self.dst_nbins - 1,
  199. max(0.0, math.floor(src_bin_end / dst_bin_width)),
  200. )
  201. dst_bin_of_begin_center = (
  202. dst_bin_of_begin * dst_bin_width + dst_bin_width / 2
  203. )
  204. density = np_histogram[src_bin] / bin_width
  205. if dst_bin_of_begin == dst_bin_of_end:
  206. # if src_bin is entirely within 1 dst_bin
  207. delta_begin = src_bin_begin - dst_bin_of_begin_center
  208. delta_end = src_bin_end - dst_bin_of_begin_center
  209. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  210. else:
  211. delta_begin = src_bin_begin - dst_bin_of_begin_center
  212. delta_end = dst_bin_width / 2
  213. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  214. norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm(
  215. -dst_bin_width / 2, dst_bin_width / 2, density, norm_type
  216. )
  217. dst_bin_of_end_center = (
  218. dst_bin_of_end * dst_bin_width + dst_bin_width / 2
  219. )
  220. delta_begin = -dst_bin_width / 2
  221. delta_end = src_bin_end - dst_bin_of_end_center
  222. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  223. return norm
  224. # cumulative sum
  225. total = sum(np_histogram)
  226. cSum = np.cumsum(np_histogram, axis=0)
  227. stepsize = 1e-5 # granularity
  228. alpha = 0.0 # lower bound
  229. beta = 1.0 # upper bound
  230. start_bin = 0
  231. end_bin = self.bins - 1
  232. norm_min = float("inf")
  233. while alpha < beta:
  234. # Find the next step
  235. next_alpha = alpha + stepsize
  236. next_beta = beta - stepsize
  237. # find the left and right bins between the quantile bounds
  238. l = start_bin
  239. r = end_bin
  240. while l < end_bin and cSum[l] < next_alpha * total:
  241. l = l + 1
  242. while r > start_bin and cSum[r] > next_beta * total:
  243. r = r - 1
  244. # decide the next move
  245. next_start_bin = start_bin
  246. next_end_bin = end_bin
  247. if (l - start_bin) > (end_bin - r):
  248. # move the start bin
  249. next_start_bin = l
  250. alpha = next_alpha
  251. else:
  252. # move the end bin
  253. next_end_bin = r
  254. beta = next_beta
  255. if next_start_bin == start_bin and next_end_bin == end_bin:
  256. continue
  257. # calculate the quantization error using next_start_bin and next_end_bin
  258. norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2")
  259. if norm > norm_min:
  260. break
  261. norm_min = norm
  262. start_bin = next_start_bin
  263. end_bin = next_end_bin
  264. new_min = self.min_val + Tensor(bin_width * start_bin, dtype=np.float32)
  265. new_max = self.min_val + Tensor(bin_width * (end_bin + 1), dtype=np.float32)
  266. return new_min, new_max
  267. def get_qparams(self):
  268. new_min, new_max = self._non_linear_param_search()
  269. return self._calculate_qparams(new_min, new_max)
  270. def _combine_histograms(
  271. self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins
  272. ):
  273. # First up-sample the histogram with new data by a factor of L
  274. # This creates an approximate probability density thats piecwise constant
  275. upsampled_histogram = new_hist.repeat(upsample_rate)
  276. # Now insert the upsampled histogram into the output
  277. # histogram, which is initialized with zeros.
  278. # The offset at which the histogram is introduced is determined
  279. # by the start index as the output histogram can cover a wider range
  280. histogram_with_output_range = np.zeros((Nbins * downsample_rate))
  281. histogram_with_output_range[
  282. start_idx : Nbins * upsample_rate + start_idx
  283. ] = upsampled_histogram
  284. # Compute integral histogram, double precision is needed to ensure
  285. # that there are no overflows
  286. integral_histogram = np.cumsum(histogram_with_output_range, 0)[
  287. downsample_rate - 1 :: downsample_rate
  288. ]
  289. # Finally perform interpolation
  290. shifted_integral_histogram = np.zeros((Nbins))
  291. shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1]
  292. interpolated_histogram = (
  293. integral_histogram - shifted_integral_histogram
  294. ) / upsample_rate
  295. orig_hist = orig_hist + interpolated_histogram
  296. return orig_hist
  297. def _adjust_min_max(self, combined_min, combined_max, upsample_rate):
  298. # We ensure that:
  299. # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
  300. # This allows us to have a common grid of resolution s, where we can align
  301. # the input histogram
  302. # start_idx maps min_val to the histogram bin index.
  303. np_min_val = self.min_val.numpy()[0]
  304. np_max_val = self.max_val.numpy()[0]
  305. hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate)
  306. downsample_rate = int(
  307. np.ceil((combined_max - combined_min) / (self.bins * hist_bin_width))
  308. )
  309. e = downsample_rate * (self.bins * hist_bin_width) - (
  310. combined_max - combined_min
  311. )
  312. combined_max = combined_max + e / 2
  313. combined_min = combined_min - e / 2
  314. start_idx = int(np.round((np_min_val - combined_min) / hist_bin_width))
  315. return combined_min, combined_max, downsample_rate, start_idx
  316. def sideeffect_forward(self, x_orig):
  317. x = x_orig.numpy()
  318. min_val = self.min_val.numpy()[0]
  319. max_val = self.max_val.numpy()[0]
  320. histogram = self.histogram.numpy()
  321. new_min = x.min()
  322. new_max = x.max()
  323. if histogram[0] == -1:
  324. new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
  325. else:
  326. new_min = min(new_min, min_val)
  327. new_max = max(new_max, max_val)
  328. # combine the existing histogram and new histogram into 1 histogram
  329. # We do this by first upsampling the histogram to a dense grid
  330. # and then downsampling the histogram efficiently
  331. (new_min, new_max, downsample_rate, start_idx,) = self._adjust_min_max(
  332. new_min, new_max, self.upsample_rate
  333. )
  334. new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
  335. new_histogram = new_histogram.astype(np.float64)
  336. if new_min == min_val and new_max == max_val:
  337. new_histogram += histogram
  338. else:
  339. new_histogram = self._combine_histograms(
  340. new_histogram,
  341. histogram,
  342. self.upsample_rate,
  343. downsample_rate,
  344. start_idx,
  345. self.bins,
  346. )
  347. self.histogram._reset(new_histogram)
  348. self.min_val._reset(new_min)
  349. self.max_val._reset(new_max)
  350. def forward(self, x_orig):
  351. self.sideeffect_forward(x_orig)
  352. return x_orig

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台