GitOrigin-RevId: 5727f63560
tags/v1.0.0-rc1
@@ -92,6 +92,25 @@ class QATModule(Module): | |||
else: | |||
return self.act_observer.get_dtype() | |||
def _get_qparams(self, fake_quant: FakeQuantize, observer: Observer): | |||
if hasattr(fake_quant, "get_qparams"): | |||
return fake_quant.get_qparams() | |||
elif observer is not None: | |||
return observer.get_qparams() | |||
return None | |||
def get_weight_qparams(self): | |||
r""" | |||
Get weight's quantization parameters. | |||
""" | |||
return self._get_qparams(self.weight_fake_quant, self.weight_observer) | |||
def get_activation_qparams(self): | |||
r""" | |||
Get activation's quantization parameters. | |||
""" | |||
return self._get_qparams(self.act_fake_quant, self.act_observer) | |||
@classmethod | |||
@abstractmethod | |||
def from_float_module(cls, float_module: Module): | |||
@@ -8,7 +8,7 @@ | |||
from .fake_quant import FakeQuantize | |||
from .internal_fake_quant import * | |||
from .observer import HistogramObserver, Observer, ObserverMode | |||
from .observer import HistogramObserver, Observer | |||
from .qconfig import ( | |||
QConfig, | |||
calibration_qconfig, | |||
@@ -16,3 +16,4 @@ from .qconfig import ( | |||
min_max_fakequant_qconfig, | |||
tqt_quant_qconfig, | |||
) | |||
from .utils import QuantMode |
@@ -15,7 +15,8 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||
from ..core import Buffer, Function, Parameter | |||
from ..jit import sideeffect | |||
from ..module import Module | |||
from .observer import ObserverMode, Round | |||
from .observer import Round | |||
from .utils import QuantMode, get_qparam_dict | |||
class _FakeQuantize(Module): | |||
@@ -121,8 +122,18 @@ class TQT(_FakeQuantize): | |||
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) | |||
return inp | |||
def get_qparams(self): | |||
qdict = get_qparam_dict(QuantMode.TQT) | |||
qdict["scale"] = 2 ** self.scale | |||
return qdict | |||
def get_dtype(self): | |||
return get_quantized_dtype(self.dtype, 2 ** self.scale.numpy()[0], None) | |||
q_dict = self.get_qparams() | |||
scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0] | |||
zero_point = ( | |||
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0] | |||
) | |||
return get_quantized_dtype(self.dtype, scale, zero_point) | |||
class FakeQuantize(_FakeQuantize): | |||
@@ -131,7 +142,7 @@ class FakeQuantize(_FakeQuantize): | |||
""" | |||
def fake_quant_forward(self, inp, q_dict): | |||
if q_dict["mode"] == ObserverMode.SYMMERTIC: | |||
if q_dict["mode"] == QuantMode.SYMMERTIC: | |||
scale = q_dict["scale"] | |||
# Quant | |||
oup = Round()(inp / scale) | |||
@@ -16,6 +16,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||
from ..core import Buffer, Function, tensor | |||
from ..jit import sideeffect | |||
from ..module import Module | |||
from .utils import QuantMode, get_qparam_dict | |||
class Round(Function): | |||
@@ -81,29 +82,10 @@ class Observer(Module): | |||
pass | |||
class ObserverMode(Enum): | |||
SYMMERTIC = 1 | |||
ASYMMERTIC = 2 | |||
def create_observer_dict(mode): | |||
if mode == ObserverMode.SYMMERTIC: | |||
return { | |||
"mode": ObserverMode.SYMMERTIC, | |||
"scale": None, | |||
} | |||
else: | |||
return { | |||
"mode": ObserverMode.ASYMMERTIC, | |||
"scale": None, | |||
"zero_point": None, | |||
} | |||
class MinMaxObserver(Observer): | |||
def __init__( | |||
self, | |||
mode=ObserverMode.SYMMERTIC, | |||
mode=QuantMode.SYMMERTIC, | |||
eps=0.00001, | |||
dtype="qint8", | |||
narrow_range: bool = False, | |||
@@ -117,10 +99,10 @@ class MinMaxObserver(Observer): | |||
def _calculate_qparams(self, inp_min_val, inp_max_val): | |||
min_val = F.minimum(0.0, inp_min_val) | |||
max_val = F.maximum(0.0, inp_max_val) | |||
q_dict = create_observer_dict(self.mode) | |||
q_dict = get_qparam_dict(self.mode) | |||
q_dict["min_val"] = inp_min_val | |||
q_dict["max_val"] = inp_max_val | |||
if self.mode == ObserverMode.SYMMERTIC: | |||
if self.mode == QuantMode.SYMMERTIC: | |||
symmetric_max_vals = F.maximum(-min_val, max_val) | |||
# use maximun to avoid scale too small at the begin | |||
q_dict["scale"] = F.maximum( | |||
@@ -166,7 +148,7 @@ class ExponentialMovingAverageObserver(MinMaxObserver): | |||
def __init__( | |||
self, | |||
momentum=0.9, | |||
mode=ObserverMode.SYMMERTIC, | |||
mode=QuantMode.SYMMERTIC, | |||
eps=0.00001, | |||
dtype="qint8", | |||
narrow_range: bool = False, | |||
@@ -204,7 +186,7 @@ class HistogramObserver(MinMaxObserver): | |||
self, | |||
bins=2048, | |||
upsample_rate=128, | |||
mode=ObserverMode.SYMMERTIC, | |||
mode=QuantMode.SYMMERTIC, | |||
eps=0.00001, | |||
dtype="qint8", | |||
narrow_range: bool = False, | |||
@@ -6,6 +6,7 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from enum import Enum | |||
from functools import partial, update_wrapper, wraps | |||
@@ -21,3 +22,24 @@ def register_method_to_class(cls): | |||
return func | |||
return decorator | |||
class QuantMode(Enum): | |||
SYMMERTIC = 1 | |||
ASYMMERTIC = 2 | |||
TQT = 3 | |||
qparam_dict = { | |||
QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None,}, | |||
QuantMode.ASYMMERTIC: { | |||
"mode": QuantMode.ASYMMERTIC, | |||
"scale": None, | |||
"zero_point": None, | |||
}, | |||
QuantMode.TQT: {"mode": QuantMode.TQT, "scale": None,}, | |||
} | |||
def get_qparam_dict(mode): | |||
return qparam_dict.get(mode, None) |