Browse Source

fix(mge/quantization): replace `_reset` with "=" in observer

GitOrigin-RevId: ed6af9b98d
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
b7ed0cb850
2 changed files with 16 additions and 5 deletions
  1. +5
    -5
      imperative/python/megengine/quantization/observer.py
  2. +11
    -0
      imperative/python/test/unit/quantization/test_observer.py

+ 5
- 5
imperative/python/megengine/quantization/observer.py View File

@@ -153,11 +153,11 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
**kwargs **kwargs
): ):
super().__init__(mode, eps, dtype, narrow_range, **kwargs) super().__init__(mode, eps, dtype, narrow_range, **kwargs)
self.momentum = Tensor(momentum)
self.momentum = Tensor(momentum, dtype="float32")
self.runtime_momentum = Tensor(0.0) self.runtime_momentum = Tensor(0.0)


def set_momentum(self, momentum): def set_momentum(self, momentum):
self.momentum._reset(momentum)
self.momentum = Tenosr(momentum, dtype="float32")


def forward(self, x_orig): def forward(self, x_orig):
if self.enabled: if self.enabled:
@@ -439,9 +439,9 @@ class HistogramObserver(MinMaxObserver):
self.bins, self.bins,
) )


self.histogram._reset(new_histogram)
self.min_val._reset(new_min)
self.max_val._reset(new_max)
self.histogram = Tensor(new_histogram, dtype="float32")
self.min_val = Tensor(new_min, dtype="float32")
self.max_val = Tensor(new_max, dtype="float32")


def forward(self, x_orig): def forward(self, x_orig):
self.sideeffect_forward(x_orig) self.sideeffect_forward(x_orig)


+ 11
- 0
imperative/python/test/unit/quantization/test_observer.py View File

@@ -8,6 +8,7 @@ import megengine.distributed as dist
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
from megengine.quantization.observer import ( from megengine.quantization.observer import (
ExponentialMovingAverageObserver, ExponentialMovingAverageObserver,
HistogramObserver,
MinMaxObserver, MinMaxObserver,
Observer, Observer,
PassiveObserver, PassiveObserver,
@@ -44,6 +45,16 @@ def test_exponential_moving_average_observer():
np.testing.assert_allclose(m.max_val.numpy(), expected_max) np.testing.assert_allclose(m.max_val.numpy(), expected_max)




def test_histogram_observer():
x = np.random.rand(3, 3, 3, 3).astype("float32")
np_min, np_max = x.min(), x.max()
x = mge.tensor(x)
m = HistogramObserver()
m(x)
np.testing.assert_allclose(m.min_val.numpy(), np_min)
np.testing.assert_allclose(m.max_val.numpy(), np_max)


def test_passive_observer(): def test_passive_observer():
q_dict = {"scale": mge.tensor(1.0)} q_dict = {"scale": mge.tensor(1.0)}
m = PassiveObserver(q_dict, "qint8") m = PassiveObserver(q_dict, "qint8")


Loading…
Cancel
Save