Browse Source

fix(mge/quantization): fix histogram observer load and store issue

GitOrigin-RevId: b0a2b476e4
tags/v0.5.0
Megvii Engine Team Xu Xinran 5 years ago
parent
commit
7b2c5a7383
2 changed files with 33 additions and 40 deletions
  1. +1
    -1
      python_module/megengine/quantization/__init__.py
  2. +32
    -39
      python_module/megengine/quantization/observer.py

+ 1
- 1
python_module/megengine/quantization/__init__.py View File

@@ -6,7 +6,7 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .fake_quant import FakeQuantize
from .observer import HistogramObserver, Observer
from .observer import HistogramObserver, Observer, ObserverMode
from .qconfig import (
QConfig,
calibration_qconfig,


+ 32
- 39
python_module/megengine/quantization/observer.py View File

@@ -132,7 +132,7 @@ class MinMaxObserver(Observer):
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit,
)
# caculate zero_point
q_dict["zero_point"] = self.qmin - Round()((min_val / scale))
q_dict["zero_point"] = self.qmin - Round()((min_val / q_dict["scale"]))

return q_dict

@@ -204,7 +204,7 @@ class HistogramObserver(MinMaxObserver):
self.bins = bins
self.upsample_rate = upsample_rate
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1
self.histogram = None
self.histogram = Buffer([0.0] * bins)

def _non_linear_param_search(self):
r"""Non-linear parameter search.
@@ -212,6 +212,12 @@ class HistogramObserver(MinMaxObserver):
By selecting new min/max, we filter out outliers in input distribution.
"""

np_min_val = self.min_val.numpy()[0]
np_max_val = self.max_val.numpy()[0]
np_histogram = self.histogram.numpy()
assert len(np_histogram) == self.bins, "bins mistmatch"
bin_width = (np_max_val - np_min_val) / self.bins

def _get_norm(delta_begin, delta_end, density, norm_type):
r"""
Compute the norm of the values uniformaly distributed between
@@ -233,9 +239,6 @@ class HistogramObserver(MinMaxObserver):
Compute the quantization error if we use start_bin to end_bin as the
min and max to do the quantization.
"""
np_min_val = self.min_val.numpy()[0]
np_max_val = self.max_val.numpy()[0]
bin_width = (np_max_val - np_min_val) / self.bins

norm = 0.0
dst_bin_width = (
@@ -262,7 +265,7 @@ class HistogramObserver(MinMaxObserver):
dst_bin_of_begin * dst_bin_width + dst_bin_width / 2
)

density = self.histogram[src_bin] / bin_width
density = np_histogram[src_bin] / bin_width
if dst_bin_of_begin == dst_bin_of_end:
# if src_bin is entirely within 1 dst_bin
delta_begin = src_bin_begin - dst_bin_of_begin_center
@@ -286,12 +289,9 @@ class HistogramObserver(MinMaxObserver):
norm = norm + _get_norm(delta_begin, delta_end, density, norm_type)
return norm

assert len(self.histogram) == self.bins, "bins mistmatch"
bin_width = (self.max_val - self.min_val) / self.bins

# cumulative sum
total = sum(self.histogram)
cSum = np.cumsum(self.histogram, axis=0)
total = sum(np_histogram)
cSum = np.cumsum(np_histogram, axis=0)

stepsize = 1e-5 # granularity
alpha = 0.0 # lower bound
@@ -400,46 +400,39 @@ class HistogramObserver(MinMaxObserver):
x = x_orig.numpy()
min_val = self.min_val.numpy()[0]
max_val = self.max_val.numpy()[0]
histogram = self.histogram.numpy()
new_min = x.min()
new_max = x.max()

if min_val == 0 or max_val == 0:
min_val = x.min()
max_val = x.max()
self.min_val.set_value(min_val)
self.max_val.set_value(max_val)
self.histogram, _ = np.histogram(x, self.bins, (min_val, max_val))
self.histogram = self.histogram.astype(np.float64)
new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
else:
new_min = x.min()
new_max = x.max()
combined_min = min(new_min, min_val)
combined_max = max(new_max, max_val)
new_min = min(new_min, min_val)
new_max = max(new_max, max_val)
# combine the existing histogram and new histogram into 1 histogram
# We do this by first upsampling the histogram to a dense grid
# and then downsampling the histogram efficiently
(
combined_min,
combined_max,
downsample_rate,
start_idx,
) = self._adjust_min_max(combined_min, combined_max, self.upsample_rate)

combined_histogram, _ = np.histogram(
x, self.bins, (combined_min, combined_max)
(new_min, new_max, downsample_rate, start_idx,) = self._adjust_min_max(
new_min, new_max, self.upsample_rate
)
combined_histogram = combined_histogram.astype(np.float64)
if combined_min == min_val and combined_max == max_val:
combined_histogram += self.histogram

new_histogram, _ = np.histogram(x, self.bins, (new_min, new_max))
new_histogram = new_histogram.astype(np.float64)
if new_min == min_val and new_max == max_val:
new_histogram += histogram
else:
combined_histogram = self._combine_histograms(
combined_histogram,
self.histogram,
new_histogram = self._combine_histograms(
new_histogram,
histogram,
self.upsample_rate,
downsample_rate,
start_idx,
self.bins,
)
self.histogram = combined_histogram
self.min_val.set_value(combined_min)
self.max_val.set_value(combined_max)

self.histogram.set_value(new_histogram)
self.min_val.set_value(new_min)
self.max_val.set_value(new_max)

def forward(self, x_orig):
self.sideeffect_forward(x_orig)


Loading…
Cancel
Save