|
|
@@ -99,21 +99,9 @@ class MinMaxObserver(Observer): |
|
|
|
def __init__(self, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"): |
|
|
|
super().__init__(dtype) |
|
|
|
self.mode = mode |
|
|
|
|
|
|
|
self.min_val = Buffer(0.0, dtype=np.float32) |
|
|
|
self.max_val = Buffer(0.0, dtype=np.float32) |
|
|
|
self.min_val = Buffer(np.finfo(np.float32).max, dtype=np.float32) |
|
|
|
self.max_val = Buffer(np.finfo(np.float32).min, dtype=np.float32) |
|
|
|
self.scale_limit = eps |
|
|
|
# flag is used by cond_take, first time will be first flag, and after will be set as not_flag |
|
|
|
self.first_flag = Buffer(np.array([1, 0], dtype=np.int32)) |
|
|
|
self.not_flag = Buffer(np.array([0, 1], dtype=np.int32)) |
|
|
|
|
|
|
|
def set_min_max(self, tmp_min, tmp_max): |
|
|
|
# FIXME: cond_take will destory shape, use reshape to reset shape |
|
|
|
tmp_min = tmp_min.reshape(1) |
|
|
|
tmp_max = tmp_max.reshape(1) |
|
|
|
F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0) |
|
|
|
F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0) |
|
|
|
F.add_update(self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0) |
|
|
|
|
|
|
|
def _calculate_qparams(self, inp_min_val, inp_max_val): |
|
|
|
min_val = F.minimum(0.0, inp_min_val) |
|
|
@@ -144,13 +132,20 @@ class MinMaxObserver(Observer): |
|
|
|
# stop gradient |
|
|
|
x = F.zero_grad(x_orig) |
|
|
|
# find max and min |
|
|
|
tmp_min, _ = F.cond_take( |
|
|
|
self.first_flag, F.concat([x.min(), F.minimum(self.min_val, x.min())]) |
|
|
|
F.add_update( |
|
|
|
self.min_val, |
|
|
|
F.minimum(self.min_val, x.min()), |
|
|
|
alpha=0.0, |
|
|
|
beta=1.0, |
|
|
|
bias=0.0, |
|
|
|
) |
|
|
|
tmp_max, _ = F.cond_take( |
|
|
|
self.first_flag, F.concat([x.max(), F.maximum(self.max_val, x.max())]) |
|
|
|
F.add_update( |
|
|
|
self.max_val, |
|
|
|
F.maximum(self.max_val, x.max()), |
|
|
|
alpha=0.0, |
|
|
|
beta=1.0, |
|
|
|
bias=0.0, |
|
|
|
) |
|
|
|
self.set_min_max(tmp_min, tmp_max) |
|
|
|
return x_orig |
|
|
|
|
|
|
|
|
|
|
@@ -160,6 +155,7 @@ class ExponentialMovingAverageObserver(MinMaxObserver): |
|
|
|
): |
|
|
|
super().__init__(mode, eps, dtype) |
|
|
|
self.momentum = Buffer(momentum) |
|
|
|
self.runtime_momentum = Buffer(0.0) |
|
|
|
|
|
|
|
def set_momentum(self, momentum): |
|
|
|
self.momentum.set_value(momentum) |
|
|
@@ -169,25 +165,19 @@ class ExponentialMovingAverageObserver(MinMaxObserver): |
|
|
|
# stop gradient |
|
|
|
x = F.zero_grad(x_orig) |
|
|
|
# Exponential Moving Average |
|
|
|
tmp_min, _ = F.cond_take( |
|
|
|
self.first_flag, |
|
|
|
F.concat( |
|
|
|
[ |
|
|
|
x.min(), |
|
|
|
self.momentum * self.min_val + (1 - self.momentum) * x.min(), |
|
|
|
] |
|
|
|
), |
|
|
|
tmp_min = ( |
|
|
|
self.min_val * self.runtime_momentum |
|
|
|
+ (1 - self.runtime_momentum) * x.min() |
|
|
|
) |
|
|
|
tmp_max = ( |
|
|
|
self.max_val * self.runtime_momentum |
|
|
|
+ (1 - self.runtime_momentum) * x.max() |
|
|
|
) |
|
|
|
tmp_max, _ = F.cond_take( |
|
|
|
self.first_flag, |
|
|
|
F.concat( |
|
|
|
[ |
|
|
|
x.max(), |
|
|
|
self.momentum * self.max_val + (1 - self.momentum) * x.max(), |
|
|
|
] |
|
|
|
), |
|
|
|
F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0) |
|
|
|
F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0) |
|
|
|
F.add_update( |
|
|
|
self.runtime_momentum, self.momentum, alpha=0.0, beta=1.0, bias=0.0 |
|
|
|
) |
|
|
|
self.set_min_max(tmp_min, tmp_max) |
|
|
|
return x_orig |
|
|
|
|
|
|
|
|
|
|
|