Browse Source

docs(mge/quantization): add docstring for Observer

GitOrigin-RevId: 043be3886d
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
992a90bbad
1 changed files with 48 additions and 2 deletions
  1. +48
    -2
      imperative/python/megengine/quantization/observer.py

+ 48
- 2
imperative/python/megengine/quantization/observer.py View File

@@ -26,9 +26,10 @@ logger = get_logger(__name__)


class Observer(Module, QParamsModuleMixin): class Observer(Module, QParamsModuleMixin):
r""" r"""
A base class for Observer Module.
A base class for Observer Module. Used to record input tensor's statistics for
quantization.


:param dtype: a string indicating to collect scale and zero_point of which dtype.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
""" """


def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs): def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs):
@@ -72,6 +73,14 @@ class Observer(Module, QParamsModuleMixin):




class MinMaxObserver(Observer): class MinMaxObserver(Observer):
r"""
A Observer Module records input tensor's running min and max values to calc scale.

:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""

def __init__( def __init__(
self, self,
mode: QuantMode = QuantMode.SYMMERTIC, mode: QuantMode = QuantMode.SYMMERTIC,
@@ -119,6 +128,14 @@ class MinMaxObserver(Observer):




class SyncMinMaxObserver(MinMaxObserver): class SyncMinMaxObserver(MinMaxObserver):
r"""
A distributed version of :class:`~.MinMaxObserver`.

:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""

def forward(self, x_orig): def forward(self, x_orig):
if self.enable: if self.enable:
x = x_orig.detach() x = x_orig.detach()
@@ -134,6 +151,15 @@ class SyncMinMaxObserver(MinMaxObserver):




class ExponentialMovingAverageObserver(MinMaxObserver): class ExponentialMovingAverageObserver(MinMaxObserver):
r"""
A :class:`~.MinMaxObserver` with momentum support for min/max updating.

:param momentum: momentum ratio for min/max updating.
:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""

def __init__( def __init__(
self, self,
momentum: float = 0.9, momentum: float = 0.9,
@@ -170,6 +196,15 @@ class ExponentialMovingAverageObserver(MinMaxObserver):




class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver): class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):
r"""
A distributed version of :class:`~.ExponentialMovingAverageObserver`.

:param momentum: momentum ratio for min/max updating.
:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""

def forward(self, x_orig): def forward(self, x_orig):
if self.enabled: if self.enabled:
x = x_orig.detach() x = x_orig.detach()
@@ -192,6 +227,17 @@ class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver):




class HistogramObserver(MinMaxObserver): class HistogramObserver(MinMaxObserver):
r"""
A :class:`~.MinMaxObserver` using running histogram of tensor values
for min/max updating. Usually used for calibration quantization.

:param bins: number of bins to use for the histogram.
:param upsample_rate: which ratio to interpolate histograms in.
:param mode: set quantization mode.
:param eps: a initial maximum value to avoid division by zero problem.
:param dtype: a string indicating which dtype to collect scale and zero_point of.
"""

def __init__( def __init__(
self, self,
bins: int = 2048, bins: int = 2048,


Loading…
Cancel
Save