GitOrigin-RevId: 060d908349
release-1.2
@@ -17,9 +17,7 @@ from .module import QuantizedModule | |||
class Linear(QuantizedModule): | |||
r"""Quantized version of :class:`~.qat.linear.Linear`.""" | |||
def __init__( | |||
self, dtype: np.dtype = None, | |||
): | |||
def __init__(self, dtype: np.dtype = None): | |||
super().__init__() | |||
self.weight = None | |||
self.bias = None | |||
@@ -15,7 +15,8 @@ from .qconfig import ( | |||
ema_fakequant_qconfig, | |||
ema_lowbit_fakequant_qconfig, | |||
min_max_fakequant_qconfig, | |||
passive_qconfig, | |||
sync_ema_fakequant_qconfig, | |||
tqt_quant_qconfig, | |||
tqt_qconfig, | |||
) | |||
from .utils import QuantMode |
@@ -28,7 +28,9 @@ class _FakeQuantize(Module): | |||
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``. | |||
""" | |||
def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | |||
def __init__( | |||
self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs | |||
): | |||
super().__init__() | |||
if not dtype in _metadata_dict.keys(): | |||
raise ValueError( | |||
@@ -114,24 +116,28 @@ class TQT(_FakeQuantize): | |||
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks. | |||
""" | |||
def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True): | |||
super().__init__(dtype, narrow_range, enable) | |||
self.scale = Parameter([0.0], dtype=np.float32) | |||
def __init__( | |||
self, | |||
q_dict, | |||
dtype: str, | |||
narrow_range: bool = False, | |||
enable: bool = True, | |||
**kwargs | |||
): | |||
super().__init__(dtype, narrow_range, enable, **kwargs) | |||
assert ( | |||
q_dict["mode"] == QuantMode.SYMMERTIC | |||
), "only symmetric quantization is supported by TQT" | |||
if "scale" not in q_dict or q_dict["scale"] is None: | |||
raise AssertionError("Can not get an initialized scale") | |||
self.scale = F.log(q_dict["scale"]) / math.log(2) | |||
def fake_quant_forward(self, inp, q_dict=None): | |||
# when enable, TQT will do fakequant forward, finetune the scale | |||
return TQT_Function(self.qmin, self.qmax)(inp, self.scale) | |||
def normal_foward(self, inp, q_dict=None): | |||
if q_dict["enable_observer"]: | |||
# when disable, TQT will do normal forward, initialize scale weight | |||
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"])) | |||
tmp_scale = F.log(tmp_scale / 127) / math.log(2) | |||
self.scale[...] = tmp_scale | |||
return inp | |||
def get_qparams(self): | |||
q_dict = get_qparam_dict(QuantMode.TQT) | |||
q_dict = get_qparam_dict(QuantMode.SYMMERTIC) | |||
q_dict["scale"] = 2 ** self.scale | |||
return q_dict | |||
@@ -7,6 +7,7 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import math | |||
from abc import abstractmethod | |||
from copy import deepcopy | |||
import numpy as np | |||
@@ -28,7 +29,7 @@ class Observer(Module): | |||
instead of 1 greater. Usually True for weight and False for activation. | |||
""" | |||
def __init__(self, dtype: str, narrow_range: bool = False): | |||
def __init__(self, dtype: str, narrow_range: bool = False, **kwargs): | |||
super().__init__() | |||
if dtype not in _metadata_dict.keys(): | |||
raise ValueError( | |||
@@ -81,8 +82,9 @@ class MinMaxObserver(Observer): | |||
eps=0.00001, | |||
dtype="qint8", | |||
narrow_range: bool = False, | |||
**kwargs | |||
): | |||
super().__init__(dtype, narrow_range) | |||
super().__init__(dtype, narrow_range, **kwargs) | |||
self.mode = mode | |||
self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32) | |||
self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32) | |||
@@ -105,7 +107,7 @@ class MinMaxObserver(Observer): | |||
else: | |||
# use maximun to avoid scale too small at the begin | |||
q_dict["scale"] = F.maximum( | |||
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit, | |||
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit | |||
) | |||
# caculate zero_point | |||
q_dict["zero_point"] = self.qmin - Round()((min_val / q_dict["scale"])) | |||
@@ -148,8 +150,9 @@ class ExponentialMovingAverageObserver(MinMaxObserver): | |||
eps=0.00001, | |||
dtype="qint8", | |||
narrow_range: bool = False, | |||
**kwargs | |||
): | |||
super().__init__(mode, eps, dtype, narrow_range) | |||
super().__init__(mode, eps, dtype, narrow_range, **kwargs) | |||
self.momentum = Tensor(momentum) | |||
self.runtime_momentum = Tensor(0.0) | |||
@@ -205,8 +208,9 @@ class HistogramObserver(MinMaxObserver): | |||
eps=0.00001, | |||
dtype="qint8", | |||
narrow_range: bool = False, | |||
**kwargs | |||
): | |||
super().__init__(mode, eps, dtype, narrow_range) | |||
super().__init__(mode, eps, dtype, narrow_range, **kwargs) | |||
self.bins = bins | |||
self.upsample_rate = upsample_rate | |||
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1 | |||
@@ -417,7 +421,7 @@ class HistogramObserver(MinMaxObserver): | |||
# combine the existing histogram and new histogram into 1 histogram | |||
# We do this by first upsampling the histogram to a dense grid | |||
# and then downsampling the histogram efficiently | |||
(new_min, new_max, downsample_rate, start_idx,) = self._adjust_min_max( | |||
(new_min, new_max, downsample_rate, start_idx) = self._adjust_min_max( | |||
new_min, new_max, self.upsample_rate | |||
) | |||
@@ -442,3 +446,34 @@ class HistogramObserver(MinMaxObserver): | |||
def forward(self, x_orig): | |||
self.sideeffect_forward(x_orig) | |||
return x_orig | |||
class PassiveObserver(Observer): | |||
r""" | |||
This class can be set :attr:`scale` derectly. | |||
""" | |||
def __init__(self, q_dict, dtype: str, narrow_range: bool = False, **kwargs): | |||
super().__init__(dtype, narrow_range, **kwargs) | |||
self.q_dict = deepcopy(q_dict) | |||
if "scale" not in q_dict or q_dict["scale"] is None: | |||
raise AssertionError("Can not get an initialized scale") | |||
self.orig_scale = q_dict["scale"].numpy() | |||
@property | |||
def scale(self): | |||
return self.q_dict["scale"] | |||
@scale.setter | |||
def scale(self, value): | |||
assert value > 0 | |||
self.q_dict["scale"].set_value(value) | |||
def get_qparams(self): | |||
return self.q_dict | |||
def forward(self, x): | |||
r""" | |||
Just return input because :attr:`q_dict` is set by :func:`~.apply_easy_quant`. | |||
""" | |||
return x |
@@ -13,6 +13,7 @@ from .observer import ( | |||
ExponentialMovingAverageObserver, | |||
HistogramObserver, | |||
MinMaxObserver, | |||
PassiveObserver, | |||
SyncExponentialMovingAverageObserver, | |||
SyncMinMaxObserver, | |||
) | |||
@@ -66,17 +67,22 @@ class QConfig: | |||
self.weight_fake_quant = weight_fake_quant | |||
self.act_fake_quant = act_fake_quant | |||
def __eq__(self, other): | |||
def eq(a, b): | |||
if isinstance(a, partial) and isinstance(b, partial): | |||
return all( | |||
[a.func == b.func, a.args == b.args, a.keywords == b.keywords] | |||
) | |||
else: | |||
return a == b | |||
return ( | |||
eq(self.weight_observer, other.weight_observer) | |||
and eq(self.act_observer, other.act_observer) | |||
and eq(self.weight_fake_quant, other.weight_fake_quant) | |||
and eq(self.act_fake_quant, other.act_fake_quant) | |||
) | |||
tqt_quant_qconfig = QConfig( | |||
weight_observer=partial( | |||
ExponentialMovingAverageObserver, dtype="qint8", narrow_range=True | |||
), | |||
act_observer=partial( | |||
ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False | |||
), | |||
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), | |||
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), | |||
) | |||
min_max_fakequant_qconfig = QConfig( | |||
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), | |||
@@ -118,3 +124,17 @@ calibration_qconfig = QConfig( | |||
weight_fake_quant=None, | |||
act_fake_quant=None, | |||
) | |||
tqt_qconfig = QConfig( | |||
weight_observer=None, | |||
act_observer=None, | |||
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), | |||
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), | |||
) | |||
passive_qconfig = QConfig( | |||
weight_observer=partial(PassiveObserver, dtype="qint8", narrow_range=True), | |||
act_observer=partial(PassiveObserver, dtype="qint8", narrow_range=False), | |||
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), | |||
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), | |||
) |
@@ -6,15 +6,18 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from copy import copy, deepcopy | |||
from functools import partial | |||
from typing import Callable, Dict, Tuple | |||
import numpy as np | |||
from .. import module as Float | |||
from ..functional import concat, norm | |||
from ..module import Module | |||
from ..module import qat as QAT | |||
from ..module import quantized as Quantized | |||
from ..module.qat import QATModule | |||
from ..module.quantized import QuantizedModule | |||
from .fake_quant import TQT | |||
from .qconfig import QConfig, ema_fakequant_qconfig | |||
@@ -32,9 +35,7 @@ def _get_quantable_module_names(): | |||
return quantable_module_names | |||
def _get_convert_dict() -> Tuple[ | |||
Dict[Module, QATModule], Dict[QATModule, QuantizedModule] | |||
]: | |||
def _get_convert_dict(): | |||
quantable_module_names = _get_quantable_module_names() | |||
quantable_modules = [getattr(Float, key) for key in quantable_module_names] | |||
@@ -47,6 +48,11 @@ def _get_convert_dict() -> Tuple[ | |||
_float2qat_dict, _qat2quantized_dict = _get_convert_dict() | |||
qat_modules = tuple(_qat2quantized_dict.keys()) | |||
def is_qat(mod: Module): | |||
return isinstance(mod, qat_modules) | |||
def quantize(module: Module, inplace: bool = True, mapping: dict = None): | |||
@@ -133,6 +139,34 @@ def quantize_qat( | |||
return module | |||
def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True): | |||
r""" | |||
Reset :class:`~._FakeQuantize` and :class:`~.Observer` according to ``qconfig`` | |||
:param module: root module to reset recursively. | |||
:param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig. | |||
:param inplace: whether to reset submodules in-place. | |||
""" | |||
if not inplace: | |||
module = deepcopy(module) | |||
def safe_call(func, q_dict): | |||
return func(q_dict=q_dict) if func is not None else None | |||
for m in list(module._flatten(predicate=is_qat)): | |||
if m.with_weight: | |||
weight_q_dict = m.get_weight_qparams() | |||
m.weight_observer = safe_call(qconfig.weight_observer, weight_q_dict) | |||
m.weight_fake_quant = safe_call(qconfig.weight_fake_quant, weight_q_dict) | |||
if m.with_act: | |||
act_q_dict = m.get_activation_qparams() | |||
m.act_observer = safe_call(qconfig.act_observer, act_q_dict) | |||
m.act_fake_quant = safe_call(qconfig.act_fake_quant, act_q_dict) | |||
return module | |||
def _propagate(module: Module, func_str: str, *args, **kargs): | |||
def fn(mod: Module): | |||
if isinstance(mod, QATModule): | |||
@@ -151,6 +185,85 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig): | |||
_propagate(module, "set_qconfig", qconfig) | |||
def hook_qat_module(module: Module, func: Callable): | |||
r""" | |||
Add hooks for all :class:`~.QATModule` submodule | |||
""" | |||
hooks = [] | |||
for submodule in list(module._flatten(predicate=is_qat)): | |||
hooks.append(submodule.register_forward_hook(func)) | |||
return hooks | |||
def apply_easy_quant(module, data, start=0.8, stop=1.2, num=40): | |||
r""" | |||
Implementation of ``EasyQuant``: https://arxiv.org/pdf/2006.16669. | |||
Search for optimal scales. | |||
:param module: root module. | |||
:param data: input tensor used to search optimal scale. | |||
:param start: lower bound of the search interval. | |||
:param stop: upper bound of the search interval. | |||
:param num: number of samples to search. | |||
""" | |||
batch_size = data.shape[0] | |||
def get_cosine(x, y): | |||
ndim = len(x.shape) | |||
axis = tuple(range(1, ndim)) | |||
up = (x * y).sum(axis=axis) | |||
down = norm(x, axis=axis) * norm(y, axis=axis) | |||
sim = up / down | |||
return sim.mean(axis=0) | |||
def search(mod, inputs, outputs, where): | |||
mod._forward_hooks.clear() | |||
fp32_in = [_[:batch_size] for _ in inputs] | |||
int8_in = [_[batch_size:] for _ in inputs] | |||
disable_fake_quant(mod) | |||
fp32_out = mod(*fp32_in) | |||
enable_fake_quant(mod) | |||
ob = getattr(mod, where) | |||
if ob is None: | |||
return | |||
orig_scale = ob.orig_scale | |||
distance = 0 | |||
best_scale = 0 | |||
for scale in np.linspace(start * orig_scale, stop * orig_scale, num): | |||
ob.scale = scale | |||
int8_out = mod(*int8_in) | |||
dis = get_cosine(fp32_out, int8_out) | |||
if dis > distance: | |||
distance = dis | |||
best_scale = scale | |||
ob.scale = best_scale | |||
if where == "act_observer": | |||
int8_out = mod(*int8_in) | |||
return concat([fp32_out, int8_out]) | |||
else: | |||
int8_out = outputs[batch_size:] | |||
return concat([fp32_out, int8_out]) | |||
data = concat([data, data]) | |||
hook_qat_module(module, partial(search, where="weight_observer")) | |||
module(data) | |||
hook_qat_module(module, partial(search, where="act_observer")) | |||
module(data) | |||
return module | |||
def disable_fake_quant(module: Module): | |||
r""" | |||
Recursively disable ``module`` fake quantization in QATModule through :meth:`~.Module.apply` | |||
@@ -54,17 +54,15 @@ class QuantMode(Enum): | |||
SYMMERTIC = 1 | |||
ASYMMERTIC = 2 | |||
TQT = 3 | |||
qparam_dict = { | |||
QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None,}, | |||
QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None}, | |||
QuantMode.ASYMMERTIC: { | |||
"mode": QuantMode.ASYMMERTIC, | |||
"scale": None, | |||
"zero_point": None, | |||
}, | |||
QuantMode.TQT: {"mode": QuantMode.TQT, "scale": None,}, | |||
} | |||
@@ -1,116 +0,0 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
import pytest | |||
from megengine import module as Float | |||
from megengine import tensor | |||
from megengine.module import qat as QAT | |||
from megengine.quantization import min_max_fakequant_qconfig | |||
from megengine.quantization.quantize import ( | |||
_get_quantable_module_names, | |||
disable_fake_quant, | |||
quantize_qat, | |||
) | |||
def test_get_quantable_module_names(): | |||
# need to make sure names from Quantized and QAT are the same | |||
def _get_qat_module_names(): | |||
def is_qat(key: str): | |||
value = getattr(QAT, key) | |||
return ( | |||
isinstance(value, type) | |||
and issubclass(value, QAT.QATModule) | |||
and value != QAT.QATModule | |||
) | |||
# source should have all quantable modules' names | |||
quantable_module_names = [key for key in dir(QAT) if is_qat(key)] | |||
return quantable_module_names | |||
qat_module_names = _get_qat_module_names() | |||
quantized_module_names = _get_quantable_module_names() | |||
assert set(qat_module_names) == set(quantized_module_names) | |||
for key in qat_module_names: | |||
value = getattr(Float, key) | |||
assert ( | |||
isinstance(value, type) | |||
and issubclass(value, Float.Module) | |||
and value != Float.Module | |||
) | |||
def test_disable_quantize(): | |||
class Net(Float.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.conv = Float.ConvBnRelu2d(3, 3, 3) | |||
self.conv.disable_quantize() | |||
def forward(self, x): | |||
return self.conv(x) | |||
net = Net() | |||
qat_net = quantize_qat(net, inplace=False) | |||
assert isinstance(qat_net.conv, Float.ConvBnRelu2d) | |||
assert isinstance(qat_net.conv.conv, Float.Conv2d) | |||
def test_convert_with_custom_mapping(): | |||
class FloatExample(Float.Module): | |||
def forward(self, x): | |||
return x | |||
class QATExample(QAT.QATModule): | |||
def forward(self, x): | |||
return x | |||
@classmethod | |||
def from_float_module(cls, float_module): | |||
return cls() | |||
class Net(Float.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.example = FloatExample() | |||
def forward(self, x): | |||
return self.example(x) | |||
net = Net() | |||
qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) | |||
assert isinstance(qat_net.example, QATExample) | |||
def test_disable_fake_quant(): | |||
class Net(Float.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.quant = Float.QuantStub() | |||
self.linear = Float.Linear(3, 3) | |||
self.dequant = Float.DequantStub() | |||
self.linear.bias.set_value(np.random.rand(3)) | |||
def forward(self, x): | |||
x = self.quant(x) | |||
x = self.linear(x) | |||
x = self.dequant(x) | |||
return x | |||
x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32)) | |||
net = Net() | |||
y1 = net(x).numpy() | |||
net = quantize_qat(net, min_max_fakequant_qconfig) | |||
y2 = net(x).numpy() | |||
disable_fake_quant(net) | |||
y3 = net(x).numpy() | |||
np.testing.assert_allclose(y1, y3) | |||
with pytest.raises(AssertionError): | |||
np.testing.assert_allclose(y2, y3) |
@@ -6,17 +6,53 @@ import pytest | |||
import megengine as mge | |||
import megengine.distributed as dist | |||
import megengine.quantization.observer as ob | |||
from megengine.distributed.helper import get_device_count_by_fork | |||
from megengine.quantization.observer import ( | |||
ExponentialMovingAverageObserver, | |||
MinMaxObserver, | |||
Observer, | |||
PassiveObserver, | |||
SyncExponentialMovingAverageObserver, | |||
SyncMinMaxObserver, | |||
) | |||
def test_observer(): | |||
with pytest.raises(TypeError): | |||
Observer("qint8") | |||
def test_min_max_observer(): | |||
x = np.random.rand(3, 3, 3, 3).astype("float32") | |||
np_min, np_max = x.min(), x.max() | |||
x = mge.tensor(x) | |||
m = ob.MinMaxObserver() | |||
m = MinMaxObserver() | |||
m(x) | |||
assert m.min_val == np_min and m.max_val == np_max | |||
np.testing.assert_allclose(m.min_val.numpy(), np_min) | |||
np.testing.assert_allclose(m.max_val.numpy(), np_max) | |||
def test_exponential_moving_average_observer(): | |||
t = np.random.rand() | |||
x1 = np.random.rand(3, 3, 3, 3).astype("float32") | |||
x2 = np.random.rand(3, 3, 3, 3).astype("float32") | |||
expected_min = x1.min() * t + x2.min() * (1 - t) | |||
expected_max = x1.max() * t + x2.max() * (1 - t) | |||
m = ExponentialMovingAverageObserver(momentum=t) | |||
m(mge.tensor(x1, dtype=np.float32)) | |||
m(mge.tensor(x2, dtype=np.float32)) | |||
np.testing.assert_allclose(m.min_val.numpy(), expected_min) | |||
np.testing.assert_allclose(m.max_val.numpy(), expected_max) | |||
def test_passive_observer(): | |||
q_dict = {"scale": mge.tensor(1.0)} | |||
m = PassiveObserver(q_dict, "qint8") | |||
assert m.orig_scale == 1.0 | |||
assert m.scale == 1.0 | |||
m.scale = 2.0 | |||
assert m.scale == 2.0 | |||
assert m.get_qparams() == {"scale": mge.tensor(2.0)} | |||
@pytest.mark.skipif( | |||
@@ -35,9 +71,39 @@ def test_sync_min_max_observer(): | |||
@dist.launcher | |||
def worker(): | |||
rank = dist.get_rank() | |||
m = ob.SyncMinMaxObserver() | |||
m = SyncMinMaxObserver() | |||
y = mge.tensor(x[rank * 3 : (rank + 1) * 3]) | |||
m(y) | |||
assert m.min_val == np_min and m.max_val == np_max | |||
worker() | |||
@pytest.mark.skipif( | |||
platform.system() == "Darwin", reason="do not imp GPU mode at macos now" | |||
) | |||
@pytest.mark.skipif( | |||
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM" | |||
) | |||
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") | |||
@pytest.mark.isolated_distributed | |||
def test_sync_exponential_moving_average_observer(): | |||
word_size = get_device_count_by_fork("gpu") | |||
t = np.random.rand() | |||
x1 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") | |||
x2 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32") | |||
expected_min = x1.min() * t + x2.min() * (1 - t) | |||
expected_max = x1.max() * t + x2.max() * (1 - t) | |||
@dist.launcher | |||
def worker(): | |||
rank = dist.get_rank() | |||
m = SyncExponentialMovingAverageObserver(momentum=t) | |||
y1 = mge.tensor(x1[rank * 3 : (rank + 1) * 3]) | |||
y2 = mge.tensor(x2[rank * 3 : (rank + 1) * 3]) | |||
m(y1) | |||
m(y2) | |||
np.testing.assert_allclose(m.min_val.numpy(), expected_min) | |||
np.testing.assert_allclose(m.max_val.numpy(), expected_max) | |||
worker() |
@@ -0,0 +1,14 @@ | |||
from functools import partial | |||
from megengine.quantization import QConfig, tqt_qconfig | |||
from megengine.quantization.fake_quant import TQT | |||
def test_equal(): | |||
qconfig = QConfig( | |||
weight_observer=None, | |||
act_observer=None, | |||
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True), | |||
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False), | |||
) | |||
assert qconfig == tqt_qconfig |
@@ -0,0 +1,266 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
import pytest | |||
from megengine import functional | |||
from megengine import module as Float | |||
from megengine import tensor | |||
from megengine.module import qat as QAT | |||
from megengine.module import quantized as Q | |||
from megengine.quantization import ( | |||
min_max_fakequant_qconfig, | |||
passive_qconfig, | |||
tqt_qconfig, | |||
) | |||
from megengine.quantization.fake_quant import TQT, FakeQuantize | |||
from megengine.quantization.observer import MinMaxObserver, PassiveObserver | |||
from megengine.quantization.quantize import ( | |||
_get_quantable_module_names, | |||
apply_easy_quant, | |||
disable_fake_quant, | |||
disable_observer, | |||
enable_fake_quant, | |||
enable_observer, | |||
propagate_qconfig, | |||
quantize, | |||
quantize_qat, | |||
reset_qconfig, | |||
) | |||
class Net(Float.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.quant = Float.QuantStub() | |||
self.linear = Float.Linear(3, 3) | |||
self.dequant = Float.DequantStub() | |||
self.linear.bias.set_value(np.random.rand(3)) | |||
def forward(self, x): | |||
x = self.quant(x) | |||
x = self.linear(x) | |||
x = self.dequant(x) | |||
return x | |||
class QATNet(Float.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.quant = QAT.QuantStub() | |||
self.linear = QAT.Linear(3, 3) | |||
self.dequant = QAT.DequantStub() | |||
self.linear.bias.set_value(np.random.rand(3)) | |||
def forward(self, x): | |||
x = self.quant(x) | |||
x = self.linear(x) | |||
x = self.dequant(x) | |||
return x | |||
def test_propagate_qconfig(): | |||
net = QATNet() | |||
propagate_qconfig(net, min_max_fakequant_qconfig) | |||
assert all( | |||
[ | |||
net.quant.weight_observer is None, | |||
net.quant.weight_fake_quant is None, | |||
isinstance(net.quant.act_observer, MinMaxObserver), | |||
isinstance(net.quant.act_fake_quant, FakeQuantize), | |||
isinstance(net.linear.weight_observer, MinMaxObserver), | |||
isinstance(net.linear.weight_fake_quant, FakeQuantize), | |||
isinstance(net.linear.act_observer, MinMaxObserver), | |||
isinstance(net.linear.act_fake_quant, FakeQuantize), | |||
net.dequant.weight_observer is None, | |||
net.dequant.weight_fake_quant is None, | |||
net.dequant.act_observer is None, | |||
net.dequant.act_observer is None, | |||
] | |||
) | |||
def init_qat_net(): | |||
net = QATNet() | |||
propagate_qconfig(net, min_max_fakequant_qconfig) | |||
min_val = np.random.randint(-127, 0, size=(2,)) | |||
max_val = np.random.randint(1, 127, size=(2,)) | |||
net.linear.weight_observer.min_val.set_value(min_val[0]) | |||
net.linear.weight_observer.max_val.set_value(max_val[0]) | |||
net.linear.act_observer.min_val.set_value(min_val[1]) | |||
net.linear.act_observer.max_val.set_value(max_val[1]) | |||
return net | |||
def test_reset_qconfig(): | |||
qat_net = init_qat_net() | |||
new_qat_net = reset_qconfig(qat_net, passive_qconfig) | |||
assert ( | |||
new_qat_net.linear.get_weight_qparams() == qat_net.linear.get_weight_qparams() | |||
) | |||
assert ( | |||
new_qat_net.linear.get_activation_qparams() | |||
== qat_net.linear.get_activation_qparams() | |||
) | |||
def test_enable_and_disable_observer(): | |||
net = init_qat_net() | |||
enable_observer(net) | |||
assert net.quant.act_observer.enabled == True | |||
assert net.linear.weight_observer.enabled == True | |||
assert net.linear.act_observer.enabled == True | |||
disable_observer(net) | |||
assert net.quant.act_observer.enabled == False | |||
assert net.linear.weight_observer.enabled == False | |||
assert net.linear.act_observer.enabled == False | |||
def test_enable_and_disable_fake_quant(): | |||
net = init_qat_net() | |||
disable_fake_quant(net) | |||
assert net.quant.act_fake_quant.enabled == False | |||
assert net.linear.weight_fake_quant.enabled == False | |||
assert net.linear.act_fake_quant.enabled == False | |||
enable_fake_quant(net) | |||
assert net.quant.act_fake_quant.enabled == True | |||
assert net.linear.weight_fake_quant.enabled == True | |||
assert net.linear.act_fake_quant.enabled == True | |||
def init_observer(module, data): | |||
enable_observer(module) | |||
disable_fake_quant(module) | |||
module(data) | |||
disable_observer(module) | |||
enable_fake_quant(module) | |||
def test_enable_and_disable_all(): | |||
x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32)) | |||
net = Net() | |||
y1 = net(x).numpy() | |||
net = quantize_qat(net, min_max_fakequant_qconfig) | |||
init_observer(net, x) | |||
y2 = net(x).numpy() | |||
disable_fake_quant(net) | |||
y3 = net(x).numpy() | |||
enable_fake_quant(net) | |||
y4 = net(x).numpy() | |||
np.testing.assert_allclose(y1, y3) | |||
np.testing.assert_allclose(y2, y4) | |||
with pytest.raises(AssertionError): | |||
np.testing.assert_allclose(y2, y3) | |||
def test_quantize_qat(): | |||
net = Net() | |||
qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig) | |||
assert isinstance(qat_net.quant, QAT.QuantStub) | |||
assert isinstance(qat_net.linear, QAT.Linear) | |||
assert isinstance(qat_net.dequant, QAT.DequantStub) | |||
def test_quantize(): | |||
qat_net = init_qat_net() | |||
q_net = quantize(qat_net, inplace=False) | |||
assert isinstance(q_net.quant, Q.QuantStub) | |||
assert isinstance(q_net.linear, Q.Linear) | |||
assert isinstance(q_net.dequant, Q.DequantStub) | |||
def test_apply_easy_quant(): | |||
qat_net = init_qat_net() | |||
data = tensor(np.random.rand(2, 3, 3, 3), dtype=np.float32) | |||
eq_net = reset_qconfig(qat_net, passive_qconfig, inplace=False) | |||
apply_easy_quant(eq_net, data, 0.9, 1.1, 10) | |||
assert isinstance(eq_net.quant.act_observer, PassiveObserver) | |||
assert isinstance(eq_net.linear.weight_observer, PassiveObserver) | |||
assert isinstance(eq_net.linear.act_observer, PassiveObserver) | |||
assert eq_net.dequant.act_observer is None | |||
def test_apply_tqt(): | |||
qat_net = init_qat_net() | |||
tqt_net = reset_qconfig(qat_net, tqt_qconfig, inplace=False) | |||
assert isinstance(tqt_net.quant.act_fake_quant, TQT) | |||
assert isinstance(tqt_net.linear.weight_fake_quant, TQT) | |||
assert isinstance(tqt_net.linear.act_fake_quant, TQT) | |||
assert tqt_net.dequant.act_fake_quant is None | |||
def test_get_quantable_module_names(): | |||
# need to make sure names from Quantized and QAT are the same | |||
def _get_qat_module_names(): | |||
def is_qat(key: str): | |||
value = getattr(QAT, key) | |||
return ( | |||
isinstance(value, type) | |||
and issubclass(value, QAT.QATModule) | |||
and value != QAT.QATModule | |||
) | |||
# source should have all quantable modules' names | |||
quantable_module_names = [key for key in dir(QAT) if is_qat(key)] | |||
return quantable_module_names | |||
qat_module_names = _get_qat_module_names() | |||
quantized_module_names = _get_quantable_module_names() | |||
assert set(qat_module_names) == set(quantized_module_names) | |||
for key in qat_module_names: | |||
value = getattr(Float, key) | |||
assert ( | |||
isinstance(value, type) | |||
and issubclass(value, Float.Module) | |||
and value != Float.Module | |||
) | |||
def test_disable_quantize(): | |||
class Net(Float.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.conv = Float.ConvBnRelu2d(3, 3, 3) | |||
self.conv.disable_quantize() | |||
def forward(self, x): | |||
return self.conv(x) | |||
net = Net() | |||
qat_net = quantize_qat(net, inplace=False) | |||
assert isinstance(qat_net.conv, Float.ConvBnRelu2d) | |||
assert isinstance(qat_net.conv.conv, Float.Conv2d) | |||
def test_convert_with_custom_mapping(): | |||
class FloatExample(Float.Module): | |||
def forward(self, x): | |||
return x | |||
class QATExample(QAT.QATModule): | |||
def forward(self, x): | |||
return x | |||
@classmethod | |||
def from_float_module(cls, float_module): | |||
return cls() | |||
class Net(Float.Module): | |||
def __init__(self): | |||
super().__init__() | |||
self.example = FloatExample() | |||
def forward(self, x): | |||
return self.example(x) | |||
net = Net() | |||
qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) | |||
assert isinstance(qat_net.example, QATExample) |