|
|
@@ -41,7 +41,6 @@ class Observer(Module): |
|
|
|
self.dtype = dtype |
|
|
|
self.qmin = _metadata_dict[dtype].qmin |
|
|
|
self.qmax = _metadata_dict[dtype].qmax |
|
|
|
self.zero_point, self.scale = None, None |
|
|
|
self.enabled = True |
|
|
|
|
|
|
|
def get_dtype(self): |
|
|
@@ -72,23 +71,6 @@ class Observer(Module): |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
class IdentityObserver(Observer): |
|
|
|
r""" |
|
|
|
An test Observer that always return scale:1 and zero_point:0. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
|
super().__init__(*args, **kwargs) |
|
|
|
self.zero_point = ones((1), dtype="float32") |
|
|
|
self.scale = zeros((1), dtype="float32") |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
return x |
|
|
|
|
|
|
|
def get_qparams(self): |
|
|
|
return self.scale, self.zero_point |
|
|
|
|
|
|
|
|
|
|
|
class MinMaxObserver(Observer): |
|
|
|
def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs): |
|
|
|
super().__init__(*args, **kwargs) |
|
|
@@ -108,47 +90,28 @@ class MinMaxObserver(Observer): |
|
|
|
# FIXME: cond_take will destory shape, use reshape to reset shape |
|
|
|
tmp_min = tmp_min.reshape(1) |
|
|
|
tmp_max = tmp_max.reshape(1) |
|
|
|
if self.training: |
|
|
|
F.zero_grad( |
|
|
|
F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0) |
|
|
|
) |
|
|
|
F.zero_grad( |
|
|
|
F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0) |
|
|
|
) |
|
|
|
F.zero_grad( |
|
|
|
F.add_update( |
|
|
|
self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0 |
|
|
|
) |
|
|
|
) |
|
|
|
F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0) |
|
|
|
F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0) |
|
|
|
F.add_update(self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0) |
|
|
|
|
|
|
|
# FIXME: add_update is applied after the whole trace procedure in `symbolic=True` |
|
|
|
# mode. So use tmp_min/tmp_max to calc and save scale/zero_point for further |
|
|
|
# calculation in FakeQuant. |
|
|
|
self.set_scale_zero_point(tmp_min, tmp_max) |
|
|
|
|
|
|
|
def set_scale_zero_point(self, tmp_min, tmp_max): |
|
|
|
def get_qparams(self): |
|
|
|
if self.symmetric: |
|
|
|
symmetric_max_vals = F.maximum(-tmp_min, tmp_max) |
|
|
|
symmetric_max_vals = F.maximum(-self.min_val, self.max_val) |
|
|
|
# use maximun to avoid scale too small at the begin |
|
|
|
self.scale = F.maximum( |
|
|
|
scale = F.maximum( |
|
|
|
symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit |
|
|
|
) |
|
|
|
# zero_point = self.zero_point |
|
|
|
zero_point = self.zero_point |
|
|
|
else: |
|
|
|
# use maximun to avoid scale too small at the begin |
|
|
|
self.scale = F.maximum( |
|
|
|
(tmp_max - tmp_min) / (self.qmax - self.qmin), self.scale_limit |
|
|
|
scale = F.maximum( |
|
|
|
(self.max_val - self.min_val) / (self.qmax - self.qmin), |
|
|
|
self.scale_limit, |
|
|
|
) |
|
|
|
# caculate zero_point |
|
|
|
self.zero_point = self.qmin - Round()((tmp_min / self.scale)) |
|
|
|
|
|
|
|
def get_qparams(self): |
|
|
|
# scale and zero_point is runtime tensor rather than Buffer, |
|
|
|
# so need to re-calc if min_val and max_val are loaded. |
|
|
|
if self.scale is None: |
|
|
|
self.set_scale_zero_point(self.min_val, self.max_val) |
|
|
|
zero_point = self.qmin - Round()((self.min_val / scale)) |
|
|
|
|
|
|
|
return self.scale, self.zero_point |
|
|
|
return scale, zero_point |
|
|
|
|
|
|
|
def forward(self, x_orig): |
|
|
|
if self.enabled: |
|
|
|