You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

observer.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import math
  9. from abc import abstractmethod
  10. from copy import deepcopy
  11. from typing import Union
  12. import numpy as np
  13. from .. import functional as F
  14. from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes
  15. from ..distributed import WORLD, get_rank, is_distributed
  16. from ..functional.distributed import all_reduce_max, all_reduce_min
  17. from ..logger import get_logger
  18. from ..module import Module
  19. from ..tensor import Tensor
  20. from .utils import QParams, QParamsModuleMixin, QuantMode, create_qparams
  21. logger = get_logger(__name__)
  22. class Observer(Module, QParamsModuleMixin):
  23. r"""A base class for Observer Module. Used to record input tensor's statistics for
  24. quantization.
  25. Args:
  26. dtype: a string indicating which dtype to collect scale and zero_point of.
  27. """
  28. def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
  29. super().__init__()
  30. if isinstance(dtype, str):
  31. if not dtype in _builtin_quant_dtypes:
  32. raise ValueError(
  33. "unknown dtype: {}, only support {}".format(
  34. dtype, _builtin_quant_dtypes.keys()
  35. )
  36. )
  37. dtype = _builtin_quant_dtypes[dtype]
  38. if "narrow_range" in kwargs:
  39. del kwargs["narrow_range"]
  40. logger.warning(
  41. "FakeQuantize currently has no narrow_range param "
  42. "so it is ignored here",
  43. exc_info=DeprecationWarning,
  44. )
  45. self.dtype = dtype
  46. self.qmin = dtype.qmin
  47. self.qmax = dtype.qmax
  48. self.enabled = True
  49. def enable(self):
  50. self.enabled = True
  51. def disable(self):
  52. self.enabled = False
  53. def train(self, mode: bool = True, recursive: bool = True) -> None:
  54. super().train(mode, recursive)
  55. if mode:
  56. self.enable()
  57. else:
  58. self.disable()
  59. @abstractmethod
  60. def forward(self, x):
  61. pass
  62. class MinMaxObserver(Observer):
  63. r"""A Observer Module records input tensor's running min and max values to calc scale.
  64. Args:
  65. mode: set quantization mode.
  66. eps: a initial maximum value to avoid division by zero problem.
  67. dtype: a string indicating which dtype to collect scale and zero_point of.
  68. """
  69. def __init__(
  70. self,
  71. mode: QuantMode = QuantMode.SYMMERTIC,
  72. eps: float = 0.00001,
  73. dtype: Union[str, QuantDtypeMeta] = "qint8",
  74. **kwargs
  75. ):
  76. super().__init__(dtype, **kwargs)
  77. self.mode = mode
  78. self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32)
  79. self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32)
  80. self.scale_limit = eps
  81. def _calculate_qparams(self, inp_min_val, inp_max_val):
  82. min_val = F.minimum(0.0, inp_min_val)
  83. max_val = F.maximum(0.0, inp_max_val)
  84. if self.mode == QuantMode.SYMMERTIC:
  85. symmetric_max_vals = F.maximum(-min_val, max_val)
  86. # use maximun to avoid scale too small at the begin
  87. scale = F.maximum(
  88. symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
  89. )
  90. zero_point = None
  91. else:
  92. # use maximun to avoid scale too small at the begin
  93. scale = F.maximum(
  94. (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit
  95. )
  96. # caculate zero_point
  97. zero_point = self.qmin - F.round((min_val / scale))
  98. return create_qparams(self.mode, self.dtype, scale=scale, zero_point=zero_point)
  99. def get_qparams(self):
  100. return self._calculate_qparams(self.min_val, self.max_val)
  101. def forward(self, x_orig):
  102. if self.enabled:
  103. # stop gradient
  104. x = x_orig.detach()
  105. # find max and min
  106. self.min_val[...] = F.minimum(self.min_val, x.min())
  107. self.max_val[...] = F.maximum(self.max_val, x.max())
  108. return x_orig
  109. class SyncMinMaxObserver(MinMaxObserver):
  110. r"""A distributed version of :class:`~.MinMaxObserver`.
  111. Args:
  112. mode: set quantization mode.
  113. eps: a initial maximum value to avoid division by zero problem.
  114. dtype: a string indicating which dtype to collect scale and zero_point of.
  115. """
  116. def forward(self, x_orig):
  117. if self.enable:
  118. x = x_orig.detach()
  119. if is_distributed():
  120. min_x = all_reduce_min(x.min(), WORLD)
  121. max_x = all_reduce_max(x.max(), WORLD)
  122. else:
  123. min_x = x.min()
  124. max_x = x.max()
  125. self.min_val[...] = F.minimum(self.min_val, min_x)
  126. self.max_val[...] = F.maximum(self.max_val, max_x)
  127. return x_orig
  128. class ExponentialMovingAverageObserver(MinMaxObserver):
  129. r"""A :class:`~.MinMaxObserver` with momentum support for min/max updating.
  130. Args:
  131. momentum: momentum ratio for min/max updating.
  132. mode: set quantization mode.
  133. eps: a initial maximum value to avoid division by zero problem.
  134. dtype: a string indicating which dtype to collect scale and zero_point of.
  135. """
  136. def __init__(
  137. self,
  138. momentum: float = 0.9,
  139. mode: QuantMode = QuantMode.SYMMERTIC,
  140. eps: float = 0.00001,
  141. dtype: Union[str, QuantDtypeMeta] = "qint8",
  142. **kwargs
  143. ):
  144. super().__init__(mode, eps, dtype, **kwargs)
  145. self.momentum = Tensor(momentum, dtype="float32")
  146. # used to avoid if-clauses in the first forward which is not supported
  147. # in trace mode.
  148. self.runtime_momentum = Tensor(0.0)
  149. def set_momentum(self, momentum):
  150. self.momentum = Tensor(momentum, dtype="float32")
  151. def forward(self, x_orig):
  152. if self.enabled:
  153. # stop gradient
  154. x = x_orig.detach()
  155. # Exponential Moving Average
  156. self.min_val[...] = (
  157. self.min_val * self.runtime_momentum
  158. + (1 - self.runtime_momentum) * x.min()
  159. )
  160. self.max_val[...] = (
  161. self.max_val * self.runtime_momentum
  162. + (1 - self.runtime_momentum) * x.max()
  163. )
  164. self.runtime_momentum[...] = self.momentum
  165. return x_orig
  166. class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
  167. r"""A distributed version of :class:`~.ExponentialMovingAverageObserver`.
  168. Args:
  169. momentum: momentum ratio for min/max updating.
  170. mode: set quantization mode.
  171. eps: a initial maximum value to avoid division by zero problem.
  172. dtype: a string indicating which dtype to collect scale and zero_point of.
  173. """
  174. def forward(self, x_orig):
  175. if self.enabled:
  176. x = x_orig.detach()
  177. if is_distributed:
  178. min_x = all_reduce_min(x.min(), WORLD)
  179. max_x = all_reduce_max(x.max(), WORLD)
  180. else:
  181. min_x = x.min()
  182. max_x = x.max()
  183. self.min_val[...] = (
  184. self.min_val * self.runtime_momentum
  185. + (1 - self.runtime_momentum) * min_x
  186. )
  187. self.max_val[...] = (
  188. self.max_val * self.runtime_momentum
  189. + (1 - self.runtime_momentum) * max_x
  190. )
  191. self.runtime_momentum[...] = self.momentum
  192. return x_orig
  193. class HistogramObserver(MinMaxObserver):
  194. r"""A :class:`~.MinMaxObserver` using running histogram of tensor values
  195. for min/max updating. Usually used for calibration quantization.
  196. Args:
  197. bins: number of bins to use for the histogram.
  198. upsample_rate: which ratio to interpolate histograms in.
  199. mode: set quantization mode.
  200. eps: a initial maximum value to avoid division by zero problem.
  201. dtype: a string indicating which dtype to collect scale and zero_point of.
  202. """
  203. def __init__(
  204. self,
  205. bins: int = 2048,
  206. upsample_rate: int = 128,
  207. mode: QuantMode = QuantMode.SYMMERTIC,
  208. eps: float = 0.00001,
  209. dtype: Union[str, QuantDtypeMeta] = "qint8",
  210. **kwargs
  211. ):
  212. super().__init__(mode, eps, dtype, **kwargs)
  213. self.bins = bins
  214. self.upsample_rate = upsample_rate
  215. self.dst_nbins = (
  216. _builtin_quant_dtypes[dtype].qmax - _builtin_quant_dtypes[dtype].qmin + 1
  217. )
  218. self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32")
  219. def _non_linear_param_search(self):
  220. r"""Non-linear parameter search.
  221. An approximation for L2 error minimization for selecting min/max.
  222. By selecting new min/max, we filter out outliers in input distribution.
  223. """
  224. np_min_val = self.min_val.numpy()
  225. np_max_val = self.max_val.numpy()
  226. np_histogram = self.histogram.numpy()
  227. assert len(np_histogram) == self.bins, "bins mistmatch"
  228. bin_width = (np_max_val - np_min_val) / self.bins
  229. def _get_norm(delta_begin, delta_end, density, norm_type):
  230. r"""Compute the norm of the values uniformaly distributed between
  231. delta_begin and delta_end.
  232. norm = density * (integral_{begin, end} x^2)
  233. = density * (end^3 - begin^3) / 3
  234. """
  235. assert norm_type == "L2", "Only L2 norms are currently supported"
  236. norm = 0.0
  237. if norm_type == "L2":
  238. norm = (
  239. delta_end * delta_end * delta_end
  240. - delta_begin * delta_begin * delta_begin
  241. ) / 3
  242. return density * norm
  243. def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
  244. r"""Compute the quantization error if we use start_bin to end_bin as the
  245. min and max to do the quantization.
  246. """
  247. norm = 0.0
  248. dst_bin_width = (
  249. bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
  250. )
  251. if dst_bin_width == 0.0:
  252. return 0.0
  253. for src_bin in range(self.bins):
  254. # distances from the beginning of first dst_bin to the beginning and
  255. # end of src_bin
  256. src_bin_begin = (src_bin - next_start_bin) * bin_width
  257. src_bin_end = src_bin_begin + bin_width
  258. # which dst_bins the beginning and end of src_bin belong to?
  259. dst_bin_of_begin = min(
  260. self.dst_nbins - 1,
  261. max(0.0, math.floor(src_bin_begin / dst_bin_width)),
  262. )
  263. dst_bin_of_end = min(
  264. self.dst_nbins - 1,
  265. max(0.0, math.floor(src_bin_end / dst_bin_width)),
  266. )
  267. dst_bin_of_begin_center = (
  268. dst_bin_of_begin * dst_bin_width + dst_bin_width / 2
  269. )
  270. density = np_histogram[src_bin] / bin_width
  271. if dst_bin_of_begin == dst_bin_of_end:
  272. # if src_bin is entirely within 1 dst_bin
  273. delta_begin = src_bin_begin - dst_bin_of_begin_center
  274. delta_end = src_bin_end - dst_bin_of_begin_center
  275. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  276. else:
  277. delta_begin = src_bin_begin - dst_bin_of_begin_center
  278. delta_end = dst_bin_width / 2
  279. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  280. norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm(
  281. -dst_bin_width / 2, dst_bin_width / 2, density, norm_type
  282. )
  283. dst_bin_of_end_center = (
  284. dst_bin_of_end * dst_bin_width + dst_bin_width / 2
  285. )
  286. delta_begin = -dst_bin_width / 2
  287. delta_end = src_bin_end - dst_bin_of_end_center
  288. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  289. return norm
  290. # cumulative sum
  291. total = sum(np_histogram)
  292. cSum = np.cumsum(np_histogram, axis=0)
  293. stepsize = 1e-5 # granularity
  294. alpha = 0.0 # lower bound
  295. beta = 1.0 # upper bound
  296. start_bin = 0
  297. end_bin = self.bins - 1
  298. norm_min = float("inf")
  299. while alpha < beta:
  300. # Find the next step
  301. next_alpha = alpha + stepsize
  302. next_beta = beta - stepsize
  303. # find the left and right bins between the quantile bounds
  304. l = start_bin
  305. r = end_bin
  306. while l < end_bin and cSum[l] < next_alpha * total:
  307. l = l + 1
  308. while r > start_bin and cSum[r] > next_beta * total:
  309. r = r - 1
  310. # decide the next move
  311. next_start_bin = start_bin
  312. next_end_bin = end_bin
  313. if (l - start_bin) > (end_bin - r):
  314. # move the start bin
  315. next_start_bin = l
  316. alpha = next_alpha
  317. else:
  318. # move the end bin
  319. next_end_bin = r
  320. beta = next_beta
  321. if next_start_bin == start_bin and next_end_bin == end_bin:
  322. continue
  323. # calculate the quantization error using next_start_bin and next_end_bin
  324. norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2")
  325. if norm > norm_min:
  326. break
  327. norm_min = norm
  328. start_bin = next_start_bin
  329. end_bin = next_end_bin
  330. new_min = self.min_val + Tensor(bin_width * start_bin, dtype=np.float32)
  331. new_max = self.min_val + Tensor(bin_width * (end_bin + 1), dtype=np.float32)
  332. return new_min, new_max
  333. def get_qparams(self):
  334. new_min, new_max = self._non_linear_param_search()
  335. return self._calculate_qparams(new_min, new_max)
  336. def _combine_histograms(
  337. self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins
  338. ):
  339. # First up-sample the histogram with new data by a factor of L
  340. # This creates an approximate probability density thats piecwise constant
  341. upsampled_histogram = new_hist.repeat(upsample_rate)
  342. # Now insert the upsampled histogram into the output
  343. # histogram, which is initialized with zeros.
  344. # The offset at which the histogram is introduced is determined
  345. # by the start index as the output histogram can cover a wider range
  346. histogram_with_output_range = np.zeros((Nbins * downsample_rate))
  347. histogram_with_output_range[
  348. start_idx : Nbins * upsample_rate + start_idx
  349. ] = upsampled_histogram
  350. # Compute integral histogram, double precision is needed to ensure
  351. # that there are no overflows
  352. integral_histogram = np.cumsum(histogram_with_output_range, 0)[
  353. downsample_rate - 1 :: downsample_rate
  354. ]
  355. # Finally perform interpolation
  356. shifted_integral_histogram = np.zeros((Nbins))
  357. shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1]
  358. interpolated_histogram = (
  359. integral_histogram - shifted_integral_histogram
  360. ) / upsample_rate
  361. orig_hist = orig_hist + interpolated_histogram
  362. return orig_hist
  363. def _adjust_min_max(self, combined_min, combined_max, upsample_rate):
  364. # We ensure that:
  365. # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
  366. # This allows us to have a common grid of resolution s, where we can align
  367. # the input histogram
  368. # start_idx maps min_val to the histogram bin index.
  369. np_min_val = self.min_val.numpy()
  370. np_max_val = self.max_val.numpy()
  371. hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate)
  372. downsample_rate = int(
  373. np.ceil((combined_max - combined_min) / (self.bins * hist_bin_width))
  374. )
  375. e = downsample_rate * (self.bins * hist_bin_width) - (
  376. combined_max - combined_min
  377. )
  378. combined_max = combined_max + e / 2
  379. combined_min = combined_min - e / 2
  380. start_idx = int(np.round((np_min_val - combined_min) / hist_bin_width))
  381. return combined_min, combined_max, downsample_rate, start_idx
  382. def sideeffect_forward(self, x_orig):
  383. x = x_orig.numpy()
  384. min_val = self.min_val.numpy()
  385. max_val = self.max_val.numpy()
  386. histogram = self.histogram.numpy()
  387. new_min = x.min()
  388. new_max = x.max()
  389. if histogram[0] == -1:
  390. new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
  391. else:
  392. new_min = min(new_min, min_val)
  393. new_max = max(new_max, max_val)
  394. # combine the existing histogram and new histogram into 1 histogram
  395. # We do this by first upsampling the histogram to a dense grid
  396. # and then downsampling the histogram efficiently
  397. (new_min, new_max, downsample_rate, start_idx) = self._adjust_min_max(
  398. new_min, new_max, self.upsample_rate
  399. )
  400. new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
  401. new_histogram = new_histogram.astype(np.float64)
  402. if new_min == min_val and new_max == max_val:
  403. new_histogram += histogram
  404. else:
  405. new_histogram = self._combine_histograms(
  406. new_histogram,
  407. histogram,
  408. self.upsample_rate,
  409. downsample_rate,
  410. start_idx,
  411. self.bins,
  412. )
  413. self.histogram = Tensor(new_histogram, dtype="float32")
  414. self.min_val = Tensor(new_min, dtype="float32")
  415. self.max_val = Tensor(new_max, dtype="float32")
  416. def forward(self, x_orig):
  417. self.sideeffect_forward(x_orig)
  418. return x_orig
  419. class PassiveObserver(Observer):
  420. r"""An Observer that supports setting :attr:`scale` directly."""
  421. def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
  422. super().__init__(dtype, **kwargs)
  423. self.qparams = None
  424. self.orig_scale = None
  425. @property
  426. def scale(self):
  427. return self.qparams.scale
  428. @scale.setter
  429. def scale(self, value: np.ndarray):
  430. assert np.all(value > 0)
  431. self.qparams.scale[...] = Tensor(value)
  432. def get_qparams(self):
  433. return self.qparams
  434. def set_qparams(self, qparams: QParams):
  435. r"""set the ``qparams``.
  436. Args:
  437. qparams: used to set initial scale.
  438. """
  439. self.qparams = deepcopy(qparams)
  440. if qparams.scale is None:
  441. raise AssertionError("Can not get an initialized scale")
  442. if qparams.dtype_meta is None:
  443. qparams.dtype_meta = self.dtype
  444. else:
  445. assert (
  446. qparams.dtype_meta is self.dtype
  447. ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
  448. qparams.dtype_meta, self.dtype
  449. )
  450. self.orig_scale = qparams.scale.numpy()
  451. def forward(self, x):
  452. r"""Just return input because :attr:`qparams` is set by :func:`~.apply_easy_quant`."""
  453. return x

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台