|
- import math
- from abc import abstractmethod
- from copy import deepcopy
- from typing import Union
-
- import numpy as np
-
- from .. import functional as F
- from ..core.tensor.dtype import QuantDtypeMeta, _builtin_quant_dtypes
- from ..distributed import WORLD, get_rank, is_distributed
- from ..functional.distributed import all_reduce_max, all_reduce_min
- from ..logger import get_logger
- from ..module import Module
- from ..tensor import Tensor
- from .utils import QParams, QParamsModuleMixin, QuantMode, create_qparams
-
- logger = get_logger(__name__)
-
-
- class Observer(Module, QParamsModuleMixin):
- r"""A base class for Observer Module. Used to record input tensor's statistics for
- quantization.
-
- Args:
- dtype: a string indicating which dtype to collect scale and zero_point of.
- """
-
- def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
- super().__init__()
- if isinstance(dtype, str):
- if not dtype in _builtin_quant_dtypes:
- raise ValueError(
- "unknown dtype: {}, only support {}".format(
- dtype, _builtin_quant_dtypes.keys()
- )
- )
- dtype = _builtin_quant_dtypes[dtype]
- if "narrow_range" in kwargs:
- del kwargs["narrow_range"]
- logger.warning(
- "FakeQuantize currently has no narrow_range param "
- "so it is ignored here",
- exc_info=DeprecationWarning,
- )
- self.dtype = dtype
- self.qmin = dtype.qmin
- self.qmax = dtype.qmax
- self.enabled = True
-
- def enable(self):
- self.enabled = True
-
- def disable(self):
- self.enabled = False
-
- def train(self, mode: bool = True, recursive: bool = True) -> None:
- super().train(mode, recursive)
- if mode:
- self.enable()
- else:
- self.disable()
-
- @abstractmethod
- def forward(self, x):
- pass
-
-
- class MinMaxObserver(Observer):
- r"""A Observer Module records input tensor's running min and max values to calc scale.
-
- Args:
- mode: set quantization mode.
- eps: a initial maximum value to avoid division by zero problem.
- dtype: a string indicating which dtype to collect scale and zero_point of.
- """
-
- def __init__(
- self,
- mode: QuantMode = QuantMode.SYMMERTIC,
- eps: float = 0.00001,
- dtype: Union[str, QuantDtypeMeta] = "qint8",
- **kwargs
- ):
- super().__init__(dtype, **kwargs)
- self.mode = mode
- self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32)
- self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32)
- self.scale_limit = eps
-
- def _calculate_qparams(self, inp_min_val, inp_max_val):
- min_val = F.minimum(0.0, inp_min_val)
- max_val = F.maximum(0.0, inp_max_val)
- if self.mode == QuantMode.SYMMERTIC:
- symmetric_max_vals = F.maximum(-min_val, max_val)
- # use maximun to avoid scale too small at the begin
- scale = F.maximum(
- symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
- )
- zero_point = None
- else:
- # use maximun to avoid scale too small at the begin
- scale = F.maximum(
- (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit
- )
- # caculate zero_point
- zero_point = self.qmin - F.round((min_val / scale))
-
- return create_qparams(self.mode, self.dtype, scale=scale, zero_point=zero_point)
-
- def get_qparams(self):
- return self._calculate_qparams(self.min_val, self.max_val)
-
- def forward(self, x_orig):
- if self.enabled:
- # stop gradient
- x = x_orig.detach()
- # find max and min
- self.min_val[...] = F.minimum(self.min_val, x.min())
- self.max_val[...] = F.maximum(self.max_val, x.max())
- return x_orig
-
-
- class SyncMinMaxObserver(MinMaxObserver):
- r"""A distributed version of :class:`~.MinMaxObserver`.
-
- Args:
- mode: set quantization mode.
- eps: a initial maximum value to avoid division by zero problem.
- dtype: a string indicating which dtype to collect scale and zero_point of.
- """
-
- def forward(self, x_orig):
- if self.enable:
- x = x_orig.detach()
- if is_distributed():
- min_x = all_reduce_min(x.min(), WORLD)
- max_x = all_reduce_max(x.max(), WORLD)
- else:
- min_x = x.min()
- max_x = x.max()
- self.min_val[...] = F.minimum(self.min_val, min_x)
- self.max_val[...] = F.maximum(self.max_val, max_x)
- return x_orig
-
-
- class ExponentialMovingAverageObserver(MinMaxObserver):
- r"""A :class:`~.MinMaxObserver` with momentum support for min/max updating.
-
- Args:
- momentum: momentum ratio for min/max updating.
- mode: set quantization mode.
- eps: a initial maximum value to avoid division by zero problem.
- dtype: a string indicating which dtype to collect scale and zero_point of.
- """
-
- def __init__(
- self,
- momentum: float = 0.9,
- mode: QuantMode = QuantMode.SYMMERTIC,
- eps: float = 0.00001,
- dtype: Union[str, QuantDtypeMeta] = "qint8",
- **kwargs
- ):
- super().__init__(mode, eps, dtype, **kwargs)
- self.momentum = Tensor(momentum, dtype="float32")
- # used to avoid if-clauses in the first forward which is not supported
- # in trace mode.
- self.runtime_momentum = Tensor(0.0)
-
- def set_momentum(self, momentum):
- self.momentum = Tensor(momentum, dtype="float32")
-
- def forward(self, x_orig):
- if self.enabled:
- # stop gradient
- x = x_orig.detach()
- # Exponential Moving Average
- self.min_val[...] = (
- self.min_val * self.runtime_momentum
- + (1 - self.runtime_momentum) * x.min()
- )
- self.max_val[...] = (
- self.max_val * self.runtime_momentum
- + (1 - self.runtime_momentum) * x.max()
- )
- self.runtime_momentum[...] = self.momentum
-
- return x_orig
-
-
- class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
- r"""A distributed version of :class:`~.ExponentialMovingAverageObserver`.
-
- Args:
- momentum: momentum ratio for min/max updating.
- mode: set quantization mode.
- eps: a initial maximum value to avoid division by zero problem.
- dtype: a string indicating which dtype to collect scale and zero_point of.
- """
-
- def forward(self, x_orig):
- if self.enabled:
- x = x_orig.detach()
- if is_distributed():
- min_x = all_reduce_min(x.min(), WORLD)
- max_x = all_reduce_max(x.max(), WORLD)
- else:
- min_x = x.min()
- max_x = x.max()
- self.min_val[...] = (
- self.min_val * self.runtime_momentum
- + (1 - self.runtime_momentum) * min_x
- )
- self.max_val[...] = (
- self.max_val * self.runtime_momentum
- + (1 - self.runtime_momentum) * max_x
- )
- self.runtime_momentum[...] = self.momentum
- return x_orig
-
-
- class HistogramObserver(MinMaxObserver):
- r"""A :class:`~.MinMaxObserver` using running histogram of tensor values
- for min/max updating. Usually used for calibration quantization.
-
- Args:
- bins: number of bins to use for the histogram.
- upsample_rate: which ratio to interpolate histograms in.
- mode: set quantization mode.
- eps: a initial maximum value to avoid division by zero problem.
- dtype: a string indicating which dtype to collect scale and zero_point of.
- """
-
- def __init__(
- self,
- bins: int = 2048,
- upsample_rate: int = 128,
- mode: QuantMode = QuantMode.SYMMERTIC,
- eps: float = 0.00001,
- dtype: Union[str, QuantDtypeMeta] = "qint8",
- **kwargs
- ):
- super().__init__(mode, eps, dtype, **kwargs)
- self.bins = bins
- self.upsample_rate = upsample_rate
- self.dst_nbins = (
- _builtin_quant_dtypes[dtype].qmax - _builtin_quant_dtypes[dtype].qmin + 1
- )
- self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32")
-
- def _non_linear_param_search(self):
- r"""Non-linear parameter search.
- An approximation for L2 error minimization for selecting min/max.
- By selecting new min/max, we filter out outliers in input distribution.
- """
-
- np_min_val = self.min_val.numpy()
- np_max_val = self.max_val.numpy()
- np_histogram = self.histogram.numpy()
- assert len(np_histogram) == self.bins, "bins mistmatch"
- bin_width = (np_max_val - np_min_val) / self.bins
-
- def _get_norm(delta_begin, delta_end, density, norm_type):
- r"""Compute the norm of the values uniformaly distributed between
- delta_begin and delta_end.
- norm = density * (integral_{begin, end} x^2)
- = density * (end^3 - begin^3) / 3
- """
- assert norm_type == "L2", "Only L2 norms are currently supported"
- norm = 0.0
- if norm_type == "L2":
- norm = (
- delta_end * delta_end * delta_end
- - delta_begin * delta_begin * delta_begin
- ) / 3
- return density * norm
-
- def _compute_quantization_error(next_start_bin, next_end_bin, norm_type):
- r"""Compute the quantization error if we use start_bin to end_bin as the
- min and max to do the quantization.
- """
-
- norm = 0.0
- dst_bin_width = (
- bin_width * (next_end_bin - next_start_bin + 1) / self.dst_nbins
- )
- if dst_bin_width == 0.0:
- return 0.0
- for src_bin in range(self.bins):
- # distances from the beginning of first dst_bin to the beginning and
- # end of src_bin
- src_bin_begin = (src_bin - next_start_bin) * bin_width
- src_bin_end = src_bin_begin + bin_width
-
- # which dst_bins the beginning and end of src_bin belong to?
- dst_bin_of_begin = min(
- self.dst_nbins - 1,
- max(0.0, math.floor(src_bin_begin / dst_bin_width)),
- )
- dst_bin_of_end = min(
- self.dst_nbins - 1,
- max(0.0, math.floor(src_bin_end / dst_bin_width)),
- )
- dst_bin_of_begin_center = (
- dst_bin_of_begin * dst_bin_width + dst_bin_width / 2
- )
-
- density = np_histogram[src_bin] / bin_width
- if dst_bin_of_begin == dst_bin_of_end:
- # if src_bin is entirely within 1 dst_bin
- delta_begin = src_bin_begin - dst_bin_of_begin_center
- delta_end = src_bin_end - dst_bin_of_begin_center
- norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
- else:
- delta_begin = src_bin_begin - dst_bin_of_begin_center
- delta_end = dst_bin_width / 2
- norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
-
- norm = norm + (dst_bin_of_end - dst_bin_of_begin - 1) * _get_norm(
- -dst_bin_width / 2, dst_bin_width / 2, density, norm_type
- )
-
- dst_bin_of_end_center = (
- dst_bin_of_end * dst_bin_width + dst_bin_width / 2
- )
-
- delta_begin = -dst_bin_width / 2
- delta_end = src_bin_end - dst_bin_of_end_center
- norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
- return norm
-
- # cumulative sum
- total = sum(np_histogram)
- cSum = np.cumsum(np_histogram, axis=0)
-
- stepsize = 1e-5 # granularity
- alpha = 0.0 # lower bound
- beta = 1.0 # upper bound
- start_bin = 0
- end_bin = self.bins - 1
- norm_min = float("inf")
-
- while alpha < beta:
- # Find the next step
- next_alpha = alpha + stepsize
- next_beta = beta - stepsize
-
- # find the left and right bins between the quantile bounds
- l = start_bin
- r = end_bin
- while l < end_bin and cSum[l] < next_alpha * total:
- l = l + 1
- while r > start_bin and cSum[r] > next_beta * total:
- r = r - 1
-
- # decide the next move
- next_start_bin = start_bin
- next_end_bin = end_bin
- if (l - start_bin) > (end_bin - r):
- # move the start bin
- next_start_bin = l
- alpha = next_alpha
- else:
- # move the end bin
- next_end_bin = r
- beta = next_beta
-
- if next_start_bin == start_bin and next_end_bin == end_bin:
- continue
-
- # calculate the quantization error using next_start_bin and next_end_bin
- norm = _compute_quantization_error(next_start_bin, next_end_bin, "L2")
-
- if norm > norm_min:
- break
- norm_min = norm
- start_bin = next_start_bin
- end_bin = next_end_bin
-
- new_min = self.min_val + Tensor(bin_width * start_bin, dtype=np.float32)
- new_max = self.min_val + Tensor(bin_width * (end_bin + 1), dtype=np.float32)
- return new_min, new_max
-
- def get_qparams(self):
- new_min, new_max = self._non_linear_param_search()
- return self._calculate_qparams(new_min, new_max)
-
- def _combine_histograms(
- self, orig_hist, new_hist, upsample_rate, downsample_rate, start_idx, Nbins
- ):
- # First up-sample the histogram with new data by a factor of L
- # This creates an approximate probability density thats piecwise constant
- upsampled_histogram = new_hist.repeat(upsample_rate)
- # Now insert the upsampled histogram into the output
- # histogram, which is initialized with zeros.
- # The offset at which the histogram is introduced is determined
- # by the start index as the output histogram can cover a wider range
- histogram_with_output_range = np.zeros((Nbins * downsample_rate))
- histogram_with_output_range[
- start_idx : Nbins * upsample_rate + start_idx
- ] = upsampled_histogram
- # Compute integral histogram, double precision is needed to ensure
- # that there are no overflows
- integral_histogram = np.cumsum(histogram_with_output_range, 0)[
- downsample_rate - 1 :: downsample_rate
- ]
- # Finally perform interpolation
- shifted_integral_histogram = np.zeros((Nbins))
- shifted_integral_histogram[1:Nbins] = integral_histogram[0:-1]
- interpolated_histogram = (
- integral_histogram - shifted_integral_histogram
- ) / upsample_rate
- orig_hist = orig_hist + interpolated_histogram
- return orig_hist
-
- def _adjust_min_max(self, combined_min, combined_max, upsample_rate):
- # We ensure that:
- # (combined_max - combined_min)/(downsample_rate*Nbins) = (max - min)/(upsample_rate*Nbins)
- # This allows us to have a common grid of resolution s, where we can align
- # the input histogram
- # start_idx maps min_val to the histogram bin index.
- np_min_val = self.min_val.numpy()
- np_max_val = self.max_val.numpy()
-
- hist_bin_width = (np_max_val - np_min_val) / (self.bins * upsample_rate)
- downsample_rate = int(
- np.ceil((combined_max - combined_min) / (self.bins * hist_bin_width))
- )
- e = downsample_rate * (self.bins * hist_bin_width) - (
- combined_max - combined_min
- )
- combined_max = combined_max + e / 2
- combined_min = combined_min - e / 2
- start_idx = int(np.round((np_min_val - combined_min) / hist_bin_width))
-
- return combined_min, combined_max, downsample_rate, start_idx
-
- def sideeffect_forward(self, x_orig):
- x = x_orig.numpy()
- min_val = self.min_val.numpy()
- max_val = self.max_val.numpy()
- histogram = self.histogram.numpy()
- new_min = x.min()
- new_max = x.max()
- if histogram[0] == -1:
- new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
- else:
- new_min = min(new_min, min_val)
- new_max = max(new_max, max_val)
- # combine the existing histogram and new histogram into 1 histogram
- # We do this by first upsampling the histogram to a dense grid
- # and then downsampling the histogram efficiently
- (new_min, new_max, downsample_rate, start_idx) = self._adjust_min_max(
- new_min, new_max, self.upsample_rate
- )
-
- new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
- new_histogram = new_histogram.astype(np.float64)
- if new_min == min_val and new_max == max_val:
- new_histogram += histogram
- else:
- new_histogram = self._combine_histograms(
- new_histogram,
- histogram,
- self.upsample_rate,
- downsample_rate,
- start_idx,
- self.bins,
- )
-
- self.histogram = Tensor(new_histogram, dtype="float32")
- self.min_val = Tensor(new_min, dtype="float32")
- self.max_val = Tensor(new_max, dtype="float32")
-
- def forward(self, x_orig):
- self.sideeffect_forward(x_orig)
- return x_orig
-
-
- class PassiveObserver(Observer):
- r"""An Observer that supports setting :attr:`scale` directly."""
-
- def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
- super().__init__(dtype, **kwargs)
- self.qparams = None
- self.orig_scale = None
-
- @property
- def scale(self):
- return self.qparams.scale
-
- @scale.setter
- def scale(self, value: np.ndarray):
- assert np.all(value > 0)
- self.qparams.scale[...] = Tensor(value)
-
- def get_qparams(self):
- return self.qparams
-
- def set_qparams(self, qparams: QParams):
- r"""set the ``qparams``.
-
- Args:
- qparams: used to set initial scale.
- """
- self.qparams = deepcopy(qparams)
- if qparams.scale is None:
- raise AssertionError("Can not get an initialized scale")
- if qparams.dtype_meta is None:
- qparams.dtype_meta = self.dtype
- else:
- assert (
- qparams.dtype_meta is self.dtype
- ), "input qparams' dtype is not equal to self.dtype.\nqparams.dtype_meta={}\nself.dtype={}".format(
- qparams.dtype_meta, self.dtype
- )
- self.orig_scale = qparams.scale.numpy()
-
- def forward(self, x):
- r"""Just return input because :attr:`qparams` is set by :func:`~.apply_easy_quant`."""
- return x
|