You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

observer.py 16 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import math
  9. from abc import abstractmethod
  10. from enum import Enum
  11. import numpy as np
  12. from .. import functional as F
  13. from .._internal.dtype import _metadata_dict, get_quantized_dtype
  14. from ..core import Buffer, Function, tensor
  15. from ..jit import sideeffect
  16. from ..module import Module
  17. class Round(Function):
  18. def forward(self, x):
  19. return x.round()
  20. def backward(self, output_grads):
  21. return output_grads
  22. class Observer(Module):
  23. r"""
  24. A base class for Observer Module.
  25. :param dtype: a string indicating to collect scale and zero_point of which dtype
  26. """
  27. def __init__(self, dtype="qint8"):
  28. super().__init__()
  29. if dtype not in _metadata_dict.keys():
  30. raise ValueError(
  31. "unknown dtype: {}, only support {}".format(
  32. dtype, _metadata_dict.keys()
  33. )
  34. )
  35. self.dtype = dtype
  36. self.qmin = _metadata_dict[dtype].qmin
  37. self.qmax = _metadata_dict[dtype].qmax
  38. self.enabled = True
  39. def get_dtype(self):
  40. q_dict = self.get_qparams()
  41. numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0]
  42. numpy_zero_point = (
  43. None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0]
  44. )
  45. return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point)
  46. def enable(self):
  47. self.enabled = True
  48. def disable(self):
  49. self.enabled = False
  50. def train(self, mode: bool = True) -> None:
  51. super().train(mode)
  52. if mode:
  53. self.enable()
  54. else:
  55. self.disable()
  56. @abstractmethod
  57. def forward(self, x):
  58. pass
  59. @abstractmethod
  60. def get_qparams(self, **kwargs):
  61. pass
  62. class ObserverMode(Enum):
  63. SYMMERTIC = 1
  64. ASYMMERTIC = 2
  65. def create_observer_dict(mode):
  66. if mode == ObserverMode.SYMMERTIC:
  67. return {
  68. "mode": ObserverMode.SYMMERTIC,
  69. "scale": None,
  70. }
  71. else:
  72. return {
  73. "mode": ObserverMode.ASYMMERTIC,
  74. "scale": None,
  75. "zero_point": None,
  76. }
  77. class MinMaxObserver(Observer):
  78. def __init__(self, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"):
  79. super().__init__(dtype)
  80. self.mode = mode
  81. self.min_val = Buffer(0.0, dtype=np.float32)
  82. self.max_val = Buffer(0.0, dtype=np.float32)
  83. self.scale_limit = eps
  84. # flag is used by cond_take, first time will be first flag, and after will be set as not_flag
  85. self.first_flag = Buffer(np.array([1, 0], dtype=np.int32))
  86. self.not_flag = Buffer(np.array([0, 1], dtype=np.int32))
  87. def set_min_max(self, tmp_min, tmp_max):
  88. # FIXME: cond_take will destory shape, use reshape to reset shape
  89. tmp_min = tmp_min.reshape(1)
  90. tmp_max = tmp_max.reshape(1)
  91. F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0)
  92. F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0)
  93. F.add_update(self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0)
  94. def _calculate_qparams(self, inp_min_val, inp_max_val):
  95. min_val = F.minimum(0.0, inp_min_val)
  96. max_val = F.maximum(0.0, inp_max_val)
  97. q_dict = create_observer_dict(self.mode)
  98. if self.mode == ObserverMode.SYMMERTIC:
  99. symmetric_max_vals = F.maximum(-min_val, max_val)
  100. # use maximun to avoid scale too small at the begin
  101. q_dict["scale"] = F.maximum(
  102. symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
  103. )
  104. # zero_point = self.zero_point
  105. else:
  106. # use maximun to avoid scale too small at the begin
  107. q_dict["scale"] = F.maximum(
  108. (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit,
  109. )
  110. # caculate zero_point
  111. q_dict["zero_point"] = self.qmin - Round()((min_val / scale))
  112. return q_dict
  113. def get_qparams(self):
  114. return self._calculate_qparams(self.min_val, self.max_val)
  115. def forward(self, x_orig):
  116. if self.enabled:
  117. # stop gradient
  118. x = F.zero_grad(x_orig)
  119. # find max and min
  120. tmp_min, _ = F.cond_take(
  121. self.first_flag, F.concat([x.min(), F.minimum(self.min_val, x.min())])
  122. )
  123. tmp_max, _ = F.cond_take(
  124. self.first_flag, F.concat([x.max(), F.maximum(self.max_val, x.max())])
  125. )
  126. self.set_min_max(tmp_min, tmp_max)
  127. return x_orig
  128. class ExponentialMovingAverageObserver(MinMaxObserver):
  129. def __init__(
  130. self, momentum=0.9, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"
  131. ):
  132. super().__init__(mode, eps, dtype)
  133. self.momentum = Buffer(momentum)
  134. def set_momentum(self, momentum):
  135. self.momentum.set_value(momentum)
  136. def forward(self, x_orig):
  137. if self.enabled:
  138. # stop gradient
  139. x = F.zero_grad(x_orig)
  140. # Exponential Moving Average
  141. tmp_min, _ = F.cond_take(
  142. self.first_flag,
  143. F.concat(
  144. [
  145. x.min(),
  146. self.momentum * self.min_val + (1 - self.momentum) * x.min(),
  147. ]
  148. ),
  149. )
  150. tmp_max, _ = F.cond_take(
  151. self.first_flag,
  152. F.concat(
  153. [
  154. x.max(),
  155. self.momentum * self.max_val + (1 - self.momentum) * x.max(),
  156. ]
  157. ),
  158. )
  159. self.set_min_max(tmp_min, tmp_max)
  160. return x_orig
  161. class HistogramObserver(MinMaxObserver):
  162. def __init__(
  163. self,
  164. bins=2048,
  165. upsample_rate=128,
  166. dtype="qint8",
  167. mode=ObserverMode.SYMMERTIC,
  168. eps=0.00001,
  169. ):
  170. super().__init__(mode, eps, dtype)
  171. self.bins = bins
  172. self.upsample_rate = upsample_rate
  173. self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1
  174. self.histogram = None
  175. def _non_linear_param_search(self):
  176. r"""Non-linear parameter search.
  177. An approximation for L2 error minimization for selecting min/max.
  178. By selecting new min/max, we filter out outliers in input distribution.
  179. """
  180. def _get_norm(delta_begin, delta_end, density, norm_type):
  181. r"""
  182. Compute the norm of the values uniformaly distributed between
  183. delta_begin and delta_end.
  184. norm = density * (integral_{begin, end} x^2)
  185. = density * (end^3 - begin^3) / 3
  186. """
  187. assert norm_type == "L2", "Only L2 norms are currently supported"
  188. norm = 0.0
  189. if norm_type == "L2":
  190. norm = (
  191. delta_end * delta_end * delta_end
  192. - delta_begin * delta_begin * delta_begin
  193. ) / 3
  194. return density * norm
  195. def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
  196. r"""
  197. Compute the quantization error if we use start_bin to end_bin as the
  198. min and max to do the quantization.
  199. """
  200. np_min_val = self.min_val.numpy()[0]
  201. np_max_val = self.max_val.numpy()[0]
  202. bin_width = (np_max_val - np_min_val) / self.bins
  203. norm = 0.0
  204. dst_bin_width = (
  205. bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
  206. )
  207. if dst_bin_width == 0.0:
  208. return 0.0
  209. for src_bin in range(self.bins):
  210. # distances from the beginning of first dst_bin to the beginning and
  211. # end of src_bin
  212. src_bin_begin = (src_bin - next_start_bin) * bin_width
  213. src_bin_end = src_bin_begin + bin_width
  214. # which dst_bins the beginning and end of src_bin belong to?
  215. dst_bin_of_begin = min(
  216. self.dst_nbins - 1,
  217. max(0.0, math.floor(src_bin_begin / dst_bin_width)),
  218. )
  219. dst_bin_of_end = min(
  220. self.dst_nbins - 1,
  221. max(0.0, math.floor(src_bin_end / dst_bin_width)),
  222. )
  223. dst_bin_of_begin_center = (
  224. dst_bin_of_begin * dst_bin_width + dst_bin_width / 2
  225. )
  226. density = self.histogram[src_bin] / bin_width
  227. if dst_bin_of_begin == dst_bin_of_end:
  228. # if src_bin is entirely within 1 dst_bin
  229. delta_begin = src_bin_begin - dst_bin_of_begin_center
  230. delta_end = src_bin_end - dst_bin_of_begin_center
  231. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  232. else:
  233. delta_begin = src_bin_begin - dst_bin_of_begin_center
  234. delta_end = dst_bin_width / 2
  235. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  236. norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm(
  237. -dst_bin_width / 2, dst_bin_width / 2, density, norm_type
  238. )
  239. dst_bin_of_end_center = (
  240. dst_bin_of_end * dst_bin_width + dst_bin_width / 2
  241. )
  242. delta_begin = -dst_bin_width / 2
  243. delta_end = src_bin_end - dst_bin_of_end_center
  244. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  245. return norm
  246. assert len(self.histogram) == self.bins, "bins mistmatch"
  247. bin_width = (self.max_val - self.min_val) / self.bins
  248. # cumulative sum
  249. total = sum(self.histogram)
  250. cSum = np.cumsum(self.histogram, axis=0)
  251. stepsize = 1e-5 # granularity
  252. alpha = 0.0 # lower bound
  253. beta = 1.0 # upper bound
  254. start_bin = 0
  255. end_bin = self.bins - 1
  256. norm_min = float("inf")
  257. while alpha < beta:
  258. # Find the next step
  259. next_alpha = alpha + stepsize
  260. next_beta = beta - stepsize
  261. # find the left and right bins between the quantile bounds
  262. l = start_bin
  263. r = end_bin
  264. while l < end_bin and cSum[l] < next_alpha * total:
  265. l = l + 1
  266. while r > start_bin and cSum[r] > next_beta * total:
  267. r = r - 1
  268. # decide the next move
  269. next_start_bin = start_bin
  270. next_end_bin = end_bin
  271. if (l - start_bin) > (end_bin - r):
  272. # move the start bin
  273. next_start_bin = l
  274. alpha = next_alpha
  275. else:
  276. # move the end bin
  277. next_end_bin = r
  278. beta = next_beta
  279. if next_start_bin == start_bin and next_end_bin == end_bin:
  280. continue
  281. # calculate the quantization error using next_start_bin and next_end_bin
  282. norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2")
  283. if norm > norm_min:
  284. break
  285. norm_min = norm
  286. start_bin = next_start_bin
  287. end_bin = next_end_bin
  288. new_min = self.min_val + bin_width * start_bin
  289. new_max = self.min_val + bin_width * (end_bin + 1)
  290. return new_min, new_max
  291. def get_qparams(self):
  292. new_min, new_max = self._non_linear_param_search()
  293. return self._calculate_qparams(new_min, new_max)
  294. def _combine_histograms(
  295. self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins
  296. ):
  297. # First up-sample the histogram with new data by a factor of L
  298. # This creates an approximate probability density thats piecwise constant
  299. upsampled_histogram = new_hist.repeat(upsample_rate)
  300. # Now insert the upsampled histogram into the output
  301. # histogram, which is initialized with zeros.
  302. # The offset at which the histogram is introduced is determined
  303. # by the start index as the output histogram can cover a wider range
  304. histogram_with_output_range = np.zeros((Nbins * downsample_rate))
  305. histogram_with_output_range[
  306. start_idx : Nbins * upsample_rate + start_idx
  307. ] = upsampled_histogram
  308. # Compute integral histogram, double precision is needed to ensure
  309. # that there are no overflows
  310. integral_histogram = np.cumsum(histogram_with_output_range, 0)[
  311. downsample_rate - 1 :: downsample_rate
  312. ]
  313. # Finally perform interpolation
  314. shifted_integral_histogram = np.zeros((Nbins))
  315. shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1]
  316. interpolated_histogram = (
  317. integral_histogram - shifted_integral_histogram
  318. ) / upsample_rate
  319. orig_hist = orig_hist + interpolated_histogram
  320. return orig_hist
  321. def _adjust_min_max(self, combined_min, combined_max, upsample_rate):
  322. # We ensure that:
  323. # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
  324. # This allows us to have a common grid of resolution s, where we can align
  325. # the input histogram
  326. # start_idx maps min_val to the histogram bin index.
  327. np_min_val = self.min_val.numpy()[0]
  328. np_max_val = self.max_val.numpy()[0]
  329. hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate)
  330. downsample_rate = int(
  331. np.ceil((combined_max - combined_min) / (self.bins * hist_bin_width))
  332. )
  333. e = downsample_rate * (self.bins * hist_bin_width) - (
  334. combined_max - combined_min
  335. )
  336. combined_max = combined_max + e / 2
  337. combined_min = combined_min - e / 2
  338. start_idx = int(np.round((np_min_val - combined_min) / hist_bin_width))
  339. return combined_min, combined_max, downsample_rate, start_idx
  340. @sideeffect
  341. def sideeffect_forward(self, x_orig):
  342. x = x_orig.numpy()
  343. min_val = self.min_val.numpy()[0]
  344. max_val = self.max_val.numpy()[0]
  345. if min_val == 0 or max_val == 0:
  346. min_val = x.min()
  347. max_val = x.max()
  348. self.min_val.set_value(min_val)
  349. self.max_val.set_value(max_val)
  350. self.histogram, _ = np.histogram(x, self.bins, (min_val, max_val))
  351. self.histogram = self.histogram.astype(np.float64)
  352. else:
  353. new_min = x.min()
  354. new_max = x.max()
  355. combined_min = min(new_min, min_val)
  356. combined_max = max(new_max, max_val)
  357. # combine the existing histogram and new histogram into 1 histogram
  358. # We do this by first upsampling the histogram to a dense grid
  359. # and then downsampling the histogram efficiently
  360. (
  361. combined_min,
  362. combined_max,
  363. downsample_rate,
  364. start_idx,
  365. ) = self._adjust_min_max(combined_min, combined_max, self.upsample_rate)
  366. combined_histogram, _ = np.histogram(
  367. x, self.bins, (combined_min, combined_max)
  368. )
  369. combined_histogram = combined_histogram.astype(np.float64)
  370. if combined_min == min_val and combined_max == max_val:
  371. combined_histogram += self.histogram
  372. else:
  373. combined_histogram = self._combine_histograms(
  374. combined_histogram,
  375. self.histogram,
  376. self.upsample_rate,
  377. downsample_rate,
  378. start_idx,
  379. self.bins,
  380. )
  381. self.histogram = combined_histogram
  382. self.min_val.set_value(combined_min)
  383. self.max_val.set_value(combined_max)
  384. def forward(self, x_orig):
  385. self.sideeffect_forward(x_orig)
  386. return x_orig

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台