Browse Source

fix(quantize): fix quantize calibration dtype issue

GitOrigin-RevId: 667c99f469
release-1.1
Megvii Engine Team 4 years ago
parent
commit
d8ac6c70a1
1 changed files with 3 additions and 3 deletions
  1. +3
    -3
      imperative/python/megengine/quantization/observer.py

+ 3
- 3
imperative/python/megengine/quantization/observer.py View File

@@ -171,7 +171,7 @@ class HistogramObserver(MinMaxObserver):
self.bins = bins
self.upsample_rate = upsample_rate
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1
self.histogram = Tensor([-1] + [0.0] * (bins - 1))
self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32")

def _non_linear_param_search(self):
r"""Non-linear parameter search.
@@ -304,8 +304,8 @@ class HistogramObserver(MinMaxObserver):
start_bin = next_start_bin
end_bin = next_end_bin

new_min = self.min_val + bin_width * start_bin
new_max = self.min_val + bin_width * (end_bin + 1)
new_min = self.min_val + Tensor(bin_width * start_bin, dtype=np.float32)
new_max = self.min_val + Tensor(bin_width * (end_bin + 1), dtype=np.float32)
return new_min, new_max

def get_qparams(self):


Loading…
Cancel
Save