From d8ac6c70a190f1df396c82f65c21ecb73537a168 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 21 Oct 2020 16:45:37 +0800 Subject: [PATCH] fix(quantize): fix quantize calibration dtype issue GitOrigin-RevId: 667c99f469054134efd006aa8f54fca22c4c85b6 --- imperative/python/megengine/quantization/observer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/imperative/python/megengine/quantization/observer.py b/imperative/python/megengine/quantization/observer.py index 26d5465d..b1c59905 100644 --- a/imperative/python/megengine/quantization/observer.py +++ b/imperative/python/megengine/quantization/observer.py @@ -171,7 +171,7 @@ class HistogramObserver(MinMaxObserver): self.bins = bins self.upsample_rate = upsample_rate self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 - self.histogram = Tensor([-1] + [0.0] * (bins - 1)) + self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32") def _non_linear_param_search(self): r"""Non-linear parameter search. @@ -304,8 +304,8 @@ class HistogramObserver(MinMaxObserver): start_bin = next_start_bin end_bin = next_end_bin - new_min = self.min_val + bin_width * start_bin - new_max = self.min_val + bin_width * (end_bin + 1) + new_min = self.min_val + Tensor(bin_width * start_bin, dtype=np.float32) + new_max = self.min_val + Tensor(bin_width * (end_bin + 1), dtype=np.float32) return new_min, new_max def get_qparams(self):