From d4b86b844e84dff4861745273d863d01bd3ca969 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 11 Aug 2020 16:36:49 +0800 Subject: [PATCH] feat(mge/dtype): add int2 lowbit support and example GitOrigin-RevId: 67c14ac959a9f2725360f79cd3838000aa5e35ea --- python_module/megengine/_internal/dtype.py | 3 +++ python_module/megengine/quantization/__init__.py | 1 + python_module/megengine/quantization/qconfig.py | 9 +++++++++ 3 files changed, 13 insertions(+) diff --git a/python_module/megengine/_internal/dtype.py b/python_module/megengine/_internal/dtype.py index 7d0eb61a..6bb32f86 100644 --- a/python_module/megengine/_internal/dtype.py +++ b/python_module/megengine/_internal/dtype.py @@ -25,6 +25,9 @@ _metadata_dict = { "qint32": _QuantDtypeMetadata( "QuantizedS32", "int32", False, -(2 ** 31), 2 ** 31 - 1, ), + # NOTE: int2 is not supported for model dump yet + "quint2": _QuantDtypeMetadata(None, "uint8", True, 0, 3), + "qint2": _QuantDtypeMetadata(None, "int8", False, -2, 1), } diff --git a/python_module/megengine/quantization/__init__.py b/python_module/megengine/quantization/__init__.py index 82feced1..9c8a0e0d 100644 --- a/python_module/megengine/quantization/__init__.py +++ b/python_module/megengine/quantization/__init__.py @@ -13,6 +13,7 @@ from .qconfig import ( QConfig, calibration_qconfig, ema_fakequant_qconfig, + ema_lowbit_fakequant_qconfig, min_max_fakequant_qconfig, tqt_quant_qconfig, ) diff --git a/python_module/megengine/quantization/qconfig.py b/python_module/megengine/quantization/qconfig.py index 4a7b75ec..6606c1a5 100644 --- a/python_module/megengine/quantization/qconfig.py +++ b/python_module/megengine/quantization/qconfig.py @@ -92,6 +92,15 @@ ema_fakequant_qconfig = QConfig( act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), ) +ema_lowbit_fakequant_qconfig = QConfig( + weight_observer=partial(MinMaxObserver, dtype="qint4", narrow_range=False), + act_observer=partial( + ExponentialMovingAverageObserver, dtype="qint4", narrow_range=False + ), + weight_fake_quant=partial(FakeQuantize, dtype="qint4", narrow_range=False), + act_fake_quant=partial(FakeQuantize, dtype="qint4", narrow_range=False), +) + calibration_qconfig = QConfig( weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), act_observer=partial(HistogramObserver, dtype="qint8", narrow_range=False),