diff --git a/imperative/python/megengine/quantization/observer.py b/imperative/python/megengine/quantization/observer.py index 26d5465d..b1c59905 100644 --- a/imperative/python/megengine/quantization/observer.py +++ b/imperative/python/megengine/quantization/observer.py @@ -171,7 +171,7 @@ class HistogramObserver(MinMaxObserver): self.bins = bins self.upsample_rate = upsample_rate self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 - self.histogram = Tensor([-1] + [0.0] * (bins - 1)) + self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32") def _non_linear_param_search(self): r"""Non-linear parameter search. @@ -304,8 +304,8 @@ class HistogramObserver(MinMaxObserver): start_bin = next_start_bin end_bin = next_end_bin - new_min = self.min_val + bin_width * start_bin - new_max = self.min_val + bin_width * (end_bin + 1) + new_min = self.min_val + Tensor(bin_width * start_bin, dtype=np.float32) + new_max = self.min_val + Tensor(bin_width * (end_bin + 1), dtype=np.float32) return new_min, new_max def get_qparams(self):