You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

observer.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import math
  9. from abc import abstractmethod
  10. from copy import deepcopy
  11. from typing import Union
  12. import numpy as np
  13. from .. import functional as F
  14. from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes
  15. from ..distributed import WORLD, get_rank, is_distributed
  16. from ..functional.distributed import all_reduce_max, all_reduce_min
  17. from ..logger import get_logger
  18. from ..module import Module
  19. from ..tensor import Tensor
  20. from .utils import QParams, QParamsModuleMixin, QuantMode, create_qparams
  21. logger = get_logger(__name__)
  22. class Observer(Module, QParamsModuleMixin):
  23. r"""
  24. A base class for Observer Module.
  25. :param dtype: a string indicating to collect scale and zero_point of which dtype.
  26. """
  27. def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
  28. super().__init__()
  29. if isinstance(dtype, str):
  30. if not dtype in _builtin_quant_dtypes:
  31. raise ValueError(
  32. "unknown dtype: {}, only support {}".format(
  33. dtype, _builtin_quant_dtypes.keys()
  34. )
  35. )
  36. dtype = _builtin_quant_dtypes[dtype]
  37. if "narrow_range" in kwargs:
  38. del kwargs["narrow_range"]
  39. logger.warning(
  40. "FakeQuantize currently has no narrow_range param "
  41. "so it is ignored here",
  42. exc_info=DeprecationWarning,
  43. )
  44. self.dtype = dtype
  45. self.qmin = dtype.qmin
  46. self.qmax = dtype.qmax
  47. self.enabled = True
  48. def enable(self):
  49. self.enabled = True
  50. def disable(self):
  51. self.enabled = False
  52. def train(self, mode: bool = True, recursive: bool = True) -> None:
  53. super().train(mode, recursive)
  54. if mode:
  55. self.enable()
  56. else:
  57. self.disable()
  58. @abstractmethod
  59. def forward(self, x):
  60. pass
  61. class MinMaxObserver(Observer):
  62. def __init__(
  63. self,
  64. mode: QuantMode = QuantMode.SYMMERTIC,
  65. eps: float = 0.00001,
  66. dtype: Union[str, QuantDtypeMeta] = "qint8",
  67. **kwargs
  68. ):
  69. super().__init__(dtype, **kwargs)
  70. self.mode = mode
  71. self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32)
  72. self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32)
  73. self.scale_limit = eps
  74. def _calculate_qparams(self, inp_min_val, inp_max_val):
  75. min_val = F.minimum(0.0, inp_min_val)
  76. max_val = F.maximum(0.0, inp_max_val)
  77. if self.mode == QuantMode.SYMMERTIC:
  78. symmetric_max_vals = F.maximum(-min_val, max_val)
  79. # use maximun to avoid scale too small at the begin
  80. scale = F.maximum(
  81. symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
  82. )
  83. zero_point = None
  84. else:
  85. # use maximun to avoid scale too small at the begin
  86. scale = F.maximum(
  87. (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit
  88. )
  89. # caculate zero_point
  90. zero_point = self.qmin - F.round((min_val / scale))
  91. return create_qparams(self.mode, self.dtype, scale=scale, zero_point=zero_point)
  92. def get_qparams(self):
  93. return self._calculate_qparams(self.min_val, self.max_val)
  94. def forward(self, x_orig):
  95. if self.enabled:
  96. # stop gradient
  97. x = x_orig.detach()
  98. # find max and min
  99. self.min_val[...] = F.minimum(self.min_val, x.min())
  100. self.max_val[...] = F.maximum(self.max_val, x.max())
  101. return x_orig
  102. class SyncMinMaxObserver(MinMaxObserver):
  103. def forward(self, x_orig):
  104. if self.enable:
  105. x = x_orig.detach()
  106. if is_distributed():
  107. min_x = all_reduce_min(x.min(), WORLD)
  108. max_x = all_reduce_max(x.max(), WORLD)
  109. else:
  110. min_x = x.min()
  111. max_x = x.max()
  112. self.min_val[...] = F.minimum(self.min_val, min_x)
  113. self.max_val[...] = F.maximum(self.max_val, max_x)
  114. return x_orig
  115. class ExponentialMovingAverageObserver(MinMaxObserver):
  116. def __init__(
  117. self,
  118. momentum: float = 0.9,
  119. mode: QuantMode = QuantMode.SYMMERTIC,
  120. eps: float = 0.00001,
  121. dtype: Union[str, QuantDtypeMeta] = "qint8",
  122. **kwargs
  123. ):
  124. super().__init__(mode, eps, dtype, **kwargs)
  125. self.momentum = Tensor(momentum, dtype="float32")
  126. # used to avoid if-clauses in the first forward which is not supported
  127. # in trace mode.
  128. self.runtime_momentum = Tensor(0.0)
  129. def set_momentum(self, momentum):
  130. self.momentum = Tensor(momentum, dtype="float32")
  131. def forward(self, x_orig):
  132. if self.enabled:
  133. # stop gradient
  134. x = x_orig.detach()
  135. # Exponential Moving Average
  136. self.min_val[...] = (
  137. self.min_val * self.runtime_momentum
  138. + (1 - self.runtime_momentum) * x.min()
  139. )
  140. self.max_val[...] = (
  141. self.max_val * self.runtime_momentum
  142. + (1 - self.runtime_momentum) * x.max()
  143. )
  144. self.runtime_momentum[...] = self.momentum
  145. return x_orig
  146. class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
  147. def forward(self, x_orig):
  148. if self.enabled:
  149. x = x_orig.detach()
  150. if is_distributed:
  151. min_x = all_reduce_min(x.min(), WORLD)
  152. max_x = all_reduce_max(x.max(), WORLD)
  153. else:
  154. min_x = x.min()
  155. max_x = x.max()
  156. self.min_val[...] = (
  157. self.min_val * self.runtime_momentum
  158. + (1 - self.runtime_momentum) * min_x
  159. )
  160. self.max_val[...] = (
  161. self.max_val * self.runtime_momentum
  162. + (1 - self.runtime_momentum) * max_x
  163. )
  164. self.runtime_momentum[...] = self.momentum
  165. return x_orig
  166. class HistogramObserver(MinMaxObserver):
  167. def __init__(
  168. self,
  169. bins: int = 2048,
  170. upsample_rate: int = 128,
  171. mode: QuantMode = QuantMode.SYMMERTIC,
  172. eps: float = 0.00001,
  173. dtype: Union[str, QuantDtypeMeta] = "qint8",
  174. **kwargs
  175. ):
  176. super().__init__(mode, eps, dtype, **kwargs)
  177. self.bins = bins
  178. self.upsample_rate = upsample_rate
  179. self.dst_nbins = (
  180. _builtin_quant_dtypes[dtype].qmax - _builtin_quant_dtypes[dtype].qmin + 1
  181. )
  182. self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32")
  183. def _non_linear_param_search(self):
  184. r"""
  185. Non-linear parameter search.
  186. An approximation for L2 error minimization for selecting min/max.
  187. By selecting new min/max, we filter out outliers in input distribution.
  188. """
  189. np_min_val = self.min_val.numpy()
  190. np_max_val = self.max_val.numpy()
  191. np_histogram = self.histogram.numpy()
  192. assert len(np_histogram) == self.bins, "bins mistmatch"
  193. bin_width = (np_max_val - np_min_val) / self.bins
  194. def _get_norm(delta_begin, delta_end, density, norm_type):
  195. r"""
  196. Compute the norm of the values uniformaly distributed between
  197. delta_begin and delta_end.
  198. norm = density * (integral_{begin, end} x^2)
  199. = density * (end^3 - begin^3) / 3
  200. """
  201. assert norm_type == "L2", "Only L2 norms are currently supported"
  202. norm = 0.0
  203. if norm_type == "L2":
  204. norm = (
  205. delta_end * delta_end * delta_end
  206. - delta_begin * delta_begin * delta_begin
  207. ) / 3
  208. return density * norm
  209. def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
  210. r"""
  211. Compute the quantization error if we use start_bin to end_bin as the
  212. min and max to do the quantization.
  213. """
  214. norm = 0.0
  215. dst_bin_width = (
  216. bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
  217. )
  218. if dst_bin_width == 0.0:
  219. return 0.0
  220. for src_bin in range(self.bins):
  221. # distances from the beginning of first dst_bin to the beginning and
  222. # end of src_bin
  223. src_bin_begin = (src_bin - next_start_bin) * bin_width
  224. src_bin_end = src_bin_begin + bin_width
  225. # which dst_bins the beginning and end of src_bin belong to?
  226. dst_bin_of_begin = min(
  227. self.dst_nbins - 1,
  228. max(0.0, math.floor(src_bin_begin / dst_bin_width)),
  229. )
  230. dst_bin_of_end = min(
  231. self.dst_nbins - 1,
  232. max(0.0, math.floor(src_bin_end / dst_bin_width)),
  233. )
  234. dst_bin_of_begin_center = (
  235. dst_bin_of_begin * dst_bin_width + dst_bin_width / 2
  236. )
  237. density = np_histogram[src_bin] / bin_width
  238. if dst_bin_of_begin == dst_bin_of_end:
  239. # if src_bin is entirely within 1 dst_bin
  240. delta_begin = src_bin_begin - dst_bin_of_begin_center
  241. delta_end = src_bin_end - dst_bin_of_begin_center
  242. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  243. else:
  244. delta_begin = src_bin_begin - dst_bin_of_begin_center
  245. delta_end = dst_bin_width / 2
  246. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  247. norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm(
  248. -dst_bin_width / 2, dst_bin_width / 2, density, norm_type
  249. )
  250. dst_bin_of_end_center = (
  251. dst_bin_of_end * dst_bin_width + dst_bin_width / 2
  252. )
  253. delta_begin = -dst_bin_width / 2
  254. delta_end = src_bin_end - dst_bin_of_end_center
  255. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  256. return norm
  257. # cumulative sum
  258. total = sum(np_histogram)
  259. cSum = np.cumsum(np_histogram, axis=0)
  260. stepsize = 1e-5 # granularity
  261. alpha = 0.0 # lower bound
  262. beta = 1.0 # upper bound
  263. start_bin = 0
  264. end_bin = self.bins - 1
  265. norm_min = float("inf")
  266. while alpha < beta:
  267. # Find the next step
  268. next_alpha = alpha + stepsize
  269. next_beta = beta - stepsize
  270. # find the left and right bins between the quantile bounds
  271. l = start_bin
  272. r = end_bin
  273. while l < end_bin and cSum[l] < next_alpha * total:
  274. l = l + 1
  275. while r > start_bin and cSum[r] > next_beta * total:
  276. r = r - 1
  277. # decide the next move
  278. next_start_bin = start_bin
  279. next_end_bin = end_bin
  280. if (l - start_bin) > (end_bin - r):
  281. # move the start bin
  282. next_start_bin = l
  283. alpha = next_alpha
  284. else:
  285. # move the end bin
  286. next_end_bin = r
  287. beta = next_beta
  288. if next_start_bin == start_bin and next_end_bin == end_bin:
  289. continue
  290. # calculate the quantization error using next_start_bin and next_end_bin
  291. norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2")
  292. if norm > norm_min:
  293. break
  294. norm_min = norm
  295. start_bin = next_start_bin
  296. end_bin = next_end_bin
  297. new_min = self.min_val + Tensor(bin_width * start_bin, dtype=np.float32)
  298. new_max = self.min_val + Tensor(bin_width * (end_bin + 1), dtype=np.float32)
  299. return new_min, new_max
  300. def get_qparams(self):
  301. new_min, new_max = self._non_linear_param_search()
  302. return self._calculate_qparams(new_min, new_max)
  303. def _combine_histograms(
  304. self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins
  305. ):
  306. # First up-sample the histogram with new data by a factor of L
  307. # This creates an approximate probability density thats piecwise constant
  308. upsampled_histogram = new_hist.repeat(upsample_rate)
  309. # Now insert the upsampled histogram into the output
  310. # histogram, which is initialized with zeros.
  311. # The offset at which the histogram is introduced is determined
  312. # by the start index as the output histogram can cover a wider range
  313. histogram_with_output_range = np.zeros((Nbins * downsample_rate))
  314. histogram_with_output_range[
  315. start_idx : Nbins * upsample_rate + start_idx
  316. ] = upsampled_histogram
  317. # Compute integral histogram, double precision is needed to ensure
  318. # that there are no overflows
  319. integral_histogram = np.cumsum(histogram_with_output_range, 0)[
  320. downsample_rate - 1 :: downsample_rate
  321. ]
  322. # Finally perform interpolation
  323. shifted_integral_histogram = np.zeros((Nbins))
  324. shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1]
  325. interpolated_histogram = (
  326. integral_histogram - shifted_integral_histogram
  327. ) / upsample_rate
  328. orig_hist = orig_hist + interpolated_histogram
  329. return orig_hist
  330. def _adjust_min_max(self, combined_min, combined_max, upsample_rate):
  331. # We ensure that:
  332. # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
  333. # This allows us to have a common grid of resolution s, where we can align
  334. # the input histogram
  335. # start_idx maps min_val to the histogram bin index.
  336. np_min_val = self.min_val.numpy()
  337. np_max_val = self.max_val.numpy()
  338. hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate)
  339. downsample_rate = int(
  340. np.ceil((combined_max - combined_min) / (self.bins * hist_bin_width))
  341. )
  342. e = downsample_rate * (self.bins * hist_bin_width) - (
  343. combined_max - combined_min
  344. )
  345. combined_max = combined_max + e / 2
  346. combined_min = combined_min - e / 2
  347. start_idx = int(np.round((np_min_val - combined_min) / hist_bin_width))
  348. return combined_min, combined_max, downsample_rate, start_idx
  349. def sideeffect_forward(self, x_orig):
  350. x = x_orig.numpy()
  351. min_val = self.min_val.numpy()
  352. max_val = self.max_val.numpy()
  353. histogram = self.histogram.numpy()
  354. new_min = x.min()
  355. new_max = x.max()
  356. if histogram[0] == -1:
  357. new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
  358. else:
  359. new_min = min(new_min, min_val)
  360. new_max = max(new_max, max_val)
  361. # combine the existing histogram and new histogram into 1 histogram
  362. # We do this by first upsampling the histogram to a dense grid
  363. # and then downsampling the histogram efficiently
  364. (new_min, new_max, downsample_rate, start_idx) = self._adjust_min_max(
  365. new_min, new_max, self.upsample_rate
  366. )
  367. new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
  368. new_histogram = new_histogram.astype(np.float64)
  369. if new_min == min_val and new_max == max_val:
  370. new_histogram += histogram
  371. else:
  372. new_histogram = self._combine_histograms(
  373. new_histogram,
  374. histogram,
  375. self.upsample_rate,
  376. downsample_rate,
  377. start_idx,
  378. self.bins,
  379. )
  380. self.histogram = Tensor(new_histogram, dtype="float32")
  381. self.min_val = Tensor(new_min, dtype="float32")
  382. self.max_val = Tensor(new_max, dtype="float32")
  383. def forward(self, x_orig):
  384. self.sideeffect_forward(x_orig)
  385. return x_orig
  386. class PassiveObserver(Observer):
  387. r"""
  388. An Observer that supports setting :attr:`scale` directly.
  389. """
  390. def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
  391. super().__init__(dtype, **kwargs)
  392. self.qparams = None
  393. self.orig_scale = None
  394. @property
  395. def scale(self):
  396. return self.qparams.scale
  397. @scale.setter
  398. def scale(self, value: np.ndarray):
  399. assert np.all(value > 0)
  400. self.qparams.scale[...] = Tensor(value)
  401. def get_qparams(self):
  402. return self.qparams
  403. def set_qparams(self, qparams: QParams):
  404. """
  405. :param qparams: used to set initial scale.
  406. """
  407. self.qparams = deepcopy(qparams)
  408. if qparams.scale is None:
  409. raise AssertionError("Can not get an initialized scale")
  410. if qparams.dtype_meta is None:
  411. qparams.dtype_meta = self.dtype
  412. else:
  413. assert (
  414. qparams.dtype_meta is self.dtype
  415. ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
  416. qparams.dtype_meta, self.dtype
  417. )
  418. self.orig_scale = qparams.scale.numpy()
  419. def forward(self, x):
  420. r"""
  421. Just return input because :attr:`qparams` is set by :func:`~.apply_easy_quant`.
  422. """
  423. return x

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台