|
|
@@ -153,11 +153,11 @@ class ExponentialMovingAverageObserver(MinMaxObserver): |
|
|
|
**kwargs |
|
|
|
): |
|
|
|
super().__init__(mode, eps, dtype, narrow_range, **kwargs) |
|
|
|
self.momentum = Tensor(momentum) |
|
|
|
self.momentum = Tensor(momentum, dtype="float32") |
|
|
|
self.runtime_momentum = Tensor(0.0) |
|
|
|
|
|
|
|
def set_momentum(self, momentum): |
|
|
|
self.momentum._reset(momentum) |
|
|
|
self.momentum = Tenosr(momentum, dtype="float32") |
|
|
|
|
|
|
|
def forward(self, x_orig): |
|
|
|
if self.enabled: |
|
|
@@ -439,9 +439,9 @@ class HistogramObserver(MinMaxObserver): |
|
|
|
self.bins, |
|
|
|
) |
|
|
|
|
|
|
|
self.histogram._reset(new_histogram) |
|
|
|
self.min_val._reset(new_min) |
|
|
|
self.max_val._reset(new_max) |
|
|
|
self.histogram = Tensor(new_histogram, dtype="float32") |
|
|
|
self.min_val = Tensor(new_min, dtype="float32") |
|
|
|
self.max_val = Tensor(new_max, dtype="float32") |
|
|
|
|
|
|
|
def forward(self, x_orig): |
|
|
|
self.sideeffect_forward(x_orig) |
|
|
|