You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

observer.py 17 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import math
  9. from abc import abstractmethod
  10. import numpy as np
  11. from .. import functional as F
  12. from ..core.tensor.dtype import _metadata_dict, get_quantized_dtype
  13. from ..distributed import WORLD, get_rank, is_distributed
  14. from ..functional.distributed import all_reduce_max, all_reduce_min
  15. from ..module import Module
  16. from ..tensor import Tensor
  17. from .utils import QuantMode, Round, get_qparam_dict
  18. class Observer(Module):
  19. r"""
  20. A base class for Observer Module.
  21. :param dtype: a string indicating to collect scale and zero_point of which dtype.
  22. :param narrow_range: whether the absolute value of ``qmin`` is the same as ``qmax``,
  23. instead of 1 greater. Usually True for weight and False for activation.
  24. """
  25. def __init__(self, dtype: str, narrow_range: bool = False):
  26. super().__init__()
  27. if dtype not in _metadata_dict.keys():
  28. raise ValueError(
  29. "unknown dtype: {}, only support {}".format(
  30. dtype, _metadata_dict.keys()
  31. )
  32. )
  33. self.dtype = dtype
  34. self.narrow_range = narrow_range
  35. self.qmin = (
  36. -_metadata_dict[dtype].qmax if narrow_range else _metadata_dict[dtype].qmin
  37. )
  38. self.qmax = _metadata_dict[dtype].qmax
  39. self.enabled = True
  40. def get_dtype(self):
  41. q_dict = self.get_qparams()
  42. numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0]
  43. numpy_zero_point = (
  44. None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0]
  45. )
  46. return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point)
  47. def enable(self):
  48. self.enabled = True
  49. def disable(self):
  50. self.enabled = False
  51. def train(self, mode: bool = True, recursive: bool = True) -> None:
  52. super().train(mode, recursive)
  53. if mode:
  54. self.enable()
  55. else:
  56. self.disable()
  57. @abstractmethod
  58. def forward(self, x):
  59. pass
  60. @abstractmethod
  61. def get_qparams(self, **kwargs):
  62. pass
  63. class MinMaxObserver(Observer):
  64. def __init__(
  65. self,
  66. mode=QuantMode.SYMMERTIC,
  67. eps=0.00001,
  68. dtype="qint8",
  69. narrow_range: bool = False,
  70. ):
  71. super().__init__(dtype, narrow_range)
  72. self.mode = mode
  73. self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32)
  74. self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32)
  75. self.scale_limit = eps
  76. def _calculate_qparams(self, inp_min_val, inp_max_val):
  77. min_val = F.minimum(0.0, inp_min_val)
  78. max_val = F.maximum(0.0, inp_max_val)
  79. q_dict = get_qparam_dict(self.mode)
  80. q_dict["min_val"] = inp_min_val
  81. q_dict["max_val"] = inp_max_val
  82. q_dict["enable_observer"] = self.enable
  83. if self.mode == QuantMode.SYMMERTIC:
  84. symmetric_max_vals = F.maximum(-min_val, max_val)
  85. # use maximun to avoid scale too small at the begin
  86. q_dict["scale"] = F.maximum(
  87. symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
  88. )
  89. # zero_point = self.zero_point
  90. else:
  91. # use maximun to avoid scale too small at the begin
  92. q_dict["scale"] = F.maximum(
  93. (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit,
  94. )
  95. # caculate zero_point
  96. q_dict["zero_point"] = self.qmin - Round()((min_val / q_dict["scale"]))
  97. return q_dict
  98. def get_qparams(self):
  99. return self._calculate_qparams(self.min_val, self.max_val)
  100. def forward(self, x_orig):
  101. if self.enabled:
  102. # stop gradient
  103. x = x_orig.detach()
  104. # find max and min
  105. self.min_val._reset(F.minimum(self.min_val, x.min()))
  106. self.max_val._reset(F.maximum(self.max_val, x.max()))
  107. return x_orig
  108. class SyncMinMaxObserver(MinMaxObserver):
  109. def forward(self, x_orig):
  110. if self.enable:
  111. x = x_orig.detach()
  112. if is_distributed():
  113. min_x = all_reduce_min(x.min(), WORLD)
  114. max_x = all_reduce_max(x.max(), WORLD)
  115. else:
  116. min_x = x.min()
  117. max_x = x.max()
  118. self.min_val._reset(F.minimum(self.min_val, min_x))
  119. self.max_val._reset(F.maximum(self.max_val, max_x))
  120. return x_orig
  121. class ExponentialMovingAverageObserver(MinMaxObserver):
  122. def __init__(
  123. self,
  124. momentum=0.9,
  125. mode=QuantMode.SYMMERTIC,
  126. eps=0.00001,
  127. dtype="qint8",
  128. narrow_range: bool = False,
  129. ):
  130. super().__init__(mode, eps, dtype, narrow_range)
  131. self.momentum = Tensor(momentum)
  132. self.runtime_momentum = Tensor(0.0)
  133. def set_momentum(self, momentum):
  134. self.momentum._reset(momentum)
  135. def forward(self, x_orig):
  136. if self.enabled:
  137. # stop gradient
  138. x = x_orig.detach()
  139. # Exponential Moving Average
  140. self.min_val._reset(
  141. self.min_val * self.runtime_momentum
  142. + (1 - self.runtime_momentum) * x.min()
  143. )
  144. self.max_val._reset(
  145. self.max_val * self.runtime_momentum
  146. + (1 - self.runtime_momentum) * x.max()
  147. )
  148. self.runtime_momentum = self.momentum
  149. return x_orig
  150. class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
  151. def forward(self, x_orig):
  152. if self.enabled:
  153. x = x_orig.detach()
  154. if is_distributed:
  155. min_x = all_reduce_min(x.min(), WORLD)
  156. max_x = all_reduce_max(x.max(), WORLD)
  157. else:
  158. min_x = x.min()
  159. max_x = x.max()
  160. self.min_val._reset(
  161. self.min_val * self.runtime_momentum
  162. + (1 - self.runtime_momentum) * min_x
  163. )
  164. self.max_val._reset(
  165. self.max_val * self.runtime_momentum
  166. + (1 - self.runtime_momentum) * max_x
  167. )
  168. self.runtime_momentum = self.momentum
  169. return x_orig
  170. class HistogramObserver(MinMaxObserver):
  171. def __init__(
  172. self,
  173. bins=2048,
  174. upsample_rate=128,
  175. mode=QuantMode.SYMMERTIC,
  176. eps=0.00001,
  177. dtype="qint8",
  178. narrow_range: bool = False,
  179. ):
  180. super().__init__(mode, eps, dtype, narrow_range)
  181. self.bins = bins
  182. self.upsample_rate = upsample_rate
  183. self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1
  184. self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32")
  185. def _non_linear_param_search(self):
  186. r"""
  187. Non-linear parameter search.
  188. An approximation for L2 error minimization for selecting min/max.
  189. By selecting new min/max, we filter out outliers in input distribution.
  190. """
  191. np_min_val = self.min_val.numpy()[0]
  192. np_max_val = self.max_val.numpy()[0]
  193. np_histogram = self.histogram.numpy()
  194. assert len(np_histogram) == self.bins, "bins mistmatch"
  195. bin_width = (np_max_val - np_min_val) / self.bins
  196. def _get_norm(delta_begin, delta_end, density, norm_type):
  197. r"""
  198. Compute the norm of the values uniformaly distributed between
  199. delta_begin and delta_end.
  200. norm = density * (integral_{begin, end} x^2)
  201. = density * (end^3 - begin^3) / 3
  202. """
  203. assert norm_type == "L2", "Only L2 norms are currently supported"
  204. norm = 0.0
  205. if norm_type == "L2":
  206. norm = (
  207. delta_end * delta_end * delta_end
  208. - delta_begin * delta_begin * delta_begin
  209. ) / 3
  210. return density * norm
  211. def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
  212. r"""
  213. Compute the quantization error if we use start_bin to end_bin as the
  214. min and max to do the quantization.
  215. """
  216. norm = 0.0
  217. dst_bin_width = (
  218. bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
  219. )
  220. if dst_bin_width == 0.0:
  221. return 0.0
  222. for src_bin in range(self.bins):
  223. # distances from the beginning of first dst_bin to the beginning and
  224. # end of src_bin
  225. src_bin_begin = (src_bin - next_start_bin) * bin_width
  226. src_bin_end = src_bin_begin + bin_width
  227. # which dst_bins the beginning and end of src_bin belong to?
  228. dst_bin_of_begin = min(
  229. self.dst_nbins - 1,
  230. max(0.0, math.floor(src_bin_begin / dst_bin_width)),
  231. )
  232. dst_bin_of_end = min(
  233. self.dst_nbins - 1,
  234. max(0.0, math.floor(src_bin_end / dst_bin_width)),
  235. )
  236. dst_bin_of_begin_center = (
  237. dst_bin_of_begin * dst_bin_width + dst_bin_width / 2
  238. )
  239. density = np_histogram[src_bin] / bin_width
  240. if dst_bin_of_begin == dst_bin_of_end:
  241. # if src_bin is entirely within 1 dst_bin
  242. delta_begin = src_bin_begin - dst_bin_of_begin_center
  243. delta_end = src_bin_end - dst_bin_of_begin_center
  244. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  245. else:
  246. delta_begin = src_bin_begin - dst_bin_of_begin_center
  247. delta_end = dst_bin_width / 2
  248. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  249. norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm(
  250. -dst_bin_width / 2, dst_bin_width / 2, density, norm_type
  251. )
  252. dst_bin_of_end_center = (
  253. dst_bin_of_end * dst_bin_width + dst_bin_width / 2
  254. )
  255. delta_begin = -dst_bin_width / 2
  256. delta_end = src_bin_end - dst_bin_of_end_center
  257. norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
  258. return norm
  259. # cumulative sum
  260. total = sum(np_histogram)
  261. cSum = np.cumsum(np_histogram, axis=0)
  262. stepsize = 1e-5 # granularity
  263. alpha = 0.0 # lower bound
  264. beta = 1.0 # upper bound
  265. start_bin = 0
  266. end_bin = self.bins - 1
  267. norm_min = float("inf")
  268. while alpha < beta:
  269. # Find the next step
  270. next_alpha = alpha + stepsize
  271. next_beta = beta - stepsize
  272. # find the left and right bins between the quantile bounds
  273. l = start_bin
  274. r = end_bin
  275. while l < end_bin and cSum[l] < next_alpha * total:
  276. l = l + 1
  277. while r > start_bin and cSum[r] > next_beta * total:
  278. r = r - 1
  279. # decide the next move
  280. next_start_bin = start_bin
  281. next_end_bin = end_bin
  282. if (l - start_bin) > (end_bin - r):
  283. # move the start bin
  284. next_start_bin = l
  285. alpha = next_alpha
  286. else:
  287. # move the end bin
  288. next_end_bin = r
  289. beta = next_beta
  290. if next_start_bin == start_bin and next_end_bin == end_bin:
  291. continue
  292. # calculate the quantization error using next_start_bin and next_end_bin
  293. norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2")
  294. if norm > norm_min:
  295. break
  296. norm_min = norm
  297. start_bin = next_start_bin
  298. end_bin = next_end_bin
  299. new_min = self.min_val + Tensor(bin_width * start_bin, dtype=np.float32)
  300. new_max = self.min_val + Tensor(bin_width * (end_bin + 1), dtype=np.float32)
  301. return new_min, new_max
  302. def get_qparams(self):
  303. new_min, new_max = self._non_linear_param_search()
  304. return self._calculate_qparams(new_min, new_max)
  305. def _combine_histograms(
  306. self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins
  307. ):
  308. # First up-sample the histogram with new data by a factor of L
  309. # This creates an approximate probability density thats piecwise constant
  310. upsampled_histogram = new_hist.repeat(upsample_rate)
  311. # Now insert the upsampled histogram into the output
  312. # histogram, which is initialized with zeros.
  313. # The offset at which the histogram is introduced is determined
  314. # by the start index as the output histogram can cover a wider range
  315. histogram_with_output_range = np.zeros((Nbins * downsample_rate))
  316. histogram_with_output_range[
  317. start_idx : Nbins * upsample_rate + start_idx
  318. ] = upsampled_histogram
  319. # Compute integral histogram, double precision is needed to ensure
  320. # that there are no overflows
  321. integral_histogram = np.cumsum(histogram_with_output_range, 0)[
  322. downsample_rate - 1 :: downsample_rate
  323. ]
  324. # Finally perform interpolation
  325. shifted_integral_histogram = np.zeros((Nbins))
  326. shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1]
  327. interpolated_histogram = (
  328. integral_histogram - shifted_integral_histogram
  329. ) / upsample_rate
  330. orig_hist = orig_hist + interpolated_histogram
  331. return orig_hist
  332. def _adjust_min_max(self, combined_min, combined_max, upsample_rate):
  333. # We ensure that:
  334. # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
  335. # This allows us to have a common grid of resolution s, where we can align
  336. # the input histogram
  337. # start_idx maps min_val to the histogram bin index.
  338. np_min_val = self.min_val.numpy()[0]
  339. np_max_val = self.max_val.numpy()[0]
  340. hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate)
  341. downsample_rate = int(
  342. np.ceil((combined_max - combined_min) / (self.bins * hist_bin_width))
  343. )
  344. e = downsample_rate * (self.bins * hist_bin_width) - (
  345. combined_max - combined_min
  346. )
  347. combined_max = combined_max + e / 2
  348. combined_min = combined_min - e / 2
  349. start_idx = int(np.round((np_min_val - combined_min) / hist_bin_width))
  350. return combined_min, combined_max, downsample_rate, start_idx
  351. def sideeffect_forward(self, x_orig):
  352. x = x_orig.numpy()
  353. min_val = self.min_val.numpy()[0]
  354. max_val = self.max_val.numpy()[0]
  355. histogram = self.histogram.numpy()
  356. new_min = x.min()
  357. new_max = x.max()
  358. if histogram[0] == -1:
  359. new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
  360. else:
  361. new_min = min(new_min, min_val)
  362. new_max = max(new_max, max_val)
  363. # combine the existing histogram and new histogram into 1 histogram
  364. # We do this by first upsampling the histogram to a dense grid
  365. # and then downsampling the histogram efficiently
  366. (new_min, new_max, downsample_rate, start_idx,) = self._adjust_min_max(
  367. new_min, new_max, self.upsample_rate
  368. )
  369. new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
  370. new_histogram = new_histogram.astype(np.float64)
  371. if new_min == min_val and new_max == max_val:
  372. new_histogram += histogram
  373. else:
  374. new_histogram = self._combine_histograms(
  375. new_histogram,
  376. histogram,
  377. self.upsample_rate,
  378. downsample_rate,
  379. start_idx,
  380. self.bins,
  381. )
  382. self.histogram._reset(new_histogram)
  383. self.min_val._reset(new_min)
  384. self.max_val._reset(new_max)
  385. def forward(self, x_orig):
  386. self.sideeffect_forward(x_orig)
  387. return x_orig

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台