GitOrigin-RevId: ed6af9b98d
tags/v1.3.0
@@ -153,11 +153,11 @@ class ExponentialMovingAverageObserver(MinMaxObserver): | |||||
**kwargs | **kwargs | ||||
): | ): | ||||
super().__init__(mode, eps, dtype, narrow_range, **kwargs) | super().__init__(mode, eps, dtype, narrow_range, **kwargs) | ||||
self.momentum = Tensor(momentum) | |||||
self.momentum = Tensor(momentum, dtype="float32") | |||||
self.runtime_momentum = Tensor(0.0) | self.runtime_momentum = Tensor(0.0) | ||||
def set_momentum(self, momentum): | def set_momentum(self, momentum): | ||||
self.momentum._reset(momentum) | |||||
self.momentum = Tenosr(momentum, dtype="float32") | |||||
def forward(self, x_orig): | def forward(self, x_orig): | ||||
if self.enabled: | if self.enabled: | ||||
@@ -439,9 +439,9 @@ class HistogramObserver(MinMaxObserver): | |||||
self.bins, | self.bins, | ||||
) | ) | ||||
self.histogram._reset(new_histogram) | |||||
self.min_val._reset(new_min) | |||||
self.max_val._reset(new_max) | |||||
self.histogram = Tensor(new_histogram, dtype="float32") | |||||
self.min_val = Tensor(new_min, dtype="float32") | |||||
self.max_val = Tensor(new_max, dtype="float32") | |||||
def forward(self, x_orig): | def forward(self, x_orig): | ||||
self.sideeffect_forward(x_orig) | self.sideeffect_forward(x_orig) | ||||
@@ -8,6 +8,7 @@ import megengine.distributed as dist | |||||
from megengine.distributed.helper import get_device_count_by_fork | from megengine.distributed.helper import get_device_count_by_fork | ||||
from megengine.quantization.observer import ( | from megengine.quantization.observer import ( | ||||
ExponentialMovingAverageObserver, | ExponentialMovingAverageObserver, | ||||
HistogramObserver, | |||||
MinMaxObserver, | MinMaxObserver, | ||||
Observer, | Observer, | ||||
PassiveObserver, | PassiveObserver, | ||||
@@ -44,6 +45,16 @@ def test_exponential_moving_average_observer(): | |||||
np.testing.assert_allclose(m.max_val.numpy(), expected_max) | np.testing.assert_allclose(m.max_val.numpy(), expected_max) | ||||
def test_histogram_observer(): | |||||
x = np.random.rand(3, 3, 3, 3).astype("float32") | |||||
np_min, np_max = x.min(), x.max() | |||||
x = mge.tensor(x) | |||||
m = HistogramObserver() | |||||
m(x) | |||||
np.testing.assert_allclose(m.min_val.numpy(), np_min) | |||||
np.testing.assert_allclose(m.max_val.numpy(), np_max) | |||||
def test_passive_observer(): | def test_passive_observer(): | ||||
q_dict = {"scale": mge.tensor(1.0)} | q_dict = {"scale": mge.tensor(1.0)} | ||||
m = PassiveObserver(q_dict, "qint8") | m = PassiveObserver(q_dict, "qint8") | ||||