Browse Source

fix(mge/quantization): modify observer api

GitOrigin-RevId: 7b9c22be96
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
ff341cb19b
3 changed files with 71 additions and 31 deletions
  1. +2
    -2
      python_module/megengine/module/module.py
  2. +21
    -10
      python_module/megengine/quantization/fake_quant.py
  3. +48
    -19
      python_module/megengine/quantization/observer.py

+ 2
- 2
python_module/megengine/module/module.py View File

@@ -505,8 +505,8 @@ class QATModule(Module):
): ):
oup = self.apply_observer(target, obs) oup = self.apply_observer(target, obs)
if fq is not None: if fq is not None:
scale, zero_point = obs.get_qparams()
oup = fq(oup, scale, zero_point)
q_dict = obs.get_qparams()
oup = fq(oup, q_dict)
return oup return oup


def set_qat_mode(self, mode: QATMode): def set_qat_mode(self, mode: QATMode):


+ 21
- 10
python_module/megengine/quantization/fake_quant.py View File

@@ -8,7 +8,7 @@
from .. import functional as F from .. import functional as F
from .._internal.dtype import _metadata_dict from .._internal.dtype import _metadata_dict
from ..module import Module from ..module import Module
from .observer import Round
from .observer import ObserverMode, Round




class FakeQuantize(Module): class FakeQuantize(Module):
@@ -35,14 +35,25 @@ class FakeQuantize(Module):
def disable(self): def disable(self):
self.enabled = False self.enabled = False


def forward(self, inp, scale, zero_point):
def forward(self, inp, q_dict):
if self.enabled: if self.enabled:
# Quant
oup = Round()(inp / scale) + zero_point
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup - zero_point) * scale
return oup

if q_dict["mode"] == ObserverMode.SYMMERTIC:
scale = q_dict["scale"]
# Quant
oup = Round()(inp / scale)
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup) * scale
return oup
else:
scale = q_dict["scale"]
zero_point = q_dict["zero_point"]
# Quant
oup = Round()(inp / scale) + zero_point
# clip
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax)
# DeQuant
oup = (oup - zero_point) * scale
return oup
return inp return inp

+ 48
- 19
python_module/megengine/quantization/observer.py View File

@@ -7,6 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math import math
from abc import abstractmethod from abc import abstractmethod
from enum import Enum


import numpy as np import numpy as np


@@ -46,9 +47,11 @@ class Observer(Module):
self.enabled = True self.enabled = True


def get_dtype(self): def get_dtype(self):
scale, zero_point = self.get_qparams()
numpy_scale = None if scale is None else scale.numpy()[0]
numpy_zero_point = None if zero_point is None else zero_point.numpy()[0]
q_dict = self.get_qparams()
numpy_scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0]
numpy_zero_point = (
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0]
)
return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point) return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point)


def enable(self): def enable(self):
@@ -73,13 +76,29 @@ class Observer(Module):
pass pass




class ObserverMode(Enum):
SYMMERTIC = 1
ASYMMERTIC = 2


def create_observer_dict(mode):
if mode == ObserverMode.SYMMERTIC:
return {
"mode": ObserverMode.SYMMERTIC,
"scale": None,
}
else:
return {
"mode": ObserverMode.ASYMMERTIC,
"scale": None,
"zero_point": None,
}


class MinMaxObserver(Observer): class MinMaxObserver(Observer):
def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs):
super().__init__(*args, **kwargs)
self.symmetric = symmetric
if self.symmetric:
# assert qmin + qmax == -1, 'when reduce_range, qmin + qmax shoule equal -1'
self.zero_point = tensor((self.qmin + self.qmax + 1) // 2)
def __init__(self, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"):
super().__init__(dtype)
self.mode = mode


self.min_val = Buffer(0.0, dtype=np.float32) self.min_val = Buffer(0.0, dtype=np.float32)
self.max_val = Buffer(0.0, dtype=np.float32) self.max_val = Buffer(0.0, dtype=np.float32)
@@ -99,22 +118,23 @@ class MinMaxObserver(Observer):
def _calculate_qparams(self, inp_min_val, inp_max_val): def _calculate_qparams(self, inp_min_val, inp_max_val):
min_val = F.minimum(0.0, inp_min_val) min_val = F.minimum(0.0, inp_min_val)
max_val = F.maximum(0.0, inp_max_val) max_val = F.maximum(0.0, inp_max_val)
if self.symmetric:
q_dict = create_observer_dict(self.mode)
if self.mode == ObserverMode.SYMMERTIC:
symmetric_max_vals = F.maximum(-min_val, max_val) symmetric_max_vals = F.maximum(-min_val, max_val)
# use maximun to avoid scale too small at the begin # use maximun to avoid scale too small at the begin
scale = F.maximum(
q_dict["scale"] = F.maximum(
symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit
) )
zero_point = self.zero_point
# zero_point = self.zero_point
else: else:
# use maximun to avoid scale too small at the begin # use maximun to avoid scale too small at the begin
scale = F.maximum(
q_dict["scale"] = F.maximum(
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit, (max_val - min_val) / (self.qmax - self.qmin), self.scale_limit,
) )
# caculate zero_point # caculate zero_point
zero_point = self.qmin - Round()((min_val / scale))
q_dict["zero_point"] = self.qmin - Round()((min_val / scale))


return scale, zero_point
return q_dict


def get_qparams(self): def get_qparams(self):
return self._calculate_qparams(self.min_val, self.max_val) return self._calculate_qparams(self.min_val, self.max_val)
@@ -135,8 +155,10 @@ class MinMaxObserver(Observer):




class ExponentialMovingAverageObserver(MinMaxObserver): class ExponentialMovingAverageObserver(MinMaxObserver):
def __init__(self, momentum=0.9, *args, **kwargs):
super().__init__(*args, **kwargs)
def __init__(
self, momentum=0.9, mode=ObserverMode.SYMMERTIC, eps=0.00001, dtype="qint8"
):
super().__init__(mode, eps, dtype)
self.momentum = Buffer(momentum) self.momentum = Buffer(momentum)


def set_momentum(self, momentum): def set_momentum(self, momentum):
@@ -170,8 +192,15 @@ class ExponentialMovingAverageObserver(MinMaxObserver):




class HistogramObserver(MinMaxObserver): class HistogramObserver(MinMaxObserver):
def __init__(self, bins=2048, upsample_rate=128, dtype="qint8", *args, **kwargs):
super().__init__(dtype=dtype, *args, **kwargs)
def __init__(
self,
bins=2048,
upsample_rate=128,
dtype="qint8",
mode=ObserverMode.SYMMERTIC,
eps=0.00001,
):
super().__init__(mode, eps, dtype)
self.bins = bins self.bins = bins
self.upsample_rate = upsample_rate self.upsample_rate = upsample_rate
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1


Loading…
Cancel
Save