You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

observer.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import math
  9. from abc import abstractmethod
  10. import numpy as np
  11. from .. import functional as F
  12. from .._internal.dtype import _metadata_dict, get_quantized_dtype
  13. from ..core import Buffer, Function, tensor
  14. from ..jit import sideeffect
  15. from ..module import Module
  16. class Round(Function):
  17. def forward(self, x):
  18. return x.round()
  19. def backward(self, output_grads):
  20. return output_grads
  21. class Observer(Module):
  22. r"""
  23. A base class for Observer Module.
  24. :param dtype: a string indicating to collect scale and zero_point of which dtype
  25. """
  26. def __init__(self, dtype="qint8"):
  27. super().__init__()
  28. if dtype not in _metadata_dict.keys():
  29. raise ValueError(
  30. "unknown dtype: {}, only support {}".format(
  31. dtype, _metadata_dict.keys()
  32. )
  33. )
  34. self.dtype = dtype
  35. self.qmin = _metadata_dict[dtype].qmin
  36. self.qmax = _metadata_dict[dtype].qmax
  37. self.enabled = True
  38. def get_dtype(self):
  39. scale, zero_point = self.get_qparams()
  40. numpy_scale = None if scale is None else scale.numpy()[0]
  41. numpy_zero_point = None if zero_point is None else zero_point.numpy()[0]
  42. return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point)
  43. def enable(self):
  44. self.enabled = True
  45. def disable(self):
  46. self.enabled = False
  47. def train(self, mode: bool = True) -> None:
  48. super().train(mode)
  49. if mode:
  50. self.enable()
  51. else:
  52. self.disable()
  53. @abstractmethod
  54. def forward(self, x):
  55. pass
  56. @abstractmethod
  57. def get_qparams(self, **kwargs):
  58. pass
  59. class MinMaxObserver(Observer):
  60. def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs):
  61. super().__init__(*args, **kwargs)
  62. self.symmetric = symmetric
  63. if self.symmetric:
  64. # assert qmin + qmax == -1, 'when reduce_range, qmin + qmax shoule equal -1'
  65. self.zero_point = tensor((self.qmin + self.qmax + 1) // 2)
  66. self.min_val = Buffer(0.0, dtype=np.float32)
  67. self.max_val = Buffer(0.0, dtype=np.float32)
  68. self.scale_limit = eps
  69. # flag is used by cond_take, first time will be first flag, and after will be set as not_flag
  70. self.first_flag = Buffer(np.array([1, 0], dtype=np.int32))
  71. self.not_flag = Buffer(np.array([0, 1], dtype=np.int32))
  72. def set_min_max(self, tmp_min, tmp_max):
  73. # FIXME: cond_take will destory shape, use reshape to reset shape
  74. tmp_min = tmp_min.reshape(1)
  75. tmp_max = tmp_max.reshape(1)
  76. F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0)
  77. F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0)
  78. F.add_update(self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0)
  79. def _calculate_qparams(self, inp_min_val, inp_max_val):
  80. min_val = F.minimum(0.0, inp_min_val)
  81. max_val = F.maximum(0.0, inp_max_val)
  82. if self.symmetric:
  83. symmetric_max_vals = F.maximum(-min_val, max_val)
  84. # use maximun to avoid scale too small at the begin
  85. scale = F.maximum(
  86. symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
  87. )
  88. zero_point = self.zero_point
  89. else:
  90. # use maximun to avoid scale too small at the begin
  91. scale = F.maximum(
  92. (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit,
  93. )
  94. # caculate zero_point
  95. zero_point = self.qmin - Round()((min_val / scale))
  96. return scale, zero_point
  97. def get_qparams(self):
  98. return self._calculate_qparams(self.min_val, self.max_val)
  99. def forward(self, x_orig):
  100. if self.enabled:
  101. # stop gradient
  102. x = F.zero_grad(x_orig)
  103. # find max and min
  104. tmp_min, _ = F.cond_take(
  105. self.first_flag, F.concat([x.min(), F.minimum(self.min_val, x.min())])
  106. )
  107. tmp_max, _ = F.cond_take(
  108. self.first_flag, F.concat([x.max(), F.maximum(self.max_val, x.max())])
  109. )
  110. self.set_min_max(tmp_min, tmp_max)
  111. return x_orig
  112. class ExponentialMovingAverageObserver(MinMaxObserver):
  113. def __init__(self, momentum=0.9, *args, **kwargs):
  114. super().__init__(*args, **kwargs)
  115. self.momentum = Buffer(momentum)
  116. def set_momentum(self, momentum):
  117. self.momentum.set_value(momentum)
  118. def forward(self, x_orig):
  119. if self.enabled:
  120. # stop gradient
  121. x = F.zero_grad(x_orig)
  122. # Exponential Moving Average
  123. tmp_min, _ = F.cond_take(
  124. self.first_flag,
  125. F.concat(
  126. [
  127. x.min(),
  128. self.momentum * self.min_val + (1 - self.momentum) * x.min(),
  129. ]
  130. ),
  131. )
  132. tmp_max, _ = F.cond_take(
  133. self.first_flag,
  134. F.concat(
  135. [
  136. x.max(),
  137. self.momentum * self.max_val + (1 - self.momentum) * x.max(),
  138. ]
  139. ),
  140. )
  141. self.set_min_max(tmp_min, tmp_max)
  142. return x_orig
  143. class HistogramObserver(MinMaxObserver):
  144. def __init__(self, bins=2048, upsample_rate=128, dtype="qint8", *args, **kwargs):
  145. super().__init__(dtype=dtype, *args, **kwargs)
  146. self.bins = bins
  147. self.upsample_rate = upsample_rate
  148. self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1
  149. self.histogram = None
  150. def _non_linear_param_search(self):
  151. r"""Non-linear parameter search.
  152. An approximation for L2 error minimization for selecting min/max.
  153. By selecting new min/max, we filter out outliers in input distribution.
  154. """
  155. def _get_norm(delta_begin, delta_end, density, norm_type):
  156. r"""
  157. Compute the norm of the values uniformaly distributed between
  158. delta_begin and delta_end.
  159. norm = density * (integral_{begin, end} x^2)
  160. = density * (end^3 - begin^3) / 3
  161. """
  162. assert norm_type == "L2", "Only L2 norms are currently supported"
  163. norm = 0.0
  164. if norm_type == "L2":
  165. norm = (
  166. delta_end * delta_end * delta_end
  167. - delta_begin * delta_begin * delta_begin
  168. ) / 3
  169. return density * norm
  170. def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
  171. r"""
  172. Compute the quantization error if we use start_bin to end_bin as the
  173. min and max to do the quantization.
  174. """
  175. np_min_val = self.min_val.numpy()[0]
  176. np_max_val = self.max_val.numpy()[0]
  177. bin_width = (np_max_val - np_min_val) / self.bins
  178. norm = 0.0
  179. dst_bin_width = (
  180. bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
  181. )
  182. if dst_bin_width == 0.0:
  183. return 0.0
  184. for src_bin in range(self.bins):
  185. # distances from the beginning of first dst_bin to the beginning and
  186. # end of src_bin
  187. src_bin_begin = (src_bin - next_start_bin) * bin_width
  188. src_bin_end = src_bin_begin + bin_width
  189. # which dst_bins the beginning and end of src_bin belong to?
  190. dst_bin_of_begin = min(
  191. self.dst_nbins - 1,
  192. max(0.0, math.floor(src_bin_begin / dst_bin_width)),
  193. )
  194. dst_bin_of_end = min(
  195. self.dst_nbins - 1,
  196. max(0.0, math.floor(src_bin_end / dst_bin_width)),
  197. )
  198. dst_bin_of_begin_center = (
  199. dst_bin_of_begin * dst_bin_width + dst_bin_width / 2
  200. )
  201. density = self.histogram[src_bin] / bin_width
  202. if dst_bin_of_begin == dst_bin_of_end:
  203. # if src_bin is entirely within 1 dst_bin
  204. delta_begin = src_bin_begin - dst_bin_of_begin_center
  205. delta_end = src_bin_end - dst_bin_of_begin_center
  206. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  207. else:
  208. delta_begin = src_bin_begin - dst_bin_of_begin_center
  209. delta_end = dst_bin_width / 2
  210. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  211. norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm(
  212. -dst_bin_width / 2, dst_bin_width / 2, density, norm_type
  213. )
  214. dst_bin_of_end_center = (
  215. dst_bin_of_end * dst_bin_width + dst_bin_width / 2
  216. )
  217. delta_begin = -dst_bin_width / 2
  218. delta_end = src_bin_end - dst_bin_of_end_center
  219. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  220. return norm
  221. assert len(self.histogram) == self.bins, "bins mistmatch"
  222. bin_width = (self.max_val - self.min_val) / self.bins
  223. # cumulative sum
  224. total = sum(self.histogram)
  225. cSum = np.cumsum(self.histogram, axis=0)
  226. stepsize = 1e-5 # granularity
  227. alpha = 0.0 # lower bound
  228. beta = 1.0 # upper bound
  229. start_bin = 0
  230. end_bin = self.bins - 1
  231. norm_min = float("inf")
  232. while alpha < beta:
  233. # Find the next step
  234. next_alpha = alpha + stepsize
  235. next_beta = beta - stepsize
  236. # find the left and right bins between the quantile bounds
  237. l = start_bin
  238. r = end_bin
  239. while l < end_bin and cSum[l] < next_alpha * total:
  240. l = l + 1
  241. while r > start_bin and cSum[r] > next_beta * total:
  242. r = r - 1
  243. # decide the next move
  244. next_start_bin = start_bin
  245. next_end_bin = end_bin
  246. if (l - start_bin) > (end_bin - r):
  247. # move the start bin
  248. next_start_bin = l
  249. alpha = next_alpha
  250. else:
  251. # move the end bin
  252. next_end_bin = r
  253. beta = next_beta
  254. if next_start_bin == start_bin and next_end_bin == end_bin:
  255. continue
  256. # calculate the quantization error using next_start_bin and next_end_bin
  257. norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2")
  258. if norm > norm_min:
  259. break
  260. norm_min = norm
  261. start_bin = next_start_bin
  262. end_bin = next_end_bin
  263. new_min = self.min_val + bin_width * start_bin
  264. new_max = self.min_val + bin_width * (end_bin + 1)
  265. return new_min, new_max
  266. def get_qparams(self):
  267. new_min, new_max = self._non_linear_param_search()
  268. return self._calculate_qparams(new_min, new_max)
  269. def _combine_histograms(
  270. self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins
  271. ):
  272. # First up-sample the histogram with new data by a factor of L
  273. # This creates an approximate probability density thats piecwise constant
  274. upsampled_histogram = new_hist.repeat(upsample_rate)
  275. # Now insert the upsampled histogram into the output
  276. # histogram, which is initialized with zeros.
  277. # The offset at which the histogram is introduced is determined
  278. # by the start index as the output histogram can cover a wider range
  279. histogram_with_output_range = np.zeros((Nbins * downsample_rate))
  280. histogram_with_output_range[
  281. start_idx : Nbins * upsample_rate + start_idx
  282. ] = upsampled_histogram
  283. # Compute integral histogram, double precision is needed to ensure
  284. # that there are no overflows
  285. integral_histogram = np.cumsum(histogram_with_output_range, 0)[
  286. downsample_rate - 1 :: downsample_rate
  287. ]
  288. # Finally perform interpolation
  289. shifted_integral_histogram = np.zeros((Nbins))
  290. shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1]
  291. interpolated_histogram = (
  292. integral_histogram - shifted_integral_histogram
  293. ) / upsample_rate
  294. orig_hist = orig_hist + interpolated_histogram
  295. return orig_hist
  296. def _adjust_min_max(self, combined_min, combined_max, upsample_rate):
  297. # We ensure that:
  298. # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
  299. # This allows us to have a common grid of resolution s, where we can align
  300. # the input histogram
  301. # start_idx maps min_val to the histogram bin index.
  302. np_min_val = self.min_val.numpy()[0]
  303. np_max_val = self.max_val.numpy()[0]
  304. hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate)
  305. downsample_rate = int(
  306. np.ceil((combined_max - combined_min) / (self.bins * hist_bin_width))
  307. )
  308. e = downsample_rate * (self.bins * hist_bin_width) - (
  309. combined_max - combined_min
  310. )
  311. combined_max = combined_max + e / 2
  312. combined_min = combined_min - e / 2
  313. start_idx = int(np.round((np_min_val - combined_min) / hist_bin_width))
  314. return combined_min, combined_max, downsample_rate, start_idx
  315. @sideeffect
  316. def sideeffect_forward(self, x_orig):
  317. x = x_orig.numpy()
  318. min_val = self.min_val.numpy()[0]
  319. max_val = self.max_val.numpy()[0]
  320. if min_val == 0 or max_val == 0:
  321. min_val = x.min()
  322. max_val = x.max()
  323. self.min_val.set_value(min_val)
  324. self.max_val.set_value(max_val)
  325. self.histogram, _ = np.histogram(x, self.bins, (min_val, max_val))
  326. self.histogram = self.histogram.astype(np.float64)
  327. else:
  328. new_min = x.min()
  329. new_max = x.max()
  330. combined_min = min(new_min, min_val)
  331. combined_max = max(new_max, max_val)
  332. # combine the existing histogram and new histogram into 1 histogram
  333. # We do this by first upsampling the histogram to a dense grid
  334. # and then downsampling the histogram efficiently
  335. (
  336. combined_min,
  337. combined_max,
  338. downsample_rate,
  339. start_idx,
  340. ) = self._adjust_min_max(combined_min, combined_max, self.upsample_rate)
  341. combined_histogram, _ = np.histogram(
  342. x, self.bins, (combined_min, combined_max)
  343. )
  344. combined_histogram = combined_histogram.astype(np.float64)
  345. if combined_min == min_val and combined_max == max_val:
  346. combined_histogram += self.histogram
  347. else:
  348. combined_histogram = self._combine_histograms(
  349. combined_histogram,
  350. self.histogram,
  351. self.upsample_rate,
  352. downsample_rate,
  353. start_idx,
  354. self.bins,
  355. )
  356. self.histogram = combined_histogram
  357. self.min_val.set_value(combined_min)
  358. self.max_val.set_value(combined_max)
  359. def forward(self, x_orig):
  360. self.sideeffect_forward(x_orig)
  361. return x_orig

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台