|
|
@@ -26,9 +26,10 @@ logger = get_logger(__name__) |
|
|
|
|
|
|
|
class Observer(Module, QParamsModuleMixin): |
|
|
|
r""" |
|
|
|
A base class for Observer Module. |
|
|
|
A base class for Observer Module. Used to record input tensor's statistics for |
|
|
|
quantization. |
|
|
|
|
|
|
|
:param dtype: a string indicating to collect scale and zero_point of which dtype. |
|
|
|
:param dtype: a string indicating which dtype to collect scale and zero_point of. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, dtype: Union[str, QuantDtypeMeta], **kwargs): |
|
|
@@ -72,6 +73,14 @@ class Observer(Module, QParamsModuleMixin): |
|
|
|
|
|
|
|
|
|
|
|
class MinMaxObserver(Observer): |
|
|
|
r""" |
|
|
|
A Observer Module records input tensor's running min and max values to calc scale. |
|
|
|
|
|
|
|
:param mode: set quantization mode. |
|
|
|
:param eps: a initial maximum value to avoid division by zero problem. |
|
|
|
:param dtype: a string indicating which dtype to collect scale and zero_point of. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
mode: QuantMode = QuantMode.SYMMERTIC, |
|
|
@@ -119,6 +128,14 @@ class MinMaxObserver(Observer): |
|
|
|
|
|
|
|
|
|
|
|
class SyncMinMaxObserver(MinMaxObserver): |
|
|
|
r""" |
|
|
|
A distributed version of :class:`~.MinMaxObserver`. |
|
|
|
|
|
|
|
:param mode: set quantization mode. |
|
|
|
:param eps: a initial maximum value to avoid division by zero problem. |
|
|
|
:param dtype: a string indicating which dtype to collect scale and zero_point of. |
|
|
|
""" |
|
|
|
|
|
|
|
def forward(self, x_orig): |
|
|
|
if self.enable: |
|
|
|
x = x_orig.detach() |
|
|
@@ -134,6 +151,15 @@ class SyncMinMaxObserver(MinMaxObserver): |
|
|
|
|
|
|
|
|
|
|
|
class ExponentialMovingAverageObserver(MinMaxObserver): |
|
|
|
r""" |
|
|
|
A :class:`~.MinMaxObserver` with momentum support for min/max updating. |
|
|
|
|
|
|
|
:param momentum: momentum ratio for min/max updating. |
|
|
|
:param mode: set quantization mode. |
|
|
|
:param eps: a initial maximum value to avoid division by zero problem. |
|
|
|
:param dtype: a string indicating which dtype to collect scale and zero_point of. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
momentum: float = 0.9, |
|
|
@@ -170,6 +196,15 @@ class ExponentialMovingAverageObserver(MinMaxObserver): |
|
|
|
|
|
|
|
|
|
|
|
class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver): |
|
|
|
r""" |
|
|
|
A distributed version of :class:`~.ExponentialMovingAverageObserver`. |
|
|
|
|
|
|
|
:param momentum: momentum ratio for min/max updating. |
|
|
|
:param mode: set quantization mode. |
|
|
|
:param eps: a initial maximum value to avoid division by zero problem. |
|
|
|
:param dtype: a string indicating which dtype to collect scale and zero_point of. |
|
|
|
""" |
|
|
|
|
|
|
|
def forward(self, x_orig): |
|
|
|
if self.enabled: |
|
|
|
x = x_orig.detach() |
|
|
@@ -192,6 +227,17 @@ class SyncExponentialMovingAverageObserver(ExponentialMovingAverageObserver): |
|
|
|
|
|
|
|
|
|
|
|
class HistogramObserver(MinMaxObserver): |
|
|
|
r""" |
|
|
|
A :class:`~.MinMaxObserver` using running histogram of tensor values |
|
|
|
for min/max updating. Usually used for calibration quantization. |
|
|
|
|
|
|
|
:param bins: number of bins to use for the histogram. |
|
|
|
:param upsample_rate: which ratio to interpolate histograms in. |
|
|
|
:param mode: set quantization mode. |
|
|
|
:param eps: a initial maximum value to avoid division by zero problem. |
|
|
|
:param dtype: a string indicating which dtype to collect scale and zero_point of. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
bins: int = 2048, |
|
|
|