|
|
@@ -7,6 +7,7 @@ |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
import math |
|
|
|
from abc import abstractmethod |
|
|
|
from enum import Enum |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
@@ -46,9 +47,11 @@ class Observer(Module): |
|
|
|
self.enabled = True |
|
|
|
|
|
|
|
def get_dtype(self): |
|
|
|
scale, zero_point = self.get_qparams() |
|
|
|
numpy_scale = None if scale is None else scale.numpy()[0] |
|
|
|
numpy_zero_point = None if zero_point is None else zero_point.numpy()[0] |
|
|
|
q_dict = self.get_qparams() |
|
|
|
numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0] |
|
|
|
numpy_zero_point = ( |
|
|
|
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0] |
|
|
|
) |
|
|
|
return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point) |
|
|
|
|
|
|
|
def enable(self): |
|
|
@@ -73,13 +76,29 @@ class Observer(Module): |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
class ObserverMode(Enum): |
|
|
|
SYMMERTIC = 1 |
|
|
|
ASYMMERTIC = 2 |
|
|
|
|
|
|
|
|
|
|
|
def create_observer_dict(mode): |
|
|
|
if mode == ObserverMode.SYMMERTIC: |
|
|
|
return { |
|
|
|
"mode": ObserverMode.SYMMERTIC, |
|
|
|
"scale": None, |
|
|
|
} |
|
|
|
else: |
|
|
|
return { |
|
|
|
"mode": ObserverMode.ASYMMERTIC, |
|
|
|
"scale": None, |
|
|
|
"zero_point": None, |
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
class MinMaxObserver(Observer): |
|
|
|
def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs): |
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
self.symmetric = symmetric |
|
|
|
if self.symmetric: |
|
|
|
# assert qmin + qmax == -1, 'when reduce_range, qmin + qmax shoule equal -1' |
|
|
|
self.zero_point = tensor((self.qmin + self.qmax + 1) // 2) |
|
|
|
def __init__(self, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"): |
|
|
|
super().__init__(dtype) |
|
|
|
self.mode = mode |
|
|
|
|
|
|
|
self.min_val = Buffer(0.0, dtype=np.float32) |
|
|
|
self.max_val = Buffer(0.0, dtype=np.float32) |
|
|
@@ -99,22 +118,23 @@ class MinMaxObserver(Observer): |
|
|
|
def _calculate_qparams(self, inp_min_val, inp_max_val): |
|
|
|
min_val = F.minimum(0.0, inp_min_val) |
|
|
|
max_val = F.maximum(0.0, inp_max_val) |
|
|
|
if self.symmetric: |
|
|
|
q_dict = create_observer_dict(self.mode) |
|
|
|
if self.mode == ObserverMode.SYMMERTIC: |
|
|
|
symmetric_max_vals = F.maximum(-min_val, max_val) |
|
|
|
# use maximun to avoid scale too small at the begin |
|
|
|
scale = F.maximum( |
|
|
|
q_dict["scale"] = F.maximum( |
|
|
|
symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit |
|
|
|
) |
|
|
|
zero_point = self.zero_point |
|
|
|
# zero_point = self.zero_point |
|
|
|
else: |
|
|
|
# use maximun to avoid scale too small at the begin |
|
|
|
scale = F.maximum( |
|
|
|
q_dict["scale"] = F.maximum( |
|
|
|
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit, |
|
|
|
) |
|
|
|
# caculate zero_point |
|
|
|
zero_point = self.qmin - Round()((min_val / scale)) |
|
|
|
q_dict["zero_point"] = self.qmin - Round()((min_val / scale)) |
|
|
|
|
|
|
|
return scale, zero_point |
|
|
|
return q_dict |
|
|
|
|
|
|
|
def get_qparams(self): |
|
|
|
return self._calculate_qparams(self.min_val, self.max_val) |
|
|
@@ -135,8 +155,10 @@ class MinMaxObserver(Observer): |
|
|
|
|
|
|
|
|
|
|
|
class ExponentialMovingAverageObserver(MinMaxObserver): |
|
|
|
def __init__(self, momentum=0.9, *args, **kwargs): |
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
def __init__( |
|
|
|
self, momentum=0.9, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8" |
|
|
|
): |
|
|
|
super().__init__(mode, eps, dtype) |
|
|
|
self.momentum = Buffer(momentum) |
|
|
|
|
|
|
|
def set_momentum(self, momentum): |
|
|
@@ -170,8 +192,15 @@ class ExponentialMovingAverageObserver(MinMaxObserver): |
|
|
|
|
|
|
|
|
|
|
|
class HistogramObserver(MinMaxObserver): |
|
|
|
def __init__(self, bins=2048, upsample_rate=128, dtype="qint8", *args, **kwargs): |
|
|
|
super().__init__(dtype=dtype, *args, **kwargs) |
|
|
|
def __init__( |
|
|
|
self, |
|
|
|
bins=2048, |
|
|
|
upsample_rate=128, |
|
|
|
dtype="qint8", |
|
|
|
mode=ObserverMode.SYMMERTIC, |
|
|
|
eps=0.00001, |
|
|
|
): |
|
|
|
super().__init__(mode, eps, dtype) |
|
|
|
self.bins = bins |
|
|
|
self.upsample_rate = upsample_rate |
|
|
|
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 |
|
|
|