You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

observer.py 18 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import math
  9. from abc import abstractmethod
  10. from copy import deepcopy
  11. import numpy as np
  12. from .. import functional as F
  13. from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype
  14. from ..distributed import WORLD, get_rank, is_distributed
  15. from ..functional.distributed import all_reduce_max, all_reduce_min
  16. from ..module import Module
  17. from ..tensor import Tensor
  18. from .utils import QuantMode, Round, get_qparam_dict
  19. class Observer(Module):
  20. r"""
  21. A base class for Observer Module.
  22. :param dtype: a string indicating to collect scale and zero_point of which dtype.
  23. :param narrow_range: whether the absolute value of ``qmin`` is the same as ``qmax``,
  24. instead of 1 greater. Usually True for weight and False for activation.
  25. """
  26. def __init__(self, dtype: str, narrow_range: bool = False, **kwargs):
  27. super().__init__()
  28. if dtype not in _metadata_dict.keys():
  29. raise ValueError(
  30. "unknown dtype: {}, only support {}".format(
  31. dtype, _metadata_dict.keys()
  32. )
  33. )
  34. self.dtype = dtype
  35. self.narrow_range = narrow_range
  36. self.qmin = (
  37. -_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin
  38. )
  39. self.qmax = _metadata_dict[dtype].qmax
  40. self.enabled = True
  41. def get_dtype(self):
  42. q_dict = self.get_qparams()
  43. numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy()
  44. numpy_zero_point = (
  45. None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()
  46. )
  47. return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point)
  48. def enable(self):
  49. self.enabled = True
  50. def disable(self):
  51. self.enabled = False
  52. def train(self, mode: bool = True, recursive: bool = True) -> None:
  53. super().train(mode, recursive)
  54. if mode:
  55. self.enable()
  56. else:
  57. self.disable()
  58. @abstractmethod
  59. def forward(self, x):
  60. pass
  61. @abstractmethod
  62. def get_qparams(self, **kwargs):
  63. pass
  64. class MinMaxObserver(Observer):
  65. def __init__(
  66. self,
  67. mode=QuantMode.SYMMERTIC,
  68. eps=0.00001,
  69. dtype="qint8",
  70. narrow_range: bool = False,
  71. **kwargs
  72. ):
  73. super().__init__(dtype, narrow_range, **kwargs)
  74. self.mode = mode
  75. self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32)
  76. self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32)
  77. self.scale_limit = eps
  78. def _calculate_qparams(self, inp_min_val, inp_max_val):
  79. min_val = F.minimum(0.0, inp_min_val)
  80. max_val = F.maximum(0.0, inp_max_val)
  81. q_dict = get_qparam_dict(self.mode)
  82. q_dict["min_val"] = inp_min_val
  83. q_dict["max_val"] = inp_max_val
  84. q_dict["enable_observer"] = self.enable
  85. if self.mode == QuantMode.SYMMERTIC:
  86. symmetric_max_vals = F.maximum(-min_val, max_val)
  87. # use maximun to avoid scale too small at the begin
  88. q_dict["scale"] = F.maximum(
  89. symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
  90. )
  91. # zero_point = self.zero_point
  92. else:
  93. # use maximun to avoid scale too small at the begin
  94. q_dict["scale"] = F.maximum(
  95. (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit
  96. )
  97. # caculate zero_point
  98. q_dict["zero_point"] = self.qmin - Round()((min_val / q_dict["scale"]))
  99. return q_dict
  100. def get_qparams(self):
  101. return self._calculate_qparams(self.min_val, self.max_val)
  102. def forward(self, x_orig):
  103. if self.enabled:
  104. # stop gradient
  105. x = x_orig.detach()
  106. # find max and min
  107. self.min_val._reset(F.minimum(self.min_val, x.min()))
  108. self.max_val._reset(F.maximum(self.max_val, x.max()))
  109. return x_orig
  110. class SyncMinMaxObserver(MinMaxObserver):
  111. def forward(self, x_orig):
  112. if self.enable:
  113. x = x_orig.detach()
  114. if is_distributed():
  115. min_x = all_reduce_min(x.min(), WORLD)
  116. max_x = all_reduce_max(x.max(), WORLD)
  117. else:
  118. min_x = x.min()
  119. max_x = x.max()
  120. self.min_val._reset(F.minimum(self.min_val, min_x))
  121. self.max_val._reset(F.maximum(self.max_val, max_x))
  122. return x_orig
  123. class ExponentialMovingAverageObserver(MinMaxObserver):
  124. def __init__(
  125. self,
  126. momentum=0.9,
  127. mode=QuantMode.SYMMERTIC,
  128. eps=0.00001,
  129. dtype="qint8",
  130. narrow_range: bool = False,
  131. **kwargs
  132. ):
  133. super().__init__(mode, eps, dtype, narrow_range, **kwargs)
  134. self.momentum = Tensor(momentum, dtype="float32")
  135. self.runtime_momentum = Tensor(0.0)
  136. def set_momentum(self, momentum):
  137. self.momentum = Tenosr(momentum, dtype="float32")
  138. def forward(self, x_orig):
  139. if self.enabled:
  140. # stop gradient
  141. x = x_orig.detach()
  142. # Exponential Moving Average
  143. self.min_val._reset(
  144. self.min_val * self.runtime_momentum
  145. + (1 - self.runtime_momentum) * x.min()
  146. )
  147. self.max_val._reset(
  148. self.max_val * self.runtime_momentum
  149. + (1 - self.runtime_momentum) * x.max()
  150. )
  151. self.runtime_momentum = self.momentum
  152. return x_orig
  153. class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
  154. def forward(self, x_orig):
  155. if self.enabled:
  156. x = x_orig.detach()
  157. if is_distributed:
  158. min_x = all_reduce_min(x.min(), WORLD)
  159. max_x = all_reduce_max(x.max(), WORLD)
  160. else:
  161. min_x = x.min()
  162. max_x = x.max()
  163. self.min_val._reset(
  164. self.min_val * self.runtime_momentum
  165. + (1 - self.runtime_momentum) * min_x
  166. )
  167. self.max_val._reset(
  168. self.max_val * self.runtime_momentum
  169. + (1 - self.runtime_momentum) * max_x
  170. )
  171. self.runtime_momentum = self.momentum
  172. return x_orig
  173. class HistogramObserver(MinMaxObserver):
  174. def __init__(
  175. self,
  176. bins=2048,
  177. upsample_rate=128,
  178. mode=QuantMode.SYMMERTIC,
  179. eps=0.00001,
  180. dtype="qint8",
  181. narrow_range: bool = False,
  182. **kwargs
  183. ):
  184. super().__init__(mode, eps, dtype, narrow_range, **kwargs)
  185. self.bins = bins
  186. self.upsample_rate = upsample_rate
  187. self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1
  188. self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32")
  189. def _non_linear_param_search(self):
  190. r"""
  191. Non-linear parameter search.
  192. An approximation for L2 error minimization for selecting min/max.
  193. By selecting new min/max, we filter out outliers in input distribution.
  194. """
  195. np_min_val = self.min_val.numpy()
  196. np_max_val = self.max_val.numpy()
  197. np_histogram = self.histogram.numpy()
  198. assert len(np_histogram) == self.bins, "bins mistmatch"
  199. bin_width = (np_max_val - np_min_val) / self.bins
  200. def _get_norm(delta_begin, delta_end, density, norm_type):
  201. r"""
  202. Compute the norm of the values uniformaly distributed between
  203. delta_begin and delta_end.
  204. norm = density * (integral_{begin, end} x^2)
  205. = density * (end^3 - begin^3) / 3
  206. """
  207. assert norm_type == "L2", "Only L2 norms are currently supported"
  208. norm = 0.0
  209. if norm_type == "L2":
  210. norm = (
  211. delta_end * delta_end * delta_end
  212. - delta_begin * delta_begin * delta_begin
  213. ) / 3
  214. return density * norm
  215. def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
  216. r"""
  217. Compute the quantization error if we use start_bin to end_bin as the
  218. min and max to do the quantization.
  219. """
  220. norm = 0.0
  221. dst_bin_width = (
  222. bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
  223. )
  224. if dst_bin_width == 0.0:
  225. return 0.0
  226. for src_bin in range(self.bins):
  227. # distances from the beginning of first dst_bin to the beginning and
  228. # end of src_bin
  229. src_bin_begin = (src_bin - next_start_bin) * bin_width
  230. src_bin_end = src_bin_begin + bin_width
  231. # which dst_bins the beginning and end of src_bin belong to?
  232. dst_bin_of_begin = min(
  233. self.dst_nbins - 1,
  234. max(0.0, math.floor(src_bin_begin / dst_bin_width)),
  235. )
  236. dst_bin_of_end = min(
  237. self.dst_nbins - 1,
  238. max(0.0, math.floor(src_bin_end / dst_bin_width)),
  239. )
  240. dst_bin_of_begin_center = (
  241. dst_bin_of_begin * dst_bin_width + dst_bin_width / 2
  242. )
  243. density = np_histogram[src_bin] / bin_width
  244. if dst_bin_of_begin == dst_bin_of_end:
  245. # if src_bin is entirely within 1 dst_bin
  246. delta_begin = src_bin_begin - dst_bin_of_begin_center
  247. delta_end = src_bin_end - dst_bin_of_begin_center
  248. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  249. else:
  250. delta_begin = src_bin_begin - dst_bin_of_begin_center
  251. delta_end = dst_bin_width / 2
  252. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  253. norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm(
  254. -dst_bin_width / 2, dst_bin_width / 2, density, norm_type
  255. )
  256. dst_bin_of_end_center = (
  257. dst_bin_of_end * dst_bin_width + dst_bin_width / 2
  258. )
  259. delta_begin = -dst_bin_width / 2
  260. delta_end = src_bin_end - dst_bin_of_end_center
  261. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  262. return norm
  263. # cumulative sum
  264. total = sum(np_histogram)
  265. cSum = np.cumsum(np_histogram, axis=0)
  266. stepsize = 1e-5 # granularity
  267. alpha = 0.0 # lower bound
  268. beta = 1.0 # upper bound
  269. start_bin = 0
  270. end_bin = self.bins - 1
  271. norm_min = float("inf")
  272. while alpha < beta:
  273. # Find the next step
  274. next_alpha = alpha + stepsize
  275. next_beta = beta - stepsize
  276. # find the left and right bins between the quantile bounds
  277. l = start_bin
  278. r = end_bin
  279. while l < end_bin and cSum[l] < next_alpha * total:
  280. l = l + 1
  281. while r > start_bin and cSum[r] > next_beta * total:
  282. r = r - 1
  283. # decide the next move
  284. next_start_bin = start_bin
  285. next_end_bin = end_bin
  286. if (l - start_bin) > (end_bin - r):
  287. # move the start bin
  288. next_start_bin = l
  289. alpha = next_alpha
  290. else:
  291. # move the end bin
  292. next_end_bin = r
  293. beta = next_beta
  294. if next_start_bin == start_bin and next_end_bin == end_bin:
  295. continue
  296. # calculate the quantization error using next_start_bin and next_end_bin
  297. norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2")
  298. if norm > norm_min:
  299. break
  300. norm_min = norm
  301. start_bin = next_start_bin
  302. end_bin = next_end_bin
  303. new_min = self.min_val + Tensor(bin_width * start_bin, dtype=np.float32)
  304. new_max = self.min_val + Tensor(bin_width * (end_bin + 1), dtype=np.float32)
  305. return new_min, new_max
  306. def get_qparams(self):
  307. new_min, new_max = self._non_linear_param_search()
  308. return self._calculate_qparams(new_min, new_max)
  309. def _combine_histograms(
  310. self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins
  311. ):
  312. # First up-sample the histogram with new data by a factor of L
  313. # This creates an approximate probability density thats piecwise constant
  314. upsampled_histogram = new_hist.repeat(upsample_rate)
  315. # Now insert the upsampled histogram into the output
  316. # histogram, which is initialized with zeros.
  317. # The offset at which the histogram is introduced is determined
  318. # by the start index as the output histogram can cover a wider range
  319. histogram_with_output_range = np.zeros((Nbins * downsample_rate))
  320. histogram_with_output_range[
  321. start_idx : Nbins * upsample_rate + start_idx
  322. ] = upsampled_histogram
  323. # Compute integral histogram, double precision is needed to ensure
  324. # that there are no overflows
  325. integral_histogram = np.cumsum(histogram_with_output_range, 0)[
  326. downsample_rate - 1 :: downsample_rate
  327. ]
  328. # Finally perform interpolation
  329. shifted_integral_histogram = np.zeros((Nbins))
  330. shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1]
  331. interpolated_histogram = (
  332. integral_histogram - shifted_integral_histogram
  333. ) / upsample_rate
  334. orig_hist = orig_hist + interpolated_histogram
  335. return orig_hist
  336. def _adjust_min_max(self, combined_min, combined_max, upsample_rate):
  337. # We ensure that:
  338. # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
  339. # This allows us to have a common grid of resolution s, where we can align
  340. # the input histogram
  341. # start_idx maps min_val to the histogram bin index.
  342. np_min_val = self.min_val.numpy()
  343. np_max_val = self.max_val.numpy()
  344. hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate)
  345. downsample_rate = int(
  346. np.ceil((combined_max - combined_min) / (self.bins * hist_bin_width))
  347. )
  348. e = downsample_rate * (self.bins * hist_bin_width) - (
  349. combined_max - combined_min
  350. )
  351. combined_max = combined_max + e / 2
  352. combined_min = combined_min - e / 2
  353. start_idx = int(np.round((np_min_val - combined_min) / hist_bin_width))
  354. return combined_min, combined_max, downsample_rate, start_idx
  355. def sideeffect_forward(self, x_orig):
  356. x = x_orig.numpy()
  357. min_val = self.min_val.numpy()
  358. max_val = self.max_val.numpy()
  359. histogram = self.histogram.numpy()
  360. new_min = x.min()
  361. new_max = x.max()
  362. if histogram[0] == -1:
  363. new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
  364. else:
  365. new_min = min(new_min, min_val)
  366. new_max = max(new_max, max_val)
  367. # combine the existing histogram and new histogram into 1 histogram
  368. # We do this by first upsampling the histogram to a dense grid
  369. # and then downsampling the histogram efficiently
  370. (new_min, new_max, downsample_rate, start_idx) = self._adjust_min_max(
  371. new_min, new_max, self.upsample_rate
  372. )
  373. new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
  374. new_histogram = new_histogram.astype(np.float64)
  375. if new_min == min_val and new_max == max_val:
  376. new_histogram += histogram
  377. else:
  378. new_histogram = self._combine_histograms(
  379. new_histogram,
  380. histogram,
  381. self.upsample_rate,
  382. downsample_rate,
  383. start_idx,
  384. self.bins,
  385. )
  386. self.histogram = Tensor(new_histogram, dtype="float32")
  387. self.min_val = Tensor(new_min, dtype="float32")
  388. self.max_val = Tensor(new_max, dtype="float32")
  389. def forward(self, x_orig):
  390. self.sideeffect_forward(x_orig)
  391. return x_orig
  392. class PassiveObserver(Observer):
  393. r"""
  394. This class can be set :attr:`scale` derectly.
  395. """
  396. def __init__(self, q_dict, dtype: str, narrow_range: bool = False, **kwargs):
  397. super().__init__(dtype, narrow_range, **kwargs)
  398. self.q_dict = deepcopy(q_dict)
  399. if "scale" not in q_dict or q_dict["scale"] is None:
  400. raise AssertionError("Can not get an initialized scale")
  401. self.orig_scale = q_dict["scale"].numpy()
  402. @property
  403. def scale(self):
  404. return self.q_dict["scale"]
  405. @scale.setter
  406. def scale(self, value):
  407. assert value > 0
  408. self.q_dict["scale"][...] = Tensor(value)
  409. def get_qparams(self):
  410. return self.q_dict
  411. def forward(self, x):
  412. r"""
  413. Just return input because :attr:`q_dict` is set by :func:`~.apply_easy_quant`.
  414. """
  415. return x

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台