Browse Source

docs(mge/quantization): add docstring for Observer

GitOrigin-RevId: 043be3886d
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
992a90bbad
1 changed files with 48 additions and 2 deletions
  1. +48
    -2
      imperative/python/megengine/quantization/observer.py

+ 48
- 2
imperative/python/megengine/quantization/observer.py View File

@@ -26,9 +26,10 @@ logger = get_logger(__name__)

class Observer(Module, QParamsModuleMixin):
r"""
A base class for Observer Module.
A base class for Observer Module. Used to record input tensor's statistics for
quantization.

:param dtype: a string indicating to collect scale and zero_point of which dtype.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""

def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
@@ -72,6 +73,14 @@ class Observer(Module, QParamsModuleMixin):


class MinMaxObserver(Observer):
r"""
A Observer Module records input tensor's running min and max values to calc scale.

:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""

def __init__(
self,
mode: QuantMode = QuantMode.SYMMERTIC,
@@ -119,6 +128,14 @@ class MinMaxObserver(Observer):


class SyncMinMaxObserver(MinMaxObserver):
r"""
A distributed version of :class:`~.MinMaxObserver`.

:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""

def forward(self, x_orig):
if self.enable:
x = x_orig.detach()
@@ -134,6 +151,15 @@ class SyncMinMaxObserver(MinMaxObserver):


class ExponentialMovingAverageObserver(MinMaxObserver):
r"""
A :class:`~.MinMaxObserver` with momentum support for min/max updating.

:param momentum: momentum ratio for min/max updating.
:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""

def __init__(
self,
momentum: float = 0.9,
@@ -170,6 +196,15 @@ class ExponentialMovingAverageObserver(MinMaxObserver):


class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
r"""
A distributed version of :class:`~.ExponentialMovingAverageObserver`.

:param momentum: momentum ratio for min/max updating.
:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""

def forward(self, x_orig):
if self.enabled:
x = x_orig.detach()
@@ -192,6 +227,17 @@ class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):


class HistogramObserver(MinMaxObserver):
r"""
A :class:`~.MinMaxObserver` using running histogram of tensor values
for min/max updating. Usually used for calibration quantization.

:param bins: number of bins to use for the histogram.
:param upsample_rate: which ratio to interpolate histograms in.
:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""

def __init__(
self,
bins: int = 2048,


Loading…
Cancel
Save