Browse Source

feat(mgb/quantization): add get quantize parameters support

GitOrigin-RevId: 5727f63560
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
4495c0cc86
5 changed files with 63 additions and 28 deletions
  1. +19
    -0
      python_module/megengine/module/qat/module.py
  2. +2
    -1
      python_module/megengine/quantization/__init__.py
  3. +14
    -3
      python_module/megengine/quantization/fake_quant.py
  4. +6
    -24
      python_module/megengine/quantization/observer.py
  5. +22
    -0
      python_module/megengine/quantization/utils.py

+ 19
- 0
python_module/megengine/module/qat/module.py View File

@@ -92,6 +92,25 @@ class QATModule(Module):
else: else:
return self.act_observer.get_dtype() return self.act_observer.get_dtype()


def _get_qparams(self, fake_quant: FakeQuantize, observer: Observer):
if hasattr(fake_quant, "get_qparams"):
return fake_quant.get_qparams()
elif observer is not None:
return observer.get_qparams()
return None

def get_weight_qparams(self):
r"""
Get weight's quantization parameters.
"""
return self._get_qparams(self.weight_fake_quant, self.weight_observer)

def get_activation_qparams(self):
r"""
Get activation's quantization parameters.
"""
return self._get_qparams(self.act_fake_quant, self.act_observer)

@classmethod @classmethod
@abstractmethod @abstractmethod
def from_float_module(cls, float_module: Module): def from_float_module(cls, float_module: Module):


+ 2
- 1
python_module/megengine/quantization/__init__.py View File

@@ -8,7 +8,7 @@


from .fake_quant import FakeQuantize from .fake_quant import FakeQuantize
from .internal_fake_quant import * from .internal_fake_quant import *
from .observer import HistogramObserver, Observer, ObserverMode
from .observer import HistogramObserver, Observer
from .qconfig import ( from .qconfig import (
QConfig, QConfig,
calibration_qconfig, calibration_qconfig,
@@ -16,3 +16,4 @@ from .qconfig import (
min_max_fakequant_qconfig, min_max_fakequant_qconfig,
tqt_quant_qconfig, tqt_quant_qconfig,
) )
from .utils import QuantMode

+ 14
- 3
python_module/megengine/quantization/fake_quant.py View File

@@ -15,7 +15,8 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, Parameter from ..core import Buffer, Function, Parameter
from ..jit import sideeffect from ..jit import sideeffect
from ..module import Module from ..module import Module
from .observer import ObserverMode, Round
from .observer import Round
from .utils import QuantMode, get_qparam_dict




class _FakeQuantize(Module): class _FakeQuantize(Module):
@@ -121,8 +122,18 @@ class TQT(_FakeQuantize):
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0)
return inp return inp


def get_qparams(self):
qdict = get_qparam_dict(QuantMode.TQT)
qdict["scale"] = 2 ** self.scale
return qdict

def get_dtype(self): def get_dtype(self):
return get_quantized_dtype(self.dtype, 2 ** self.scale.numpy()[0], None)
q_dict = self.get_qparams()
scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0]
zero_point = (
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0]
)
return get_quantized_dtype(self.dtype, scale, zero_point)




class FakeQuantize(_FakeQuantize): class FakeQuantize(_FakeQuantize):
@@ -131,7 +142,7 @@ class FakeQuantize(_FakeQuantize):
""" """


def fake_quant_forward(self, inp, q_dict): def fake_quant_forward(self, inp, q_dict):
if q_dict["mode"] == ObserverMode.SYMMERTIC:
if q_dict["mode"] == QuantMode.SYMMERTIC:
scale = q_dict["scale"] scale = q_dict["scale"]
# Quant # Quant
oup = Round()(inp / scale) oup = Round()(inp / scale)


+ 6
- 24
python_module/megengine/quantization/observer.py View File

@@ -16,6 +16,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype
from ..core import Buffer, Function, tensor from ..core import Buffer, Function, tensor
from ..jit import sideeffect from ..jit import sideeffect
from ..module import Module from ..module import Module
from .utils import QuantMode, get_qparam_dict




class Round(Function): class Round(Function):
@@ -81,29 +82,10 @@ class Observer(Module):
pass pass




class ObserverMode(Enum):
SYMMERTIC = 1
ASYMMERTIC = 2


def create_observer_dict(mode):
if mode == ObserverMode.SYMMERTIC:
return {
"mode": ObserverMode.SYMMERTIC,
"scale": None,
}
else:
return {
"mode": ObserverMode.ASYMMERTIC,
"scale": None,
"zero_point": None,
}


class MinMaxObserver(Observer): class MinMaxObserver(Observer):
def __init__( def __init__(
self, self,
mode=ObserverMode.SYMMERTIC,
mode=QuantMode.SYMMERTIC,
eps=0.00001, eps=0.00001,
dtype="qint8", dtype="qint8",
narrow_range: bool = False, narrow_range: bool = False,
@@ -117,10 +99,10 @@ class MinMaxObserver(Observer):
def _calculate_qparams(self, inp_min_val, inp_max_val): def _calculate_qparams(self, inp_min_val, inp_max_val):
min_val = F.minimum(0.0, inp_min_val) min_val = F.minimum(0.0, inp_min_val)
max_val = F.maximum(0.0, inp_max_val) max_val = F.maximum(0.0, inp_max_val)
q_dict = create_observer_dict(self.mode)
q_dict = get_qparam_dict(self.mode)
q_dict["min_val"] = inp_min_val q_dict["min_val"] = inp_min_val
q_dict["max_val"] = inp_max_val q_dict["max_val"] = inp_max_val
if self.mode == ObserverMode.SYMMERTIC:
if self.mode == QuantMode.SYMMERTIC:
symmetric_max_vals = F.maximum(-min_val, max_val) symmetric_max_vals = F.maximum(-min_val, max_val)
# use maximun to avoid scale too small at the begin # use maximun to avoid scale too small at the begin
q_dict["scale"] = F.maximum( q_dict["scale"] = F.maximum(
@@ -166,7 +148,7 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
def __init__( def __init__(
self, self,
momentum=0.9, momentum=0.9,
mode=ObserverMode.SYMMERTIC,
mode=QuantMode.SYMMERTIC,
eps=0.00001, eps=0.00001,
dtype="qint8", dtype="qint8",
narrow_range: bool = False, narrow_range: bool = False,
@@ -204,7 +186,7 @@ class HistogramObserver(MinMaxObserver):
self, self,
bins=2048, bins=2048,
upsample_rate=128, upsample_rate=128,
mode=ObserverMode.SYMMERTIC,
mode=QuantMode.SYMMERTIC,
eps=0.00001, eps=0.00001,
dtype="qint8", dtype="qint8",
narrow_range: bool = False, narrow_range: bool = False,


+ 22
- 0
python_module/megengine/quantization/utils.py View File

@@ -6,6 +6,7 @@
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.


from enum import Enum
from functools import partial, update_wrapper, wraps from functools import partial, update_wrapper, wraps




@@ -21,3 +22,24 @@ def register_method_to_class(cls):
return func return func


return decorator return decorator


class QuantMode(Enum):
SYMMERTIC = 1
ASYMMERTIC = 2
TQT = 3


qparam_dict = {
QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None,},
QuantMode.ASYMMERTIC: {
"mode": QuantMode.ASYMMERTIC,
"scale": None,
"zero_point": None,
},
QuantMode.TQT: {"mode": QuantMode.TQT, "scale": None,},
}


def get_qparam_dict(mode):
return qparam_dict.get(mode, None)

Loading…
Cancel
Save