|
|
@@ -171,7 +171,7 @@ class HistogramObserver(MinMaxObserver): |
|
|
|
self.bins = bins |
|
|
|
self.upsample_rate = upsample_rate |
|
|
|
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 |
|
|
|
self.histogram = Tensor([-1] + [0.0] * (bins - 1)) |
|
|
|
self.histogram = Tensor([-1] + [0.0] * (bins - 1), dtype="float32") |
|
|
|
|
|
|
|
def _non_linear_param_search(self): |
|
|
|
r"""Non-linear parameter search. |
|
|
@@ -304,8 +304,8 @@ class HistogramObserver(MinMaxObserver): |
|
|
|
start_bin = next_start_bin |
|
|
|
end_bin = next_end_bin |
|
|
|
|
|
|
|
new_min = self.min_val + bin_width * start_bin |
|
|
|
new_max = self.min_val + bin_width * (end_bin + 1) |
|
|
|
new_min = self.min_val + Tensor(bin_width * start_bin, dtype=np.float32) |
|
|
|
new_max = self.min_val + Tensor(bin_width * (end_bin + 1), dtype=np.float32) |
|
|
|
return new_min, new_max |
|
|
|
|
|
|
|
def get_qparams(self): |
|
|
|