You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

observer.py 20 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521
  1. import math
  2. from abc import abstractmethod
  3. from copy import deepcopy
  4. from typing import Union
  5. import numpy as np
  6. from .. import functional as F
  7. from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes
  8. from ..distributed import WORLD, get_rank, is_distributed
  9. from ..functional.distributed import all_reduce_max, all_reduce_min
  10. from ..logger import get_logger
  11. from ..module import Module
  12. from ..tensor import Tensor
  13. from .utils import QParams, QParamsModuleMixin, QuantMode, create_qparams
  14. logger = get_logger(__name__)
  15. class Observer(Module, QParamsModuleMixin):
  16. r"""A base class for Observer Module. Used to record input tensor's statistics for
  17. quantization.
  18. Args:
  19. dtype: a string indicating which dtype to collect scale and zero_point of.
  20. """
  21. def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
  22. super().__init__()
  23. if isinstance(dtype, str):
  24. if not dtype in _builtin_quant_dtypes:
  25. raise ValueError(
  26. "unknown dtype: {}, only support {}".format(
  27. dtype, _builtin_quant_dtypes.keys()
  28. )
  29. )
  30. dtype = _builtin_quant_dtypes[dtype]
  31. if "narrow_range" in kwargs:
  32. del kwargs["narrow_range"]
  33. logger.warning(
  34. "FakeQuantize currently has no narrow_range param "
  35. "so it is ignored here",
  36. exc_info=DeprecationWarning,
  37. )
  38. self.dtype = dtype
  39. self.qmin = dtype.qmin
  40. self.qmax = dtype.qmax
  41. self.enabled = True
  42. def enable(self):
  43. self.enabled = True
  44. def disable(self):
  45. self.enabled = False
  46. def train(self, mode: bool = True, recursive: bool = True) -> None:
  47. super().train(mode, recursive)
  48. if mode:
  49. self.enable()
  50. else:
  51. self.disable()
  52. @abstractmethod
  53. def forward(self, x):
  54. pass
  55. class MinMaxObserver(Observer):
  56. r"""A Observer Module records input tensor's running min and max values to calc scale.
  57. Args:
  58. mode: set quantization mode.
  59. eps: a initial maximum value to avoid division by zero problem.
  60. dtype: a string indicating which dtype to collect scale and zero_point of.
  61. """
  62. def __init__(
  63. self,
  64. mode: QuantMode = QuantMode.SYMMERTIC,
  65. eps: float = 0.00001,
  66. dtype: Union[str, QuantDtypeMeta] = "qint8",
  67. **kwargs
  68. ):
  69. super().__init__(dtype, **kwargs)
  70. self.mode = mode
  71. self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32)
  72. self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32)
  73. self.scale_limit = eps
  74. def _calculate_qparams(self, inp_min_val, inp_max_val):
  75. min_val = F.minimum(0.0, inp_min_val)
  76. max_val = F.maximum(0.0, inp_max_val)
  77. if self.mode == QuantMode.SYMMERTIC:
  78. symmetric_max_vals = F.maximum(-min_val, max_val)
  79. # use maximun to avoid scale too small at the begin
  80. scale = F.maximum(
  81. symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
  82. )
  83. zero_point = None
  84. else:
  85. # use maximun to avoid scale too small at the begin
  86. scale = F.maximum(
  87. (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit
  88. )
  89. # caculate zero_point
  90. zero_point = self.qmin - F.round((min_val / scale))
  91. return create_qparams(self.mode, self.dtype, scale=scale, zero_point=zero_point)
  92. def get_qparams(self):
  93. return self._calculate_qparams(self.min_val, self.max_val)
  94. def forward(self, x_orig):
  95. if self.enabled:
  96. # stop gradient
  97. x = x_orig.detach()
  98. # find max and min
  99. self.min_val[...] = F.minimum(self.min_val, x.min())
  100. self.max_val[...] = F.maximum(self.max_val, x.max())
  101. return x_orig
  102. class SyncMinMaxObserver(MinMaxObserver):
  103. r"""A distributed version of :class:`~.MinMaxObserver`.
  104. Args:
  105. mode: set quantization mode.
  106. eps: a initial maximum value to avoid division by zero problem.
  107. dtype: a string indicating which dtype to collect scale and zero_point of.
  108. """
  109. def forward(self, x_orig):
  110. if self.enable:
  111. x = x_orig.detach()
  112. if is_distributed():
  113. min_x = all_reduce_min(x.min(), WORLD)
  114. max_x = all_reduce_max(x.max(), WORLD)
  115. else:
  116. min_x = x.min()
  117. max_x = x.max()
  118. self.min_val[...] = F.minimum(self.min_val, min_x)
  119. self.max_val[...] = F.maximum(self.max_val, max_x)
  120. return x_orig
  121. class ExponentialMovingAverageObserver(MinMaxObserver):
  122. r"""A :class:`~.MinMaxObserver` with momentum support for min/max updating.
  123. Args:
  124. momentum: momentum ratio for min/max updating.
  125. mode: set quantization mode.
  126. eps: a initial maximum value to avoid division by zero problem.
  127. dtype: a string indicating which dtype to collect scale and zero_point of.
  128. """
  129. def __init__(
  130. self,
  131. momentum: float = 0.9,
  132. mode: QuantMode = QuantMode.SYMMERTIC,
  133. eps: float = 0.00001,
  134. dtype: Union[str, QuantDtypeMeta] = "qint8",
  135. **kwargs
  136. ):
  137. super().__init__(mode, eps, dtype, **kwargs)
  138. self.momentum = Tensor(momentum, dtype="float32")
  139. # used to avoid if-clauses in the first forward which is not supported
  140. # in trace mode.
  141. self.runtime_momentum = Tensor(0.0)
  142. def set_momentum(self, momentum):
  143. self.momentum = Tensor(momentum, dtype="float32")
  144. def forward(self, x_orig):
  145. if self.enabled:
  146. # stop gradient
  147. x = x_orig.detach()
  148. # Exponential Moving Average
  149. self.min_val[...] = (
  150. self.min_val * self.runtime_momentum
  151. + (1 - self.runtime_momentum) * x.min()
  152. )
  153. self.max_val[...] = (
  154. self.max_val * self.runtime_momentum
  155. + (1 - self.runtime_momentum) * x.max()
  156. )
  157. self.runtime_momentum[...] = self.momentum
  158. return x_orig
  159. class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
  160. r"""A distributed version of :class:`~.ExponentialMovingAverageObserver`.
  161. Args:
  162. momentum: momentum ratio for min/max updating.
  163. mode: set quantization mode.
  164. eps: a initial maximum value to avoid division by zero problem.
  165. dtype: a string indicating which dtype to collect scale and zero_point of.
  166. """
  167. def forward(self, x_orig):
  168. if self.enabled:
  169. x = x_orig.detach()
  170. if is_distributed():
  171. min_x = all_reduce_min(x.min(), WORLD)
  172. max_x = all_reduce_max(x.max(), WORLD)
  173. else:
  174. min_x = x.min()
  175. max_x = x.max()
  176. self.min_val[...] = (
  177. self.min_val * self.runtime_momentum
  178. + (1 - self.runtime_momentum) * min_x
  179. )
  180. self.max_val[...] = (
  181. self.max_val * self.runtime_momentum
  182. + (1 - self.runtime_momentum) * max_x
  183. )
  184. self.runtime_momentum[...] = self.momentum
  185. return x_orig
  186. class HistogramObserver(MinMaxObserver):
  187. r"""A :class:`~.MinMaxObserver` using running histogram of tensor values
  188. for min/max updating. Usually used for calibration quantization.
  189. Args:
  190. bins: number of bins to use for the histogram.
  191. upsample_rate: which ratio to interpolate histograms in.
  192. mode: set quantization mode.
  193. eps: a initial maximum value to avoid division by zero problem.
  194. dtype: a string indicating which dtype to collect scale and zero_point of.
  195. """
  196. def __init__(
  197. self,
  198. bins: int = 2048,
  199. upsample_rate: int = 128,
  200. mode: QuantMode = QuantMode.SYMMERTIC,
  201. eps: float = 0.00001,
  202. dtype: Union[str, QuantDtypeMeta] = "qint8",
  203. **kwargs
  204. ):
  205. super().__init__(mode, eps, dtype, **kwargs)
  206. self.bins = bins
  207. self.upsample_rate = upsample_rate
  208. self.dst_nbins = (
  209. _builtin_quant_dtypes[dtype].qmax - _builtin_quant_dtypes[dtype].qmin + 1
  210. )
  211. self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32")
  212. def _non_linear_param_search(self):
  213. r"""Non-linear parameter search.
  214. An approximation for L2 error minimization for selecting min/max.
  215. By selecting new min/max, we filter out outliers in input distribution.
  216. """
  217. np_min_val = self.min_val.numpy()
  218. np_max_val = self.max_val.numpy()
  219. np_histogram = self.histogram.numpy()
  220. assert len(np_histogram) == self.bins, "bins mistmatch"
  221. bin_width = (np_max_val - np_min_val) / self.bins
  222. def _get_norm(delta_begin, delta_end, density, norm_type):
  223. r"""Compute the norm of the values uniformaly distributed between
  224. delta_begin and delta_end.
  225. norm = density * (integral_{begin, end} x^2)
  226. = density * (end^3 - begin^3) / 3
  227. """
  228. assert norm_type == "L2", "Only L2 norms are currently supported"
  229. norm = 0.0
  230. if norm_type == "L2":
  231. norm = (
  232. delta_end * delta_end * delta_end
  233. - delta_begin * delta_begin * delta_begin
  234. ) / 3
  235. return density * norm
  236. def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
  237. r"""Compute the quantization error if we use start_bin to end_bin as the
  238. min and max to do the quantization.
  239. """
  240. norm = 0.0
  241. dst_bin_width = (
  242. bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
  243. )
  244. if dst_bin_width == 0.0:
  245. return 0.0
  246. for src_bin in range(self.bins):
  247. # distances from the beginning of first dst_bin to the beginning and
  248. # end of src_bin
  249. src_bin_begin = (src_bin - next_start_bin) * bin_width
  250. src_bin_end = src_bin_begin + bin_width
  251. # which dst_bins the beginning and end of src_bin belong to?
  252. dst_bin_of_begin = min(
  253. self.dst_nbins - 1,
  254. max(0.0, math.floor(src_bin_begin / dst_bin_width)),
  255. )
  256. dst_bin_of_end = min(
  257. self.dst_nbins - 1,
  258. max(0.0, math.floor(src_bin_end / dst_bin_width)),
  259. )
  260. dst_bin_of_begin_center = (
  261. dst_bin_of_begin * dst_bin_width + dst_bin_width / 2
  262. )
  263. density = np_histogram[src_bin] / bin_width
  264. if dst_bin_of_begin == dst_bin_of_end:
  265. # if src_bin is entirely within 1 dst_bin
  266. delta_begin = src_bin_begin - dst_bin_of_begin_center
  267. delta_end = src_bin_end - dst_bin_of_begin_center
  268. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  269. else:
  270. delta_begin = src_bin_begin - dst_bin_of_begin_center
  271. delta_end = dst_bin_width / 2
  272. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  273. norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm(
  274. -dst_bin_width / 2, dst_bin_width / 2, density, norm_type
  275. )
  276. dst_bin_of_end_center = (
  277. dst_bin_of_end * dst_bin_width + dst_bin_width / 2
  278. )
  279. delta_begin = -dst_bin_width / 2
  280. delta_end = src_bin_end - dst_bin_of_end_center
  281. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  282. return norm
  283. # cumulative sum
  284. total = sum(np_histogram)
  285. cSum = np.cumsum(np_histogram, axis=0)
  286. stepsize = 1e-5 # granularity
  287. alpha = 0.0 # lower bound
  288. beta = 1.0 # upper bound
  289. start_bin = 0
  290. end_bin = self.bins - 1
  291. norm_min = float("inf")
  292. while alpha < beta:
  293. # Find the next step
  294. next_alpha = alpha + stepsize
  295. next_beta = beta - stepsize
  296. # find the left and right bins between the quantile bounds
  297. l = start_bin
  298. r = end_bin
  299. while l < end_bin and cSum[l] < next_alpha * total:
  300. l = l + 1
  301. while r > start_bin and cSum[r] > next_beta * total:
  302. r = r - 1
  303. # decide the next move
  304. next_start_bin = start_bin
  305. next_end_bin = end_bin
  306. if (l - start_bin) > (end_bin - r):
  307. # move the start bin
  308. next_start_bin = l
  309. alpha = next_alpha
  310. else:
  311. # move the end bin
  312. next_end_bin = r
  313. beta = next_beta
  314. if next_start_bin == start_bin and next_end_bin == end_bin:
  315. continue
  316. # calculate the quantization error using next_start_bin and next_end_bin
  317. norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2")
  318. if norm > norm_min:
  319. break
  320. norm_min = norm
  321. start_bin = next_start_bin
  322. end_bin = next_end_bin
  323. new_min = self.min_val + Tensor(bin_width * start_bin, dtype=np.float32)
  324. new_max = self.min_val + Tensor(bin_width * (end_bin + 1), dtype=np.float32)
  325. return new_min, new_max
  326. def get_qparams(self):
  327. new_min, new_max = self._non_linear_param_search()
  328. return self._calculate_qparams(new_min, new_max)
  329. def _combine_histograms(
  330. self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins
  331. ):
  332. # First up-sample the histogram with new data by a factor of L
  333. # This creates an approximate probability density thats piecwise constant
  334. upsampled_histogram = new_hist.repeat(upsample_rate)
  335. # Now insert the upsampled histogram into the output
  336. # histogram, which is initialized with zeros.
  337. # The offset at which the histogram is introduced is determined
  338. # by the start index as the output histogram can cover a wider range
  339. histogram_with_output_range = np.zeros((Nbins * downsample_rate))
  340. histogram_with_output_range[
  341. start_idx : Nbins * upsample_rate + start_idx
  342. ] = upsampled_histogram
  343. # Compute integral histogram, double precision is needed to ensure
  344. # that there are no overflows
  345. integral_histogram = np.cumsum(histogram_with_output_range, 0)[
  346. downsample_rate - 1 :: downsample_rate
  347. ]
  348. # Finally perform interpolation
  349. shifted_integral_histogram = np.zeros((Nbins))
  350. shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1]
  351. interpolated_histogram = (
  352. integral_histogram - shifted_integral_histogram
  353. ) / upsample_rate
  354. orig_hist = orig_hist + interpolated_histogram
  355. return orig_hist
  356. def _adjust_min_max(self, combined_min, combined_max, upsample_rate):
  357. # We ensure that:
  358. # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
  359. # This allows us to have a common grid of resolution s, where we can align
  360. # the input histogram
  361. # start_idx maps min_val to the histogram bin index.
  362. np_min_val = self.min_val.numpy()
  363. np_max_val = self.max_val.numpy()
  364. hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate)
  365. downsample_rate = int(
  366. np.ceil((combined_max - combined_min) / (self.bins * hist_bin_width))
  367. )
  368. e = downsample_rate * (self.bins * hist_bin_width) - (
  369. combined_max - combined_min
  370. )
  371. combined_max = combined_max + e / 2
  372. combined_min = combined_min - e / 2
  373. start_idx = int(np.round((np_min_val - combined_min) / hist_bin_width))
  374. return combined_min, combined_max, downsample_rate, start_idx
  375. def sideeffect_forward(self, x_orig):
  376. x = x_orig.numpy()
  377. min_val = self.min_val.numpy()
  378. max_val = self.max_val.numpy()
  379. histogram = self.histogram.numpy()
  380. new_min = x.min()
  381. new_max = x.max()
  382. if histogram[0] == -1:
  383. new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
  384. else:
  385. new_min = min(new_min, min_val)
  386. new_max = max(new_max, max_val)
  387. # combine the existing histogram and new histogram into 1 histogram
  388. # We do this by first upsampling the histogram to a dense grid
  389. # and then downsampling the histogram efficiently
  390. (new_min, new_max, downsample_rate, start_idx) = self._adjust_min_max(
  391. new_min, new_max, self.upsample_rate
  392. )
  393. new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
  394. new_histogram = new_histogram.astype(np.float64)
  395. if new_min == min_val and new_max == max_val:
  396. new_histogram += histogram
  397. else:
  398. new_histogram = self._combine_histograms(
  399. new_histogram,
  400. histogram,
  401. self.upsample_rate,
  402. downsample_rate,
  403. start_idx,
  404. self.bins,
  405. )
  406. self.histogram = Tensor(new_histogram, dtype="float32")
  407. self.min_val = Tensor(new_min, dtype="float32")
  408. self.max_val = Tensor(new_max, dtype="float32")
  409. def forward(self, x_orig):
  410. self.sideeffect_forward(x_orig)
  411. return x_orig
  412. class PassiveObserver(Observer):
  413. r"""An Observer that supports setting :attr:`scale` directly."""
  414. def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
  415. super().__init__(dtype, **kwargs)
  416. self.qparams = None
  417. self.orig_scale = None
  418. @property
  419. def scale(self):
  420. return self.qparams.scale
  421. @scale.setter
  422. def scale(self, value: np.ndarray):
  423. assert np.all(value > 0)
  424. self.qparams.scale[...] = Tensor(value)
  425. def get_qparams(self):
  426. return self.qparams
  427. def set_qparams(self, qparams: QParams):
  428. r"""set the ``qparams``.
  429. Args:
  430. qparams: used to set initial scale.
  431. """
  432. self.qparams = deepcopy(qparams)
  433. if qparams.scale is None:
  434. raise AssertionError("Can not get an initialized scale")
  435. if qparams.dtype_meta is None:
  436. qparams.dtype_meta = self.dtype
  437. else:
  438. assert (
  439. qparams.dtype_meta is self.dtype
  440. ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
  441. qparams.dtype_meta, self.dtype
  442. )
  443. self.orig_scale = qparams.scale.numpy()
  444. def forward(self, x):
  445. r"""Just return input because :attr:`qparams` is set by :func:`~.apply_easy_quant`."""
  446. return x