You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

observer.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import math
  9. from abc import abstractmethod
  10. from copy import deepcopy
  11. from typing import Union
  12. import numpy as np
  13. from .. import functional as F
  14. from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes
  15. from ..distributed import WORLD, get_rank, is_distributed
  16. from ..functional.distributed import all_reduce_max, all_reduce_min
  17. from ..logger import get_logger
  18. from ..module import Module
  19. from ..tensor import Tensor
  20. from .utils import QParams, QParamsModuleMixin, QuantMode, create_qparams
  21. logger = get_logger(__name__)
  22. class Observer(Module, QParamsModuleMixin):
  23. r"""
  24. A base class for Observer Module. Used to record input tensor's statistics for
  25. quantization.
  26. :param dtype: a string indicating which dtype to collect scale and zero_point of.
  27. """
  28. def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
  29. super().__init__()
  30. if isinstance(dtype, str):
  31. if not dtype in _builtin_quant_dtypes:
  32. raise ValueError(
  33. "unknown dtype: {}, only support {}".format(
  34. dtype, _builtin_quant_dtypes.keys()
  35. )
  36. )
  37. dtype = _builtin_quant_dtypes[dtype]
  38. if "narrow_range" in kwargs:
  39. del kwargs["narrow_range"]
  40. logger.warning(
  41. "FakeQuantize currently has no narrow_range param "
  42. "so it is ignored here",
  43. exc_info=DeprecationWarning,
  44. )
  45. self.dtype = dtype
  46. self.qmin = dtype.qmin
  47. self.qmax = dtype.qmax
  48. self.enabled = True
  49. def enable(self):
  50. self.enabled = True
  51. def disable(self):
  52. self.enabled = False
  53. def train(self, mode: bool = True, recursive: bool = True) -> None:
  54. super().train(mode, recursive)
  55. if mode:
  56. self.enable()
  57. else:
  58. self.disable()
  59. @abstractmethod
  60. def forward(self, x):
  61. pass
  62. class MinMaxObserver(Observer):
  63. r"""
  64. A Observer Module records input tensor's running min and max values to calc scale.
  65. :param mode: set quantization mode.
  66. :param eps: a initial maximum value to avoid division by zero problem.
  67. :param dtype: a string indicating which dtype to collect scale and zero_point of.
  68. """
  69. def __init__(
  70. self,
  71. mode: QuantMode = QuantMode.SYMMERTIC,
  72. eps: float = 0.00001,
  73. dtype: Union[str, QuantDtypeMeta] = "qint8",
  74. **kwargs
  75. ):
  76. super().__init__(dtype, **kwargs)
  77. self.mode = mode
  78. self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32)
  79. self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32)
  80. self.scale_limit = eps
  81. def _calculate_qparams(self, inp_min_val, inp_max_val):
  82. min_val = F.minimum(0.0, inp_min_val)
  83. max_val = F.maximum(0.0, inp_max_val)
  84. if self.mode == QuantMode.SYMMERTIC:
  85. symmetric_max_vals = F.maximum(-min_val, max_val)
  86. # use maximun to avoid scale too small at the begin
  87. scale = F.maximum(
  88. symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
  89. )
  90. zero_point = None
  91. else:
  92. # use maximun to avoid scale too small at the begin
  93. scale = F.maximum(
  94. (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit
  95. )
  96. # caculate zero_point
  97. zero_point = self.qmin - F.round((min_val / scale))
  98. return create_qparams(self.mode, self.dtype, scale=scale, zero_point=zero_point)
  99. def get_qparams(self):
  100. return self._calculate_qparams(self.min_val, self.max_val)
  101. def forward(self, x_orig):
  102. if self.enabled:
  103. # stop gradient
  104. x = x_orig.detach()
  105. # find max and min
  106. self.min_val[...] = F.minimum(self.min_val, x.min())
  107. self.max_val[...] = F.maximum(self.max_val, x.max())
  108. return x_orig
  109. class SyncMinMaxObserver(MinMaxObserver):
  110. r"""
  111. A distributed version of :class:`~.MinMaxObserver`.
  112. :param mode: set quantization mode.
  113. :param eps: a initial maximum value to avoid division by zero problem.
  114. :param dtype: a string indicating which dtype to collect scale and zero_point of.
  115. """
  116. def forward(self, x_orig):
  117. if self.enable:
  118. x = x_orig.detach()
  119. if is_distributed():
  120. min_x = all_reduce_min(x.min(), WORLD)
  121. max_x = all_reduce_max(x.max(), WORLD)
  122. else:
  123. min_x = x.min()
  124. max_x = x.max()
  125. self.min_val[...] = F.minimum(self.min_val, min_x)
  126. self.max_val[...] = F.maximum(self.max_val, max_x)
  127. return x_orig
  128. class ExponentialMovingAverageObserver(MinMaxObserver):
  129. r"""
  130. A :class:`~.MinMaxObserver` with momentum support for min/max updating.
  131. :param momentum: momentum ratio for min/max updating.
  132. :param mode: set quantization mode.
  133. :param eps: a initial maximum value to avoid division by zero problem.
  134. :param dtype: a string indicating which dtype to collect scale and zero_point of.
  135. """
  136. def __init__(
  137. self,
  138. momentum: float = 0.9,
  139. mode: QuantMode = QuantMode.SYMMERTIC,
  140. eps: float = 0.00001,
  141. dtype: Union[str, QuantDtypeMeta] = "qint8",
  142. **kwargs
  143. ):
  144. super().__init__(mode, eps, dtype, **kwargs)
  145. self.momentum = Tensor(momentum, dtype="float32")
  146. # used to avoid if-clauses in the first forward which is not supported
  147. # in trace mode.
  148. self.runtime_momentum = Tensor(0.0)
  149. def set_momentum(self, momentum):
  150. self.momentum = Tensor(momentum, dtype="float32")
  151. def forward(self, x_orig):
  152. if self.enabled:
  153. # stop gradient
  154. x = x_orig.detach()
  155. # Exponential Moving Average
  156. self.min_val[...] = (
  157. self.min_val * self.runtime_momentum
  158. + (1 - self.runtime_momentum) * x.min()
  159. )
  160. self.max_val[...] = (
  161. self.max_val * self.runtime_momentum
  162. + (1 - self.runtime_momentum) * x.max()
  163. )
  164. self.runtime_momentum[...] = self.momentum
  165. return x_orig
  166. class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
  167. r"""
  168. A distributed version of :class:`~.ExponentialMovingAverageObserver`.
  169. :param momentum: momentum ratio for min/max updating.
  170. :param mode: set quantization mode.
  171. :param eps: a initial maximum value to avoid division by zero problem.
  172. :param dtype: a string indicating which dtype to collect scale and zero_point of.
  173. """
  174. def forward(self, x_orig):
  175. if self.enabled:
  176. x = x_orig.detach()
  177. if is_distributed:
  178. min_x = all_reduce_min(x.min(), WORLD)
  179. max_x = all_reduce_max(x.max(), WORLD)
  180. else:
  181. min_x = x.min()
  182. max_x = x.max()
  183. self.min_val[...] = (
  184. self.min_val * self.runtime_momentum
  185. + (1 - self.runtime_momentum) * min_x
  186. )
  187. self.max_val[...] = (
  188. self.max_val * self.runtime_momentum
  189. + (1 - self.runtime_momentum) * max_x
  190. )
  191. self.runtime_momentum[...] = self.momentum
  192. return x_orig
  193. class HistogramObserver(MinMaxObserver):
  194. r"""
  195. A :class:`~.MinMaxObserver` using running histogram of tensor values
  196. for min/max updating. Usually used for calibration quantization.
  197. :param bins: number of bins to use for the histogram.
  198. :param upsample_rate: which ratio to interpolate histograms in.
  199. :param mode: set quantization mode.
  200. :param eps: a initial maximum value to avoid division by zero problem.
  201. :param dtype: a string indicating which dtype to collect scale and zero_point of.
  202. """
  203. def __init__(
  204. self,
  205. bins: int = 2048,
  206. upsample_rate: int = 128,
  207. mode: QuantMode = QuantMode.SYMMERTIC,
  208. eps: float = 0.00001,
  209. dtype: Union[str, QuantDtypeMeta] = "qint8",
  210. **kwargs
  211. ):
  212. super().__init__(mode, eps, dtype, **kwargs)
  213. self.bins = bins
  214. self.upsample_rate = upsample_rate
  215. self.dst_nbins = (
  216. _builtin_quant_dtypes[dtype].qmax - _builtin_quant_dtypes[dtype].qmin + 1
  217. )
  218. self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32")
  219. def _non_linear_param_search(self):
  220. r"""
  221. Non-linear parameter search.
  222. An approximation for L2 error minimization for selecting min/max.
  223. By selecting new min/max, we filter out outliers in input distribution.
  224. """
  225. np_min_val = self.min_val.numpy()
  226. np_max_val = self.max_val.numpy()
  227. np_histogram = self.histogram.numpy()
  228. assert len(np_histogram) == self.bins, "bins mistmatch"
  229. bin_width = (np_max_val - np_min_val) / self.bins
  230. def _get_norm(delta_begin, delta_end, density, norm_type):
  231. r"""
  232. Compute the norm of the values uniformaly distributed between
  233. delta_begin and delta_end.
  234. norm = density * (integral_{begin, end} x^2)
  235. = density * (end^3 - begin^3) / 3
  236. """
  237. assert norm_type == "L2", "Only L2 norms are currently supported"
  238. norm = 0.0
  239. if norm_type == "L2":
  240. norm = (
  241. delta_end * delta_end * delta_end
  242. - delta_begin * delta_begin * delta_begin
  243. ) / 3
  244. return density * norm
  245. def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
  246. r"""
  247. Compute the quantization error if we use start_bin to end_bin as the
  248. min and max to do the quantization.
  249. """
  250. norm = 0.0
  251. dst_bin_width = (
  252. bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
  253. )
  254. if dst_bin_width == 0.0:
  255. return 0.0
  256. for src_bin in range(self.bins):
  257. # distances from the beginning of first dst_bin to the beginning and
  258. # end of src_bin
  259. src_bin_begin = (src_bin - next_start_bin) * bin_width
  260. src_bin_end = src_bin_begin + bin_width
  261. # which dst_bins the beginning and end of src_bin belong to?
  262. dst_bin_of_begin = min(
  263. self.dst_nbins - 1,
  264. max(0.0, math.floor(src_bin_begin / dst_bin_width)),
  265. )
  266. dst_bin_of_end = min(
  267. self.dst_nbins - 1,
  268. max(0.0, math.floor(src_bin_end / dst_bin_width)),
  269. )
  270. dst_bin_of_begin_center = (
  271. dst_bin_of_begin * dst_bin_width + dst_bin_width / 2
  272. )
  273. density = np_histogram[src_bin] / bin_width
  274. if dst_bin_of_begin == dst_bin_of_end:
  275. # if src_bin is entirely within 1 dst_bin
  276. delta_begin = src_bin_begin - dst_bin_of_begin_center
  277. delta_end = src_bin_end - dst_bin_of_begin_center
  278. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  279. else:
  280. delta_begin = src_bin_begin - dst_bin_of_begin_center
  281. delta_end = dst_bin_width / 2
  282. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  283. norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm(
  284. -dst_bin_width / 2, dst_bin_width / 2, density, norm_type
  285. )
  286. dst_bin_of_end_center = (
  287. dst_bin_of_end * dst_bin_width + dst_bin_width / 2
  288. )
  289. delta_begin = -dst_bin_width / 2
  290. delta_end = src_bin_end - dst_bin_of_end_center
  291. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  292. return norm
  293. # cumulative sum
  294. total = sum(np_histogram)
  295. cSum = np.cumsum(np_histogram, axis=0)
  296. stepsize = 1e-5 # granularity
  297. alpha = 0.0 # lower bound
  298. beta = 1.0 # upper bound
  299. start_bin = 0
  300. end_bin = self.bins - 1
  301. norm_min = float("inf")
  302. while alpha < beta:
  303. # Find the next step
  304. next_alpha = alpha + stepsize
  305. next_beta = beta - stepsize
  306. # find the left and right bins between the quantile bounds
  307. l = start_bin
  308. r = end_bin
  309. while l < end_bin and cSum[l] < next_alpha * total:
  310. l = l + 1
  311. while r > start_bin and cSum[r] > next_beta * total:
  312. r = r - 1
  313. # decide the next move
  314. next_start_bin = start_bin
  315. next_end_bin = end_bin
  316. if (l - start_bin) > (end_bin - r):
  317. # move the start bin
  318. next_start_bin = l
  319. alpha = next_alpha
  320. else:
  321. # move the end bin
  322. next_end_bin = r
  323. beta = next_beta
  324. if next_start_bin == start_bin and next_end_bin == end_bin:
  325. continue
  326. # calculate the quantization error using next_start_bin and next_end_bin
  327. norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2")
  328. if norm > norm_min:
  329. break
  330. norm_min = norm
  331. start_bin = next_start_bin
  332. end_bin = next_end_bin
  333. new_min = self.min_val + Tensor(bin_width * start_bin, dtype=np.float32)
  334. new_max = self.min_val + Tensor(bin_width * (end_bin + 1), dtype=np.float32)
  335. return new_min, new_max
  336. def get_qparams(self):
  337. new_min, new_max = self._non_linear_param_search()
  338. return self._calculate_qparams(new_min, new_max)
  339. def _combine_histograms(
  340. self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins
  341. ):
  342. # First up-sample the histogram with new data by a factor of L
  343. # This creates an approximate probability density thats piecwise constant
  344. upsampled_histogram = new_hist.repeat(upsample_rate)
  345. # Now insert the upsampled histogram into the output
  346. # histogram, which is initialized with zeros.
  347. # The offset at which the histogram is introduced is determined
  348. # by the start index as the output histogram can cover a wider range
  349. histogram_with_output_range = np.zeros((Nbins * downsample_rate))
  350. histogram_with_output_range[
  351. start_idx : Nbins * upsample_rate + start_idx
  352. ] = upsampled_histogram
  353. # Compute integral histogram, double precision is needed to ensure
  354. # that there are no overflows
  355. integral_histogram = np.cumsum(histogram_with_output_range, 0)[
  356. downsample_rate - 1 :: downsample_rate
  357. ]
  358. # Finally perform interpolation
  359. shifted_integral_histogram = np.zeros((Nbins))
  360. shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1]
  361. interpolated_histogram = (
  362. integral_histogram - shifted_integral_histogram
  363. ) / upsample_rate
  364. orig_hist = orig_hist + interpolated_histogram
  365. return orig_hist
  366. def _adjust_min_max(self, combined_min, combined_max, upsample_rate):
  367. # We ensure that:
  368. # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
  369. # This allows us to have a common grid of resolution s, where we can align
  370. # the input histogram
  371. # start_idx maps min_val to the histogram bin index.
  372. np_min_val = self.min_val.numpy()
  373. np_max_val = self.max_val.numpy()
  374. hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate)
  375. downsample_rate = int(
  376. np.ceil((combined_max - combined_min) / (self.bins * hist_bin_width))
  377. )
  378. e = downsample_rate * (self.bins * hist_bin_width) - (
  379. combined_max - combined_min
  380. )
  381. combined_max = combined_max + e / 2
  382. combined_min = combined_min - e / 2
  383. start_idx = int(np.round((np_min_val - combined_min) / hist_bin_width))
  384. return combined_min, combined_max, downsample_rate, start_idx
  385. def sideeffect_forward(self, x_orig):
  386. x = x_orig.numpy()
  387. min_val = self.min_val.numpy()
  388. max_val = self.max_val.numpy()
  389. histogram = self.histogram.numpy()
  390. new_min = x.min()
  391. new_max = x.max()
  392. if histogram[0] == -1:
  393. new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
  394. else:
  395. new_min = min(new_min, min_val)
  396. new_max = max(new_max, max_val)
  397. # combine the existing histogram and new histogram into 1 histogram
  398. # We do this by first upsampling the histogram to a dense grid
  399. # and then downsampling the histogram efficiently
  400. (new_min, new_max, downsample_rate, start_idx) = self._adjust_min_max(
  401. new_min, new_max, self.upsample_rate
  402. )
  403. new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
  404. new_histogram = new_histogram.astype(np.float64)
  405. if new_min == min_val and new_max == max_val:
  406. new_histogram += histogram
  407. else:
  408. new_histogram = self._combine_histograms(
  409. new_histogram,
  410. histogram,
  411. self.upsample_rate,
  412. downsample_rate,
  413. start_idx,
  414. self.bins,
  415. )
  416. self.histogram = Tensor(new_histogram, dtype="float32")
  417. self.min_val = Tensor(new_min, dtype="float32")
  418. self.max_val = Tensor(new_max, dtype="float32")
  419. def forward(self, x_orig):
  420. self.sideeffect_forward(x_orig)
  421. return x_orig
  422. class PassiveObserver(Observer):
  423. r"""
  424. An Observer that supports setting :attr:`scale` directly.
  425. """
  426. def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
  427. super().__init__(dtype, **kwargs)
  428. self.qparams = None
  429. self.orig_scale = None
  430. @property
  431. def scale(self):
  432. return self.qparams.scale
  433. @scale.setter
  434. def scale(self, value: np.ndarray):
  435. assert np.all(value > 0)
  436. self.qparams.scale[...] = Tensor(value)
  437. def get_qparams(self):
  438. return self.qparams
  439. def set_qparams(self, qparams: QParams):
  440. """
  441. :param qparams: used to set initial scale.
  442. """
  443. self.qparams = deepcopy(qparams)
  444. if qparams.scale is None:
  445. raise AssertionError("Can not get an initialized scale")
  446. if qparams.dtype_meta is None:
  447. qparams.dtype_meta = self.dtype
  448. else:
  449. assert (
  450. qparams.dtype_meta is self.dtype
  451. ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
  452. qparams.dtype_meta, self.dtype
  453. )
  454. self.orig_scale = qparams.scale.numpy()
  455. def forward(self, x):
  456. r"""
  457. Just return input because :attr:`qparams` is set by :func:`~.apply_easy_quant`.
  458. """
  459. return x

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台