Browse Source

feat(mge/quantization): add support for easyquant

GitOrigin-RevId: 060d908349
release-1.2
Megvii Engine Team 4 years ago
parent
commit
ab9f44f15c
11 changed files with 561 additions and 160 deletions
  1. +1
    -3
      imperative/python/megengine/module/quantized/linear.py
  2. +2
    -1
      imperative/python/megengine/quantization/__init__.py
  3. +19
    -13
      imperative/python/megengine/quantization/fake_quant.py
  4. +41
    -6
      imperative/python/megengine/quantization/observer.py
  5. +30
    -10
      imperative/python/megengine/quantization/qconfig.py
  6. +117
    -4
      imperative/python/megengine/quantization/quantize.py
  7. +1
    -3
      imperative/python/megengine/quantization/utils.py
  8. +0
    -116
      imperative/python/test/unit/quantization/quantize.py
  9. +70
    -4
      imperative/python/test/unit/quantization/test_observer.py
  10. +14
    -0
      imperative/python/test/unit/quantization/test_qconfig.py
  11. +266
    -0
      imperative/python/test/unit/quantization/test_quantize.py

+ 1
- 3
imperative/python/megengine/module/quantized/linear.py View File

@@ -17,9 +17,7 @@ from .module import QuantizedModule
class Linear(QuantizedModule):
r"""Quantized version of :class:`~.qat.linear.Linear`."""

def __init__(
self, dtype: np.dtype = None,
):
def __init__(self, dtype: np.dtype = None):
super().__init__()
self.weight = None
self.bias = None


+ 2
- 1
imperative/python/megengine/quantization/__init__.py View File

@@ -15,7 +15,8 @@ from .qconfig import (
ema_fakequant_qconfig,
ema_lowbit_fakequant_qconfig,
min_max_fakequant_qconfig,
passive_qconfig,
sync_ema_fakequant_qconfig,
tqt_quant_qconfig,
tqt_qconfig,
)
from .utils import QuantMode

+ 19
- 13
imperative/python/megengine/quantization/fake_quant.py View File

@@ -28,7 +28,9 @@ class _FakeQuantize(Module):
:param enable: whether do ``normal_forward`` or ``fake_quant_forward``.
"""

def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True):
def __init__(
self, dtype: str, narrow_range: bool = False, enable: bool = True, **kwargs
):
super().__init__()
if not dtype in _metadata_dict.keys():
raise ValueError(
@@ -114,24 +116,28 @@ class TQT(_FakeQuantize):
for Accurate and Efficient Fixed-Point Inference of Deep Neural Networks.
"""

def __init__(self, dtype: str, narrow_range: bool = False, enable: bool = True):
super().__init__(dtype, narrow_range, enable)
self.scale = Parameter([0.0], dtype=np.float32)
def __init__(
self,
q_dict,
dtype: str,
narrow_range: bool = False,
enable: bool = True,
**kwargs
):
super().__init__(dtype, narrow_range, enable, **kwargs)
assert (
q_dict["mode"] == QuantMode.SYMMERTIC
), "only symmetric quantization is supported by TQT"
if "scale" not in q_dict or q_dict["scale"] is None:
raise AssertionError("Can not get an initialized scale")
self.scale = F.log(q_dict["scale"]) / math.log(2)

def fake_quant_forward(self, inp, q_dict=None):
# when enable, TQT will do fakequant forward, finetune the scale
return TQT_Function(self.qmin, self.qmax)(inp, self.scale)

def normal_foward(self, inp, q_dict=None):
if q_dict["enable_observer"]:
# when disable, TQT will do normal forward, initialize scale weight
tmp_scale = F.maximum(F.abs(q_dict["min_val"]), F.abs(q_dict["max_val"]))
tmp_scale = F.log(tmp_scale / 127) / math.log(2)
self.scale[...] = tmp_scale
return inp

def get_qparams(self):
q_dict = get_qparam_dict(QuantMode.TQT)
q_dict = get_qparam_dict(QuantMode.SYMMERTIC)
q_dict["scale"] = 2 ** self.scale
return q_dict



+ 41
- 6
imperative/python/megengine/quantization/observer.py View File

@@ -7,6 +7,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import math
from abc import abstractmethod
from copy import deepcopy

import numpy as np

@@ -28,7 +29,7 @@ class Observer(Module):
instead of 1 greater. Usually True for weight and False for activation.
"""

def __init__(self, dtype: str, narrow_range: bool = False):
def __init__(self, dtype: str, narrow_range: bool = False, **kwargs):
super().__init__()
if dtype not in _metadata_dict.keys():
raise ValueError(
@@ -81,8 +82,9 @@ class MinMaxObserver(Observer):
eps=0.00001,
dtype="qint8",
narrow_range: bool = False,
**kwargs
):
super().__init__(dtype, narrow_range)
super().__init__(dtype, narrow_range, **kwargs)
self.mode = mode
self.min_val = Tensor(np.finfo(np.float32).max, dtype=np.float32)
self.max_val = Tensor(np.finfo(np.float32).min, dtype=np.float32)
@@ -105,7 +107,7 @@ class MinMaxObserver(Observer):
else:
# use maximun to avoid scale too small at the begin
q_dict["scale"] = F.maximum(
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit,
(max_val - min_val) / (self.qmax - self.qmin), self.scale_limit
)
# caculate zero_point
q_dict["zero_point"] = self.qmin - Round()((min_val / q_dict["scale"]))
@@ -148,8 +150,9 @@ class ExponentialMovingAverageObserver(MinMaxObserver):
eps=0.00001,
dtype="qint8",
narrow_range: bool = False,
**kwargs
):
super().__init__(mode, eps, dtype, narrow_range)
super().__init__(mode, eps, dtype, narrow_range, **kwargs)
self.momentum = Tensor(momentum)
self.runtime_momentum = Tensor(0.0)

@@ -205,8 +208,9 @@ class HistogramObserver(MinMaxObserver):
eps=0.00001,
dtype="qint8",
narrow_range: bool = False,
**kwargs
):
super().__init__(mode, eps, dtype, narrow_range)
super().__init__(mode, eps, dtype, narrow_range, **kwargs)
self.bins = bins
self.upsample_rate = upsample_rate
self.dst_nbins = _metadata_dict[dtype].qmax - _metadata_dict[dtype].qmin + 1
@@ -417,7 +421,7 @@ class HistogramObserver(MinMaxObserver):
# combine the existing histogram and new histogram into 1 histogram
# We do this by first upsampling the histogram to a dense grid
# and then downsampling the histogram efficiently
(new_min, new_max, downsample_rate, start_idx,) = self._adjust_min_max(
(new_min, new_max, downsample_rate, start_idx) = self._adjust_min_max(
new_min, new_max, self.upsample_rate
)

@@ -442,3 +446,34 @@ class HistogramObserver(MinMaxObserver):
def forward(self, x_orig):
self.sideeffect_forward(x_orig)
return x_orig


class PassiveObserver(Observer):
r"""
This class can be set :attr:`scale` derectly.
"""

def __init__(self, q_dict, dtype: str, narrow_range: bool = False, **kwargs):
super().__init__(dtype, narrow_range, **kwargs)
self.q_dict = deepcopy(q_dict)
if "scale" not in q_dict or q_dict["scale"] is None:
raise AssertionError("Can not get an initialized scale")
self.orig_scale = q_dict["scale"].numpy()

@property
def scale(self):
return self.q_dict["scale"]

@scale.setter
def scale(self, value):
assert value > 0
self.q_dict["scale"].set_value(value)

def get_qparams(self):
return self.q_dict

def forward(self, x):
r"""
Just return input because :attr:`q_dict` is set by :func:`~.apply_easy_quant`.
"""
return x

+ 30
- 10
imperative/python/megengine/quantization/qconfig.py View File

@@ -13,6 +13,7 @@ from .observer import (
ExponentialMovingAverageObserver,
HistogramObserver,
MinMaxObserver,
PassiveObserver,
SyncExponentialMovingAverageObserver,
SyncMinMaxObserver,
)
@@ -66,17 +67,22 @@ class QConfig:
self.weight_fake_quant = weight_fake_quant
self.act_fake_quant = act_fake_quant

def __eq__(self, other):
def eq(a, b):
if isinstance(a, partial) and isinstance(b, partial):
return all(
[a.func == b.func, a.args == b.args, a.keywords == b.keywords]
)
else:
return a == b

return (
eq(self.weight_observer, other.weight_observer)
and eq(self.act_observer, other.act_observer)
and eq(self.weight_fake_quant, other.weight_fake_quant)
and eq(self.act_fake_quant, other.act_fake_quant)
)

tqt_quant_qconfig = QConfig(
weight_observer=partial(
ExponentialMovingAverageObserver, dtype="qint8", narrow_range=True
),
act_observer=partial(
ExponentialMovingAverageObserver, dtype="qint8", narrow_range=False
),
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True),
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False),
)

min_max_fakequant_qconfig = QConfig(
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True),
@@ -118,3 +124,17 @@ calibration_qconfig = QConfig(
weight_fake_quant=None,
act_fake_quant=None,
)

tqt_qconfig = QConfig(
weight_observer=None,
act_observer=None,
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True),
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False),
)

passive_qconfig = QConfig(
weight_observer=partial(PassiveObserver, dtype="qint8", narrow_range=True),
act_observer=partial(PassiveObserver, dtype="qint8", narrow_range=False),
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True),
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False),
)

+ 117
- 4
imperative/python/megengine/quantization/quantize.py View File

@@ -6,15 +6,18 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import copy, deepcopy
from functools import partial
from typing import Callable, Dict, Tuple

import numpy as np

from .. import module as Float
from ..functional import concat, norm
from ..module import Module
from ..module import qat as QAT
from ..module import quantized as Quantized
from ..module.qat import QATModule
from ..module.quantized import QuantizedModule
from .fake_quant import TQT
from .qconfig import QConfig, ema_fakequant_qconfig


@@ -32,9 +35,7 @@ def _get_quantable_module_names():
return quantable_module_names


def _get_convert_dict() -> Tuple[
Dict[Module, QATModule], Dict[QATModule, QuantizedModule]
]:
def _get_convert_dict():
quantable_module_names = _get_quantable_module_names()

quantable_modules = [getattr(Float, key) for key in quantable_module_names]
@@ -47,6 +48,11 @@ def _get_convert_dict() -> Tuple[


_float2qat_dict, _qat2quantized_dict = _get_convert_dict()
qat_modules = tuple(_qat2quantized_dict.keys())


def is_qat(mod: Module):
return isinstance(mod, qat_modules)


def quantize(module: Module, inplace: bool = True, mapping: dict = None):
@@ -133,6 +139,34 @@ def quantize_qat(
return module


def reset_qconfig(module: Module, qconfig: QConfig, inplace: bool = True):
r"""
Reset :class:`~._FakeQuantize` and :class:`~.Observer` according to ``qconfig``

:param module: root module to reset recursively.
:param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig.
:param inplace: whether to reset submodules in-place.
"""

if not inplace:
module = deepcopy(module)

def safe_call(func, q_dict):
return func(q_dict=q_dict) if func is not None else None

for m in list(module._flatten(predicate=is_qat)):
if m.with_weight:
weight_q_dict = m.get_weight_qparams()
m.weight_observer = safe_call(qconfig.weight_observer, weight_q_dict)
m.weight_fake_quant = safe_call(qconfig.weight_fake_quant, weight_q_dict)
if m.with_act:
act_q_dict = m.get_activation_qparams()
m.act_observer = safe_call(qconfig.act_observer, act_q_dict)
m.act_fake_quant = safe_call(qconfig.act_fake_quant, act_q_dict)

return module


def _propagate(module: Module, func_str: str, *args, **kargs):
def fn(mod: Module):
if isinstance(mod, QATModule):
@@ -151,6 +185,85 @@ def propagate_qconfig(module: QATModule, qconfig: QConfig):
_propagate(module, "set_qconfig", qconfig)


def hook_qat_module(module: Module, func: Callable):
r"""
Add hooks for all :class:`~.QATModule` submodule
"""

hooks = []
for submodule in list(module._flatten(predicate=is_qat)):
hooks.append(submodule.register_forward_hook(func))

return hooks


def apply_easy_quant(module, data, start=0.8, stop=1.2, num=40):
r"""
Implementation of ``EasyQuant``: https://arxiv.org/pdf/2006.16669.
Search for optimal scales.

:param module: root module.
:param data: input tensor used to search optimal scale.
:param start: lower bound of the search interval.
:param stop: upper bound of the search interval.
:param num: number of samples to search.
"""

batch_size = data.shape[0]

def get_cosine(x, y):
ndim = len(x.shape)
axis = tuple(range(1, ndim))
up = (x * y).sum(axis=axis)
down = norm(x, axis=axis) * norm(y, axis=axis)
sim = up / down
return sim.mean(axis=0)

def search(mod, inputs, outputs, where):

mod._forward_hooks.clear()

fp32_in = [_[:batch_size] for _ in inputs]
int8_in = [_[batch_size:] for _ in inputs]

disable_fake_quant(mod)
fp32_out = mod(*fp32_in)
enable_fake_quant(mod)

ob = getattr(mod, where)
if ob is None:
return

orig_scale = ob.orig_scale
distance = 0
best_scale = 0
for scale in np.linspace(start * orig_scale, stop * orig_scale, num):
ob.scale = scale
int8_out = mod(*int8_in)
dis = get_cosine(fp32_out, int8_out)
if dis > distance:
distance = dis
best_scale = scale
ob.scale = best_scale

if where == "act_observer":
int8_out = mod(*int8_in)
return concat([fp32_out, int8_out])
else:
int8_out = outputs[batch_size:]
return concat([fp32_out, int8_out])

data = concat([data, data])

hook_qat_module(module, partial(search, where="weight_observer"))
module(data)

hook_qat_module(module, partial(search, where="act_observer"))
module(data)

return module


def disable_fake_quant(module: Module):
r"""
Recursively disable ``module`` fake quantization in QATModule through :meth:`~.Module.apply`


+ 1
- 3
imperative/python/megengine/quantization/utils.py View File

@@ -54,17 +54,15 @@ class QuantMode(Enum):

SYMMERTIC = 1
ASYMMERTIC = 2
TQT = 3


qparam_dict = {
QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None,},
QuantMode.SYMMERTIC: {"mode": QuantMode.SYMMERTIC, "scale": None},
QuantMode.ASYMMERTIC: {
"mode": QuantMode.ASYMMERTIC,
"scale": None,
"zero_point": None,
},
QuantMode.TQT: {"mode": QuantMode.TQT, "scale": None,},
}




+ 0
- 116
imperative/python/test/unit/quantization/quantize.py View File

@@ -1,116 +0,0 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest

from megengine import module as Float
from megengine import tensor
from megengine.module import qat as QAT
from megengine.quantization import min_max_fakequant_qconfig
from megengine.quantization.quantize import (
_get_quantable_module_names,
disable_fake_quant,
quantize_qat,
)


def test_get_quantable_module_names():
# need to make sure names from Quantized and QAT are the same
def _get_qat_module_names():
def is_qat(key: str):
value = getattr(QAT, key)
return (
isinstance(value, type)
and issubclass(value, QAT.QATModule)
and value != QAT.QATModule
)

# source should have all quantable modules' names
quantable_module_names = [key for key in dir(QAT) if is_qat(key)]
return quantable_module_names

qat_module_names = _get_qat_module_names()
quantized_module_names = _get_quantable_module_names()
assert set(qat_module_names) == set(quantized_module_names)

for key in qat_module_names:
value = getattr(Float, key)
assert (
isinstance(value, type)
and issubclass(value, Float.Module)
and value != Float.Module
)


def test_disable_quantize():
class Net(Float.Module):
def __init__(self):
super().__init__()
self.conv = Float.ConvBnRelu2d(3, 3, 3)
self.conv.disable_quantize()

def forward(self, x):
return self.conv(x)

net = Net()
qat_net = quantize_qat(net, inplace=False)
assert isinstance(qat_net.conv, Float.ConvBnRelu2d)
assert isinstance(qat_net.conv.conv, Float.Conv2d)


def test_convert_with_custom_mapping():
class FloatExample(Float.Module):
def forward(self, x):
return x

class QATExample(QAT.QATModule):
def forward(self, x):
return x

@classmethod
def from_float_module(cls, float_module):
return cls()

class Net(Float.Module):
def __init__(self):
super().__init__()
self.example = FloatExample()

def forward(self, x):
return self.example(x)

net = Net()
qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample})
assert isinstance(qat_net.example, QATExample)


def test_disable_fake_quant():
class Net(Float.Module):
def __init__(self):
super().__init__()
self.quant = Float.QuantStub()
self.linear = Float.Linear(3, 3)
self.dequant = Float.DequantStub()
self.linear.bias.set_value(np.random.rand(3))

def forward(self, x):
x = self.quant(x)
x = self.linear(x)
x = self.dequant(x)
return x

x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32))
net = Net()
y1 = net(x).numpy()
net = quantize_qat(net, min_max_fakequant_qconfig)
y2 = net(x).numpy()
disable_fake_quant(net)
y3 = net(x).numpy()
np.testing.assert_allclose(y1, y3)
with pytest.raises(AssertionError):
np.testing.assert_allclose(y2, y3)

+ 70
- 4
imperative/python/test/unit/quantization/test_observer.py View File

@@ -6,17 +6,53 @@ import pytest

import megengine as mge
import megengine.distributed as dist
import megengine.quantization.observer as ob
from megengine.distributed.helper import get_device_count_by_fork
from megengine.quantization.observer import (
ExponentialMovingAverageObserver,
MinMaxObserver,
Observer,
PassiveObserver,
SyncExponentialMovingAverageObserver,
SyncMinMaxObserver,
)


def test_observer():
with pytest.raises(TypeError):
Observer("qint8")


def test_min_max_observer():
x = np.random.rand(3, 3, 3, 3).astype("float32")
np_min, np_max = x.min(), x.max()
x = mge.tensor(x)
m = ob.MinMaxObserver()
m = MinMaxObserver()
m(x)
assert m.min_val == np_min and m.max_val == np_max
np.testing.assert_allclose(m.min_val.numpy(), np_min)
np.testing.assert_allclose(m.max_val.numpy(), np_max)


def test_exponential_moving_average_observer():
t = np.random.rand()
x1 = np.random.rand(3, 3, 3, 3).astype("float32")
x2 = np.random.rand(3, 3, 3, 3).astype("float32")
expected_min = x1.min() * t + x2.min() * (1 - t)
expected_max = x1.max() * t + x2.max() * (1 - t)
m = ExponentialMovingAverageObserver(momentum=t)
m(mge.tensor(x1, dtype=np.float32))
m(mge.tensor(x2, dtype=np.float32))
np.testing.assert_allclose(m.min_val.numpy(), expected_min)
np.testing.assert_allclose(m.max_val.numpy(), expected_max)


def test_passive_observer():
q_dict = {"scale": mge.tensor(1.0)}
m = PassiveObserver(q_dict, "qint8")
assert m.orig_scale == 1.0
assert m.scale == 1.0
m.scale = 2.0
assert m.scale == 2.0
assert m.get_qparams() == {"scale": mge.tensor(2.0)}


@pytest.mark.skipif(
@@ -35,9 +71,39 @@ def test_sync_min_max_observer():
@dist.launcher
def worker():
rank = dist.get_rank()
m = ob.SyncMinMaxObserver()
m = SyncMinMaxObserver()
y = mge.tensor(x[rank * 3 : (rank + 1) * 3])
m(y)
assert m.min_val == np_min and m.max_val == np_max

worker()


@pytest.mark.skipif(
platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
)
@pytest.mark.skipif(
platform.system() == "Windows", reason="windows disable MGB_ENABLE_OPR_MM"
)
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_sync_exponential_moving_average_observer():
word_size = get_device_count_by_fork("gpu")
t = np.random.rand()
x1 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32")
x2 = np.random.rand(3 * word_size, 3, 3, 3).astype("float32")
expected_min = x1.min() * t + x2.min() * (1 - t)
expected_max = x1.max() * t + x2.max() * (1 - t)

@dist.launcher
def worker():
rank = dist.get_rank()
m = SyncExponentialMovingAverageObserver(momentum=t)
y1 = mge.tensor(x1[rank * 3 : (rank + 1) * 3])
y2 = mge.tensor(x2[rank * 3 : (rank + 1) * 3])
m(y1)
m(y2)
np.testing.assert_allclose(m.min_val.numpy(), expected_min)
np.testing.assert_allclose(m.max_val.numpy(), expected_max)

worker()

+ 14
- 0
imperative/python/test/unit/quantization/test_qconfig.py View File

@@ -0,0 +1,14 @@
from functools import partial

from megengine.quantization import QConfig, tqt_qconfig
from megengine.quantization.fake_quant import TQT


def test_equal():
qconfig = QConfig(
weight_observer=None,
act_observer=None,
weight_fake_quant=partial(TQT, dtype="qint8", narrow_range=True),
act_fake_quant=partial(TQT, dtype="qint8", narrow_range=False),
)
assert qconfig == tqt_qconfig

+ 266
- 0
imperative/python/test/unit/quantization/test_quantize.py View File

@@ -0,0 +1,266 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest

from megengine import functional
from megengine import module as Float
from megengine import tensor
from megengine.module import qat as QAT
from megengine.module import quantized as Q
from megengine.quantization import (
min_max_fakequant_qconfig,
passive_qconfig,
tqt_qconfig,
)
from megengine.quantization.fake_quant import TQT, FakeQuantize
from megengine.quantization.observer import MinMaxObserver, PassiveObserver
from megengine.quantization.quantize import (
_get_quantable_module_names,
apply_easy_quant,
disable_fake_quant,
disable_observer,
enable_fake_quant,
enable_observer,
propagate_qconfig,
quantize,
quantize_qat,
reset_qconfig,
)


class Net(Float.Module):
def __init__(self):
super().__init__()
self.quant = Float.QuantStub()
self.linear = Float.Linear(3, 3)
self.dequant = Float.DequantStub()
self.linear.bias.set_value(np.random.rand(3))

def forward(self, x):
x = self.quant(x)
x = self.linear(x)
x = self.dequant(x)
return x


class QATNet(Float.Module):
def __init__(self):
super().__init__()
self.quant = QAT.QuantStub()
self.linear = QAT.Linear(3, 3)
self.dequant = QAT.DequantStub()
self.linear.bias.set_value(np.random.rand(3))

def forward(self, x):
x = self.quant(x)
x = self.linear(x)
x = self.dequant(x)
return x


def test_propagate_qconfig():
net = QATNet()
propagate_qconfig(net, min_max_fakequant_qconfig)
assert all(
[
net.quant.weight_observer is None,
net.quant.weight_fake_quant is None,
isinstance(net.quant.act_observer, MinMaxObserver),
isinstance(net.quant.act_fake_quant, FakeQuantize),
isinstance(net.linear.weight_observer, MinMaxObserver),
isinstance(net.linear.weight_fake_quant, FakeQuantize),
isinstance(net.linear.act_observer, MinMaxObserver),
isinstance(net.linear.act_fake_quant, FakeQuantize),
net.dequant.weight_observer is None,
net.dequant.weight_fake_quant is None,
net.dequant.act_observer is None,
net.dequant.act_observer is None,
]
)


def init_qat_net():
net = QATNet()
propagate_qconfig(net, min_max_fakequant_qconfig)
min_val = np.random.randint(-127, 0, size=(2,))
max_val = np.random.randint(1, 127, size=(2,))
net.linear.weight_observer.min_val.set_value(min_val[0])
net.linear.weight_observer.max_val.set_value(max_val[0])
net.linear.act_observer.min_val.set_value(min_val[1])
net.linear.act_observer.max_val.set_value(max_val[1])
return net


def test_reset_qconfig():
qat_net = init_qat_net()
new_qat_net = reset_qconfig(qat_net, passive_qconfig)
assert (
new_qat_net.linear.get_weight_qparams() == qat_net.linear.get_weight_qparams()
)
assert (
new_qat_net.linear.get_activation_qparams()
== qat_net.linear.get_activation_qparams()
)


def test_enable_and_disable_observer():
net = init_qat_net()
enable_observer(net)
assert net.quant.act_observer.enabled == True
assert net.linear.weight_observer.enabled == True
assert net.linear.act_observer.enabled == True
disable_observer(net)
assert net.quant.act_observer.enabled == False
assert net.linear.weight_observer.enabled == False
assert net.linear.act_observer.enabled == False


def test_enable_and_disable_fake_quant():
net = init_qat_net()
disable_fake_quant(net)
assert net.quant.act_fake_quant.enabled == False
assert net.linear.weight_fake_quant.enabled == False
assert net.linear.act_fake_quant.enabled == False
enable_fake_quant(net)
assert net.quant.act_fake_quant.enabled == True
assert net.linear.weight_fake_quant.enabled == True
assert net.linear.act_fake_quant.enabled == True


def init_observer(module, data):
enable_observer(module)
disable_fake_quant(module)
module(data)
disable_observer(module)
enable_fake_quant(module)


def test_enable_and_disable_all():
x = tensor(np.random.randint(1, 10, size=(3, 3)).astype(np.float32))
net = Net()
y1 = net(x).numpy()
net = quantize_qat(net, min_max_fakequant_qconfig)

init_observer(net, x)

y2 = net(x).numpy()
disable_fake_quant(net)
y3 = net(x).numpy()
enable_fake_quant(net)
y4 = net(x).numpy()
np.testing.assert_allclose(y1, y3)
np.testing.assert_allclose(y2, y4)
with pytest.raises(AssertionError):
np.testing.assert_allclose(y2, y3)


def test_quantize_qat():
net = Net()
qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig)
assert isinstance(qat_net.quant, QAT.QuantStub)
assert isinstance(qat_net.linear, QAT.Linear)
assert isinstance(qat_net.dequant, QAT.DequantStub)


def test_quantize():
qat_net = init_qat_net()
q_net = quantize(qat_net, inplace=False)
assert isinstance(q_net.quant, Q.QuantStub)
assert isinstance(q_net.linear, Q.Linear)
assert isinstance(q_net.dequant, Q.DequantStub)


def test_apply_easy_quant():
qat_net = init_qat_net()
data = tensor(np.random.rand(2, 3, 3, 3), dtype=np.float32)
eq_net = reset_qconfig(qat_net, passive_qconfig, inplace=False)
apply_easy_quant(eq_net, data, 0.9, 1.1, 10)
assert isinstance(eq_net.quant.act_observer, PassiveObserver)
assert isinstance(eq_net.linear.weight_observer, PassiveObserver)
assert isinstance(eq_net.linear.act_observer, PassiveObserver)
assert eq_net.dequant.act_observer is None


def test_apply_tqt():
qat_net = init_qat_net()
tqt_net = reset_qconfig(qat_net, tqt_qconfig, inplace=False)
assert isinstance(tqt_net.quant.act_fake_quant, TQT)
assert isinstance(tqt_net.linear.weight_fake_quant, TQT)
assert isinstance(tqt_net.linear.act_fake_quant, TQT)
assert tqt_net.dequant.act_fake_quant is None


def test_get_quantable_module_names():
# need to make sure names from Quantized and QAT are the same
def _get_qat_module_names():
def is_qat(key: str):
value = getattr(QAT, key)
return (
isinstance(value, type)
and issubclass(value, QAT.QATModule)
and value != QAT.QATModule
)

# source should have all quantable modules' names
quantable_module_names = [key for key in dir(QAT) if is_qat(key)]
return quantable_module_names

qat_module_names = _get_qat_module_names()
quantized_module_names = _get_quantable_module_names()
assert set(qat_module_names) == set(quantized_module_names)

for key in qat_module_names:
value = getattr(Float, key)
assert (
isinstance(value, type)
and issubclass(value, Float.Module)
and value != Float.Module
)


def test_disable_quantize():
class Net(Float.Module):
def __init__(self):
super().__init__()
self.conv = Float.ConvBnRelu2d(3, 3, 3)
self.conv.disable_quantize()

def forward(self, x):
return self.conv(x)

net = Net()
qat_net = quantize_qat(net, inplace=False)
assert isinstance(qat_net.conv, Float.ConvBnRelu2d)
assert isinstance(qat_net.conv.conv, Float.Conv2d)


def test_convert_with_custom_mapping():
class FloatExample(Float.Module):
def forward(self, x):
return x

class QATExample(QAT.QATModule):
def forward(self, x):
return x

@classmethod
def from_float_module(cls, float_module):
return cls()

class Net(Float.Module):
def __init__(self):
super().__init__()
self.example = FloatExample()

def forward(self, x):
return self.example(x)

net = Net()
qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample})
assert isinstance(qat_net.example, QATExample)

Loading…
Cancel
Save