Browse Source

fix(mge/quantization): fix get scale issue

GitOrigin-RevId: 99068d7422
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
de75dae84c
2 changed files with 14 additions and 50 deletions
  1. +2
    -1
      python_module/megengine/module/module.py
  2. +12
    -49
      python_module/megengine/quantization/observer.py

+ 2
- 1
python_module/megengine/module/module.py View File

@@ -500,7 +500,8 @@ class QATModule(Module):
self, target: Tensor, fq: "FakeQuantize", obs: "Observer" self, target: Tensor, fq: "FakeQuantize", obs: "Observer"
): ):
oup = self.apply_observer(target, obs) oup = self.apply_observer(target, obs)
return fq(oup, obs.scale, obs.zero_point)
scale, zero_point = obs.get_qparams()
return fq(oup, scale, zero_point)


def set_qat_mode(self, mode: QATMode): def set_qat_mode(self, mode: QATMode):
r""" r"""


+ 12
- 49
python_module/megengine/quantization/observer.py View File

@@ -41,7 +41,6 @@ class Observer(Module):
self.dtype = dtype self.dtype = dtype
self.qmin = _metadata_dict[dtype].qmin self.qmin = _metadata_dict[dtype].qmin
self.qmax = _metadata_dict[dtype].qmax self.qmax = _metadata_dict[dtype].qmax
self.zero_point, self.scale = None, None
self.enabled = True self.enabled = True


def get_dtype(self): def get_dtype(self):
@@ -72,23 +71,6 @@ class Observer(Module):
pass pass




class IdentityObserver(Observer):
r"""
An test Observer that always return scale:1 and zero_point:0.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.zero_point = ones((1), dtype="float32")
self.scale = zeros((1), dtype="float32")

def forward(self, x):
return x

def get_qparams(self):
return self.scale, self.zero_point


class MinMaxObserver(Observer): class MinMaxObserver(Observer):
def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs): def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@@ -108,47 +90,28 @@ class MinMaxObserver(Observer):
# FIXME: cond_take will destory shape, use reshape to reset shape # FIXME: cond_take will destory shape, use reshape to reset shape
tmp_min = tmp_min.reshape(1) tmp_min = tmp_min.reshape(1)
tmp_max = tmp_max.reshape(1) tmp_max = tmp_max.reshape(1)
if self.training:
F.zero_grad(
F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0)
)
F.zero_grad(
F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0)
)
F.zero_grad(
F.add_update(
self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0
)
)
F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0)
F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0)
F.add_update(self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0)


# FIXME: add_update is applied after the whole trace procedure in `symbolic=True`
# mode. So use tmp_min/tmp_max to calc and save scale/zero_point for further
# calculation in FakeQuant.
self.set_scale_zero_point(tmp_min, tmp_max)

def set_scale_zero_point(self, tmp_min, tmp_max):
def get_qparams(self):
if self.symmetric: if self.symmetric:
symmetric_max_vals = F.maximum(-tmp_min, tmp_max)
symmetric_max_vals = F.maximum(-self.min_val, self.max_val)
# use maximun to avoid scale too small at the begin # use maximun to avoid scale too small at the begin
self.scale = F.maximum(
scale = F.maximum(
symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
) )
# zero_point = self.zero_point
zero_point = self.zero_point
else: else:
# use maximun to avoid scale too small at the begin # use maximun to avoid scale too small at the begin
self.scale = F.maximum(
(tmp_max - tmp_min) / (self.qmax - self.qmin), self.scale_limit
scale = F.maximum(
(self.max_val - self.min_val) / (self.qmax - self.qmin),
self.scale_limit,
) )
# caculate zero_point # caculate zero_point
self.zero_point = self.qmin - Round()((tmp_min / self.scale))

def get_qparams(self):
# scale and zero_point is runtime tensor rather than Buffer,
# so need to re-calc if min_val and max_val are loaded.
if self.scale is None:
self.set_scale_zero_point(self.min_val, self.max_val)
zero_point = self.qmin - Round()((self.min_val / scale))


return self.scale, self.zero_point
return scale, zero_point


def forward(self, x_orig): def forward(self, x_orig):
if self.enabled: if self.enabled:


Loading…
Cancel
Save