GitOrigin-RevId: 5727f63560
tags/v1.0.0-rc1
@@ -92,6 +92,25 @@ class QATModule(Module): | |||||
else: | else: | ||||
return self.act_observer.get_dtype() | return self.act_observer.get_dtype() | ||||
def _get_qparams(self, fake_quant: FakeQuantize, observer: Observer): | |||||
if hasattr(fake_quant, "get_qparams"): | |||||
return fake_quant.get_qparams() | |||||
elif observer is not None: | |||||
return observer.get_qparams() | |||||
return None | |||||
def get_weight_qparams(self): | |||||
r""" | |||||
Get weight's quantization parameters. | |||||
""" | |||||
return self._get_qparams(self.weight_fake_quant, self.weight_observer) | |||||
def get_activation_qparams(self): | |||||
r""" | |||||
Get activation's quantization parameters. | |||||
""" | |||||
return self._get_qparams(self.act_fake_quant, self.act_observer) | |||||
@classmethod | @classmethod | ||||
@abstractmethod | @abstractmethod | ||||
def from_float_module(cls, float_module: Module): | def from_float_module(cls, float_module: Module): | ||||
@@ -8,7 +8,7 @@ | |||||
from .fake_quant import FakeQuantize | from .fake_quant import FakeQuantize | ||||
from .internal_fake_quant import * | from .internal_fake_quant import * | ||||
from .observer import HistogramObserver, Observer, ObserverMode | |||||
from .observer import HistogramObserver, Observer | |||||
from .qconfig import ( | from .qconfig import ( | ||||
QConfig, | QConfig, | ||||
calibration_qconfig, | calibration_qconfig, | ||||
@@ -16,3 +16,4 @@ from .qconfig import ( | |||||
min_max_fakequant_qconfig, | min_max_fakequant_qconfig, | ||||
tqt_quant_qconfig, | tqt_quant_qconfig, | ||||
) | ) | ||||
from .utils import QuantMode |
@@ -15,7 +15,8 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||||
from ..core import Buffer, Function, Parameter | from ..core import Buffer, Function, Parameter | ||||
from ..jit import sideeffect | from ..jit import sideeffect | ||||
from ..module import Module | from ..module import Module | ||||
from .observer import ObserverMode, Round | |||||
from .observer import Round | |||||
from .utils import QuantMode, get_qparam_dict | |||||
class _FakeQuantize(Module): | class _FakeQuantize(Module): | ||||
@@ -121,8 +122,18 @@ class TQT(_FakeQuantize): | |||||
F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) | F.add_update(self.scale, tmp_scale, alpha=0.0, beta=1.0, bias=0.0) | ||||
return inp | return inp | ||||
def get_qparams(self): | |||||
qdict = get_qparam_dict(QuantMode.TQT) | |||||
qdict["scale"] = 2 ** self.scale | |||||
return qdict | |||||
def get_dtype(self): | def get_dtype(self): | ||||
return get_quantized_dtype(self.dtype, 2 ** self.scale.numpy()[0], None) | |||||
q_dict = self.get_qparams() | |||||
scale = None if "scale" not in q_dict else q_dict["scale"].numpy()[0] | |||||
zero_point = ( | |||||
None if "zero_point" not in q_dict else q_dict["zero_point"].numpy()[0] | |||||
) | |||||
return get_quantized_dtype(self.dtype, scale, zero_point) | |||||
class FakeQuantize(_FakeQuantize): | class FakeQuantize(_FakeQuantize): | ||||
@@ -131,7 +142,7 @@ class FakeQuantize(_FakeQuantize): | |||||
""" | """ | ||||
def fake_quant_forward(self, inp, q_dict): | def fake_quant_forward(self, inp, q_dict): | ||||
if q_dict["mode"] == ObserverMode.SYMMERTIC: | |||||
if q_dict["mode"] == QuantMode.SYMMERTIC: | |||||
scale = q_dict["scale"] | scale = q_dict["scale"] | ||||
# Quant | # Quant | ||||
oup = Round()(inp / scale) | oup = Round()(inp / scale) | ||||
@@ -16,6 +16,7 @@ from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||||
from ..core import Buffer, Function, tensor | from ..core import Buffer, Function, tensor | ||||
from ..jit import sideeffect | from ..jit import sideeffect | ||||
from ..module import Module | from ..module import Module | ||||
from .utils import QuantMode, get_qparam_dict | |||||
class Round(Function): | class Round(Function): | ||||
@@ -81,29 +82,10 @@ class Observer(Module): | |||||
pass | pass | ||||
class ObserverMode(Enum): | |||||
SYMMERTIC = 1 | |||||
ASYMMERTIC = 2 | |||||
def create_observer_dict(mode): | |||||
if mode == ObserverMode.SYMMERTIC: | |||||
return { | |||||
"mode": ObserverMode.SYMMERTIC, | |||||
"scale": None, | |||||
} | |||||
else: | |||||
return { | |||||
"mode": ObserverMode.ASYMMERTIC, | |||||
"scale": None, | |||||
"zero_point": None, | |||||
} | |||||
class MinMaxObserver(Observer): | class MinMaxObserver(Observer): | ||||
def __init__( | def __init__( | ||||
self, | self, | ||||
mode=ObserverMode.SYMMERTIC, | |||||
mode=QuantMode.SYMMERTIC, | |||||
eps=0.00001, | eps=0.00001, | ||||
dtype="qint8", | dtype="qint8", | ||||
narrow_range: bool = False, | narrow_range: bool = False, | ||||
@@ -117,10 +99,10 @@ class MinMaxObserver(Observer): | |||||
def _calculate_qparams(self, inp_min_val, inp_max_val): | def _calculate_qparams(self, inp_min_val, inp_max_val): | ||||
min_val = F.minimum(0.0, inp_min_val) | min_val = F.minimum(0.0, inp_min_val) | ||||
max_val = F.maximum(0.0, inp_max_val) | max_val = F.maximum(0.0, inp_max_val) | ||||
q_dict = create_observer_dict(self.mode) | |||||
q_dict = get_qparam_dict(self.mode) | |||||
q_dict["min_val"] = inp_min_val | q_dict["min_val"] = inp_min_val | ||||
q_dict["max_val"] = inp_max_val | q_dict["max_val"] = inp_max_val | ||||
if self.mode == ObserverMode.SYMMERTIC: | |||||
if self.mode == QuantMode.SYMMERTIC: | |||||
symmetric_max_vals = F.maximum(-min_val, max_val) | symmetric_max_vals = F.maximum(-min_val, max_val) | ||||
# use maximun to avoid scale too small at the begin | # use maximun to avoid scale too small at the begin | ||||
q_dict["scale"] = F.maximum( | q_dict["scale"] = F.maximum( | ||||
@@ -166,7 +148,7 @@ class ExponentialMovingAverageObserver(MinMaxObserver): | |||||
def __init__( | def __init__( | ||||
self, | self, | ||||
momentum=0.9, | momentum=0.9, | ||||
mode=ObserverMode.SYMMERTIC, | |||||
mode=QuantMode.SYMMERTIC, | |||||
eps=0.00001, | eps=0.00001, | ||||
dtype="qint8", | dtype="qint8", | ||||
narrow_range: bool = False, | narrow_range: bool = False, | ||||
@@ -204,7 +186,7 @@ class HistogramObserver(MinMaxObserver): | |||||
self, | self, | ||||
bins=2048, | bins=2048, | ||||
upsample_rate=128, | upsample_rate=128, | ||||
mode=ObserverMode.SYMMERTIC, | |||||
mode=QuantMode.SYMMERTIC, | |||||
eps=0.00001, | eps=0.00001, | ||||
dtype="qint8", | dtype="qint8", | ||||
narrow_range: bool = False, | narrow_range: bool = False, | ||||
@@ -6,6 +6,7 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from enum import Enum | |||||
from functools import partial, update_wrapper, wraps | from functools import partial, update_wrapper, wraps | ||||
@@ -21,3 +22,24 @@ def register_method_to_class(cls): | |||||
return func | return func | ||||
return decorator | return decorator | ||||
class QuantMode(Enum): | |||||
SYMMERTIC = 1 | |||||
ASYMMERTIC = 2 | |||||
TQT = 3 | |||||
qparam_dict = { | |||||
QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None,}, | |||||
QuantMode.ASYMMERTIC: { | |||||
"mode": QuantMode.ASYMMERTIC, | |||||
"scale": None, | |||||
"zero_point": None, | |||||
}, | |||||
QuantMode.TQT: {"mode": QuantMode.TQT, "scale": None,}, | |||||
} | |||||
def get_qparam_dict(mode): | |||||
return qparam_dict.get(mode, None) |