Browse Source

refactor(mge/quantization): split `QATModule` and refactor convert api

GitOrigin-RevId: 80cfb12d10
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
caf1fac251
27 changed files with 735 additions and 480 deletions
  1. +4
    -4
      python_module/megengine/functional/nn.py
  2. +1
    -1
      python_module/megengine/module/__init__.py
  3. +4
    -9
      python_module/megengine/module/concat.py
  4. +11
    -159
      python_module/megengine/module/conv_bn_relu.py
  5. +4
    -9
      python_module/megengine/module/elemwise.py
  6. +2
    -13
      python_module/megengine/module/linear.py
  7. +0
    -96
      python_module/megengine/module/module.py
  8. +13
    -0
      python_module/megengine/module/qat/__init__.py
  9. +30
    -0
      python_module/megengine/module/qat/concat.py
  10. +193
    -0
      python_module/megengine/module/qat/conv_bn_relu.py
  11. +29
    -0
      python_module/megengine/module/qat/elemwise.py
  12. +37
    -0
      python_module/megengine/module/qat/linear.py
  13. +96
    -0
      python_module/megengine/module/qat/module.py
  14. +45
    -0
      python_module/megengine/module/qat/quant_dequant.py
  15. +7
    -13
      python_module/megengine/module/quant_dequant.py
  16. +1
    -0
      python_module/megengine/module/quantized/__init__.py
  17. +11
    -16
      python_module/megengine/module/quantized/concat.py
  18. +32
    -36
      python_module/megengine/module/quantized/conv_bn_relu.py
  19. +11
    -18
      python_module/megengine/module/quantized/elemwise.py
  20. +17
    -24
      python_module/megengine/module/quantized/linear.py
  21. +31
    -0
      python_module/megengine/module/quantized/module.py
  22. +23
    -29
      python_module/megengine/module/quantized/quant_dequant.py
  23. +0
    -9
      python_module/megengine/quantization/__init__.py
  24. +3
    -7
      python_module/megengine/quantization/qconfig.py
  25. +82
    -25
      python_module/megengine/quantization/quantize.py
  26. +10
    -12
      python_module/test/unit/module/test_conv_bn_relu.py
  27. +38
    -0
      python_module/test/unit/quantization/quantize.py

+ 4
- 4
python_module/megengine/functional/nn.py View File

@@ -27,10 +27,10 @@ from .utils import _decide_comp_node_and_comp_graph
def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor:
"""Applies a linear transformation to the input.

Refer to :class:`~.Linear` for more information.
Refer to :class:`~.module.linear.Linear` for more information.

:param inp: the input tensor with shape `(N, in_features)`.
:param weight: the weight with shape `(out_features, in_features)`.
:param weight: the weight with shape `(out_features, in_features)`.
:param bias: the bias with shape `(out_features,)`.
Default: ``None``
"""
@@ -300,9 +300,9 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor:
def softplus(inp: Tensor, beta: float = 1, threshold: float = 20) -> Tensor:
r"""
Performs the elementwise function:
.. math::
\mathsf{softplus}(x) = \log(1+\exp(\beta x)) / \beta.

For numerical stability the identity function is used when :math:`\beta x > \textrm{threshold}`.


+ 1
- 1
python_module/megengine/module/__init__.py View File

@@ -16,7 +16,7 @@ from .elemwise import Elemwise
from .embedding import Embedding
from .identity import Identity
from .linear import Linear
from .module import Module, QATModule
from .module import Module
from .parampack import ParamPack
from .pooling import AvgPool2d, MaxPool2d
from .quant_dequant import DequantStub, QuantStub


+ 4
- 9
python_module/megengine/module/concat.py View File

@@ -9,19 +9,14 @@ from typing import Iterable

from .. import functional as F
from ..core.tensor import Tensor
from .module import QATModule
from .module import Module


class Concat(QATModule):
class Concat(Module):
r"""
A :class:`~.QATModule` to do functional concat, should replace concat with this module,
supporting ``qat`` mode and ``quantized`` mode.
A :class:`~.Module` to do functional concat. Could be replaced with :class:`~.QATModule`
version :class:`~.qat.concat.Concat` using :func:`~.quantize.quantize_qat`.
"""

def forward(self, inps: Iterable[Tensor], axis: int = 0):
return F.concat(inps, axis)

def forward_qat(self, inps: Iterable[Tensor], axis: int = 0):
return self.apply_fakequant_with_observer(
self.forward(inps, axis), self.act_fake_quant, self.act_observer
)

+ 11
- 159
python_module/megengine/module/conv_bn_relu.py View File

@@ -7,14 +7,13 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Tuple, Union

from ..core import ones, zeros
from ..functional import add_update, flatten, relu, sqrt, sum, zero_grad
from ..functional import relu
from .batchnorm import BatchNorm2d
from .conv import Conv2d
from .module import QATModule
from .module import Module


class _ConvBn2d(QATModule):
class _ConvBnActivation2d(Module):
def __init__(
self,
in_channels: int,
@@ -47,171 +46,24 @@ class _ConvBn2d(QATModule):
)
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats)

def get_batch_mean_var(self, inp):
def _sum_channel(inp, axis=0, keepdims=True):
if isinstance(axis, int):
out = sum(inp, axis=axis, keepdims=keepdims)
elif isinstance(axis, tuple):
for idx, elem in enumerate(axis):
out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims)
return out

sum1 = _sum_channel(inp, (0, 2, 3))
sum2 = _sum_channel(inp ** 2, (0, 2, 3))
reduce_size = inp.shapeof().prod() / inp.shapeof(1)
batch_mean = sum1 / reduce_size
batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size
return batch_mean, batch_var

def fold_weight_bias(self, bn_mean, bn_var):
# get fold bn conv param
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
# b_fold = gamma * (b - bn_mean) / bn_std + beta
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)

if bn_mean is None:
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
if bn_var is None:
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32")

conv_bias = self.conv.bias
if conv_bias is None:
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")

bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor = gamma * bn_istd
if self.conv.groups == 1:
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
else:
w_fold = self.conv.weight * scale_factor.reshape(
self.conv.groups, -1, 1, 1, 1
)

# b_fold = gamma * (b - bn_mean) / bn_std + beta
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
return w_fold, b_fold

def update_running_mean_and_running_var(
self, bn_mean, bn_var, num_elements_per_channel
):
# update running mean and running var. no grad, use unbiased bn var
bn_mean = zero_grad(bn_mean)
bn_var = (
zero_grad(bn_var)
* num_elements_per_channel
/ (num_elements_per_channel - 1)
)
exponential_average_factor = 1 - self.bn.momentum
add_update(
self.bn.running_mean,
delta=bn_mean,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)
add_update(
self.bn.running_var,
delta=bn_var,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)

def calc_conv_bn_qat(self, inp, approx=True):
if self.training and not approx:
conv = self.conv(inp)
bn_mean, bn_var = self.get_batch_mean_var(conv)
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
else:
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var

# get gamma and beta in BatchNorm
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)
# conv_bias
conv_bias = self.conv.bias
if conv_bias is None:
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")

bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor = gamma * bn_istd
if self.conv.groups == 1:
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
else:
w_fold = self.conv.weight * scale_factor.reshape(
self.conv.groups, -1, 1, 1, 1
)
b_fold = None
if not (self.training and approx):
# b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd

w_qat = self.apply_fakequant_with_observer(
w_fold, self.weight_fake_quant, self.weight_observer
)
conv = self.conv.calc_conv(inp, w_qat, b_fold)
if not (self.training and approx):
return conv

# rescale conv to get original conv output
orig_conv = conv / scale_factor.reshape(1, -1, 1, 1)
if self.conv.bias is not None:
orig_conv = orig_conv + self.conv.bias
# calculate batch norm
bn_mean, bn_var = self.get_batch_mean_var(orig_conv)
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
conv = gamma * bn_istd * (orig_conv - bn_mean) + beta
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
return conv


class ConvBn2d(_ConvBn2d):
class ConvBn2d(_ConvBnActivation2d):
r"""
A fused :class:`~.QATModule` including Conv2d and BatchNorm2d, supporting ``qat`` mode
and ``normal`` mode.
A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced
with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBn2d` using
:func:`~.quantize.quantize_qat`.
"""

def forward_qat(self, inp):
return self.apply_fakequant_with_observer(
self.calc_conv_bn_qat(inp), self.act_fake_quant, self.act_observer
)

def forward(self, inp):
return self.bn(self.conv(inp))


class ConvBnRelu2d(_ConvBn2d):
class ConvBnRelu2d(_ConvBnActivation2d):
r"""
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu, supporting ``qat``
mode and ``normal`` mode.
A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced
with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBnRelu2d` using
:func:`~.quantize.quantize_qat`.
"""

def forward_qat(self, inp):
return self.apply_fakequant_with_observer(
relu(self.calc_conv_bn_qat(inp)), self.act_fake_quant, self.act_observer
)

def forward(self, inp):
return relu(self.bn(self.conv(inp)))

+ 4
- 9
python_module/megengine/module/elemwise.py View File

@@ -8,7 +8,7 @@
from .. import _internal as mgb
from ..core import Tensor, wrap_io_tensor
from ..core.graph import _use_default_if_none
from .module import QATModule
from .module import Module


@wrap_io_tensor
@@ -22,10 +22,10 @@ def _elemwise_func(mode, *inputs, **kwargs) -> Tensor:
return mgb.opr.elemwise(*inputs, mode=mode, **kwargs)


class Elemwise(QATModule):
class Elemwise(Module):
r"""
A :class:`~.QATModule` to do elemwise operator, should functional operator with this module,
supporting ``qat`` mode and ``normal`` mode.
A :class:`~.Module` to do elemwise operator. Could be replaced with :class:`~.QATModule`
version :class:`~.qat.elemwise.Elemwise` using :func:`~.quantize.quantize_qat`.

:param method: the elemwise method, support the following string.
It will do the normal elemwise operator for float.
@@ -88,8 +88,3 @@ class Elemwise(QATModule):

def forward(self, *inps):
return _elemwise_func(self.method, *inps)

def forward_qat(self, *inps):
return self.apply_fakequant_with_observer(
self.forward(*inps), self.act_fake_quant, self.act_observer,
)

+ 2
- 13
python_module/megengine/module/linear.py View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
@@ -11,10 +10,10 @@ import numpy as np
from .. import functional as F
from ..core import Parameter
from . import init
from .module import QATModule
from .module import Module


class Linear(QATModule):
class Linear(Module):
r"""Applies a linear transformation to the input. For instance, if input
is x, then output y is:

@@ -60,13 +59,3 @@ class Linear(QATModule):

def forward(self, x):
return self._calc_linear(x, self.weight, self.bias)

def forward_qat(self, x):
w_qat = self.apply_fakequant_with_observer(
self.weight, self.weight_fake_quant, self.weight_observer
)
return self.apply_fakequant_with_observer(
self._calc_linear(x, w_qat, self.bias),
self.act_fake_quant,
self.act_observer,
)

+ 0
- 96
python_module/megengine/module/module.py View File

@@ -7,7 +7,6 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from enum import Enum
from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union

import numpy as np
@@ -443,98 +442,3 @@ class Module(metaclass=ABCMeta):
loaded.append(k)

return set(loaded), set(skipped)


class QATModule(Module):
r"""
Base class of quantization related Module. Add extra forward methods
:meth:`~.QATModule.forward_qat` and :meth:`~.QATModule.forward_quantized` for
``qat``(quantization aware training) mode and ``quantized`` mode respectively.

Use :meth:`~.QATModule.quant` to switch between ``QAT`` and ``NORMAL`` mode,
and use :meth:`~.QATModule.to_quantized` to switch to ``quantized`` mode,
which is irreversible.

If you want to recursively switch mode for all QATModule in network, use
functions in :mod:`~.quantization.quantize`.
"""

class QATMode(Enum):
DISABLED = 1
QAT = 2
CALIBRATION = 3

def __init__(self):
from ..quantization import (
QConfig,
FakeQuantize,
Observer,
) # pylint: disable=all

super().__init__()

self.quantizing = self.QATMode.DISABLED
self.scale = None

self.weight_observer = None # type: Observer
self.act_observer = None # type: Observer

self.weight_fake_quant = None # type: FakeQuantize
self.act_fake_quant = None # type: FakeQuantize

def set_qconfig(self, qconfig: "QConfig"):
self.weight_observer = qconfig.weight_observer()
self.act_observer = qconfig.act_observer()

self.weight_fake_quant = (
None
if qconfig.fake_quant is None
else qconfig.fake_quant(self.weight_observer.dtype)
)
self.act_fake_quant = (
None
if qconfig.fake_quant is None
else qconfig.fake_quant(self.act_observer.dtype)
)

def apply_observer(self, target: Tensor, obs: "Observer"):
return obs(target)

def apply_fakequant_with_observer(
self, target: Tensor, fq: "FakeQuantize", obs: "Observer"
):
oup = self.apply_observer(target, obs)
if fq is not None:
q_dict = obs.get_qparams()
oup = fq(oup, q_dict)
return oup

def set_qat_mode(self, mode: QATMode):
r"""
Change ``self.quantizing`` mode, available values: ``self.QATMode.DISABLED``,
``QAT``,``CALIBRATION``.
"""
if not isinstance(mode, self.QATMode):
raise TypeError("mode must be QATMode Enum type")
self.quantizing = mode

def to_quantized(self):
r"""
Return a new :class:`~.Module` with quantized parameters of ``self``
according to scale and zero_point in ``self.xxx_observer``.
"""
raise NotImplementedError(
"Use megengine.quantization.quantize to register the method."
)

@abstractmethod
def forward_qat(self, *args, **kwargs):
r"""
Forward method for ``qat`` mode.
"""

def __call__(self, *args, **kwargs):
if self.quantizing == self.QATMode.DISABLED:
return self.forward(*args, **kwargs)
else:
return self.forward_qat(*args, **kwargs)

+ 13
- 0
python_module/megengine/module/qat/__init__.py View File

@@ -0,0 +1,13 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .concat import Concat
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d
from .elemwise import Elemwise
from .linear import Linear
from .module import QATModule
from .quant_dequant import DequantStub, QuantStub

+ 30
- 0
python_module/megengine/module/qat/concat.py View File

@@ -0,0 +1,30 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable

from ...core.tensor import Tensor
from .. import concat as Float
from .module import QATModule


class Concat(Float.Concat, QATModule):
r"""
A :class:`~.QATModule` to do functional concat with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""

def forward(self, inps: Iterable[Tensor], axis: int = 0):
return self.apply_quant_activation(super().forward(inps, axis))

@classmethod
def from_float_module(cls, float_module):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return cls()

+ 193
- 0
python_module/megengine/module/qat/conv_bn_relu.py View File

@@ -0,0 +1,193 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ...core import ones, zeros
from ...functional import add_update, relu, sqrt, sum, zero_grad
from .. import conv_bn_relu as Float
from .module import QATModule


class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule):
def get_batch_mean_var(self, inp):
def _sum_channel(inp, axis=0, keepdims=True):
if isinstance(axis, int):
out = sum(inp, axis=axis, keepdims=keepdims)
elif isinstance(axis, tuple):
for idx, elem in enumerate(axis):
out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims)
return out

sum1 = _sum_channel(inp, (0, 2, 3))
sum2 = _sum_channel(inp ** 2, (0, 2, 3))
reduce_size = inp.shapeof().prod() / inp.shapeof(1)
batch_mean = sum1 / reduce_size
batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size
return batch_mean, batch_var

def fold_weight_bias(self, bn_mean, bn_var):
# get fold bn conv param
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
# b_fold = gamma * (b - bn_mean) / bn_std + beta
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)

if bn_mean is None:
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
if bn_var is None:
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32")

conv_bias = self.conv.bias
if conv_bias is None:
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")

bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor = gamma * bn_istd
if self.conv.groups == 1:
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
else:
w_fold = self.conv.weight * scale_factor.reshape(
self.conv.groups, -1, 1, 1, 1
)

# b_fold = gamma * (b - bn_mean) / bn_std + beta
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd
return w_fold, b_fold

def update_running_mean_and_running_var(
self, bn_mean, bn_var, num_elements_per_channel
):
# update running mean and running var. no grad, use unbiased bn var
bn_mean = zero_grad(bn_mean)
bn_var = (
zero_grad(bn_var)
* num_elements_per_channel
/ (num_elements_per_channel - 1)
)
exponential_average_factor = 1 - self.bn.momentum
add_update(
self.bn.running_mean,
delta=bn_mean,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)
add_update(
self.bn.running_var,
delta=bn_var,
alpha=1 - exponential_average_factor,
beta=exponential_average_factor,
)

def calc_conv_bn_qat(self, inp, approx=True):
if self.training and not approx:
conv = self.conv(inp)
bn_mean, bn_var = self.get_batch_mean_var(conv)
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
else:
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var

# get gamma and beta in BatchNorm
gamma = self.bn.weight
if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
beta = self.bn.bias
if beta is None:
beta = zeros((self.bn.num_features), dtype="float32")
beta = beta.reshape(1, -1, 1, 1)
# conv_bias
conv_bias = self.conv.bias
if conv_bias is None:
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32")

bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
# bn_istd = 1 / bn_std
# w_fold = gamma / bn_std * W
scale_factor = gamma * bn_istd
if self.conv.groups == 1:
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1)
else:
w_fold = self.conv.weight * scale_factor.reshape(
self.conv.groups, -1, 1, 1, 1
)
b_fold = None
if not (self.training and approx):
# b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd

w_qat = self.apply_quant_weight(w_fold)
conv = self.conv.calc_conv(inp, w_qat, b_fold)
if not (self.training and approx):
return conv

# rescale conv to get original conv output
orig_conv = conv / scale_factor.reshape(1, -1, 1, 1)
if self.conv.bias is not None:
orig_conv = orig_conv + self.conv.bias
# calculate batch norm
bn_mean, bn_var = self.get_batch_mean_var(orig_conv)
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
conv = gamma * bn_istd * (orig_conv - bn_mean) + beta
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1)
self.update_running_mean_and_running_var(
bn_mean, bn_var, num_elements_per_channel
)
return conv

@classmethod
def from_float_module(cls, float_module: Float._ConvBnActivation2d):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
qat_module = cls(
float_module.conv.in_channels,
float_module.conv.out_channels,
float_module.conv.kernel_size,
float_module.conv.stride,
float_module.conv.padding,
float_module.conv.dilation,
float_module.conv.groups,
bool(float_module.conv.bias),
float_module.conv.conv_mode.name,
float_module.conv.compute_mode.name,
)
qat_module.conv.weight = float_module.conv.weight
qat_module.conv.bias = float_module.conv.bias
qat_module.bn = float_module.bn
return qat_module


class ConvBn2d(_ConvBnActivation2d):
r"""
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""

def forward(self, inp):
return self.apply_quant_activation(self.calc_conv_bn_qat(inp))


class ConvBnRelu2d(_ConvBnActivation2d):
r"""
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.
"""

def forward(self, inp):
return self.apply_quant_activation(relu(self.calc_conv_bn_qat(inp)))

+ 29
- 0
python_module/megengine/module/qat/elemwise.py View File

@@ -0,0 +1,29 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .. import elemwise as Float
from .module import QATModule


class Elemwise(Float.Elemwise, QATModule):
r"""
A :class:`~.QATModule` to do elemwise operator with QAT support.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.

:param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail.
"""

def forward(self, *inps):
return self.apply_quant_activation(super().forward(*inps))

@classmethod
def from_float_module(cls, float_module: Float.Elemwise):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return cls(float_module.method.name)

+ 37
- 0
python_module/megengine/module/qat/linear.py View File

@@ -0,0 +1,37 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .. import linear as Float
from .module import QATModule


class Linear(Float.Linear, QATModule):
r"""
A :class:`~.QATModule` version of :class:`~.module.linear.Linear`.
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`.

:param in_features: size of each input sample.
:param out_features: size of each output sample.
:param bias: If set to ``False``, the layer will not learn an additive bias.
Default: ``True``

"""

def forward(self, x):
w_qat = self.apply_quant_weight(self.weight)
return self.apply_quant_activation(self._calc_linear(x, w_qat, self.bias),)

@classmethod
def from_float_module(cls, float_module: Float.Linear):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
qmod = cls(float_module.in_features, float_module.out_features)
qmod.weight = float_module.weight
qmod.bias = float_module.bias
return qmod

+ 96
- 0
python_module/megengine/module/qat/module.py View File

@@ -0,0 +1,96 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import abstractmethod

from ...core import Tensor
from ...quantization import FakeQuantize, Observer, QConfig
from ..module import Module


class QATModule(Module):
r"""
Base class of quantized-float related Module, basically for QAT and Calibration.

Use :meth:`~.QATModule.from_float_module` to generate a instance from float :class:`~.Module`.
Or use :func:`~.quantize.quantize_qat` to do it recursively and automatically.

Can also be converted to :class:`~.QuantizedModule` for deployment using
:func:`~.quantize.quantize` further.
"""

def __init__(self):
super().__init__()

self.scale = None

self.weight_observer = None # type: Observer
self.act_observer = None # type: Observer

self.weight_fake_quant = None # type: FakeQuantize
self.act_fake_quant = None # type: FakeQuantize

def set_qconfig(self, qconfig: QConfig):
r"""
Set quantization related configs with ``qconfig``, including
observer and fake_quant for weight and activation.
"""
self.weight_observer = qconfig.weight_observer()
self.act_observer = qconfig.act_observer()

if qconfig.fake_quant is None:
self.weight_fake_quant = None
self.act_fake_quant = None
else:
self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype)
self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype)

def _apply_fakequant_with_observer(
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer
):
oup = observer(target)
if fake_quant is None:
return oup
else:
q_dict = observer.get_qparams()
return fake_quant(oup, q_dict)

def apply_quant_weight(self, target: Tensor):
r"""
Apply weight's observer and fake_quant from ``qconfig`` on ``target``.
"""
return self._apply_fakequant_with_observer(
target, self.weight_fake_quant, self.weight_observer
)

def apply_quant_activation(self, target: Tensor):
r"""
Apply weight's observer and fake_quant from ``qconfig`` on ``target``.
"""
return self._apply_fakequant_with_observer(
target, self.act_fake_quant, self.act_observer
)

def get_weight_dtype(self):
r"""
Get weight's quantization dtype as the method from ``qconfig``.
"""
return self.weight_observer.get_dtype()

def get_activation_dtype(self):
r"""
Get activation's quantization dtype as the method from ``qconfig``.
"""
return self.act_observer.get_dtype()

@classmethod
@abstractmethod
def from_float_module(cls, float_module: Module):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""

+ 45
- 0
python_module/megengine/module/qat/quant_dequant.py View File

@@ -0,0 +1,45 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .. import quant_dequant as Float
from .module import QATModule


class QuantStub(Float.QuantStub, QATModule):
r"""
A helper QATModule simply return input, but will quantize
input after converted to :class:`~.QuantizedModule`.
"""

def forward(self, inp):
return self.apply_quant_activation(inp)

@classmethod
def from_float_module(cls, float_module: Float.QuantStub):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return cls()


class DequantStub(Float.DequantStub, QATModule):
r"""
A helper QATModule simply return input, but will de-quantize
input after converted to :class:`~.QuantizedModule`.
"""

def forward(self, inp):
return inp

@classmethod
def from_float_module(cls, float_module: Float.DequantStub):
r"""
Return a :class:`~.QATModule` instance converted from
a float :class:`~.Module` instance.
"""
return cls()

+ 7
- 13
python_module/megengine/module/quant_dequant.py View File

@@ -5,30 +5,24 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .module import QATModule
from .module import Module


class QuantStub(QATModule):
class QuantStub(Module):
r"""
A helper QATModule doing quantize operation on input.
A helper :class:`~.Module` simply returning input. Could be replaced with :class:`~.QATModule`
version :class:`~.qat.QuantStub` using :func:`~.quantize.quantize_qat`.
"""

def forward(self, inp):
return inp

def forward_qat(self, inp):
return self.apply_fakequant_with_observer(
inp, self.act_fake_quant, self.act_observer
)


class DequantStub(QATModule):
class DequantStub(Module):
r"""
A helper QATModule doing de-quantize operation on input.
A helper :class:`~.Module` simply returning input. Could be replaced with :class:`~.QATModule`
version :class:`~.qat.DequantStub` using :func:`~.quantize.quantize_qat`.
"""

def forward(self, inp):
return inp

def forward_qat(self, inp):
return inp

+ 1
- 0
python_module/megengine/module/quantized/__init__.py View File

@@ -9,4 +9,5 @@ from .concat import Concat
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d
from .elemwise import Elemwise
from .linear import Linear
from .module import QuantizedModule
from .quant_dequant import DequantStub, QuantStub

+ 11
- 16
python_module/megengine/module/quantized/concat.py View File

@@ -7,17 +7,15 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Iterable

from ... import _internal as mgb
from ... import functional as F
from ... import module as Float
from ...core.tensor import Tensor
from ...quantization.utils import register_method_to_class
from ..module import Module
from ..qat import concat as QAT
from .module import QuantizedModule


class Concat(Module):
class Concat(QuantizedModule):
r"""
A :class:`~.Module` to do quantized concat, inference only.
A :class:`~.QuantizedModule` to do quantized concat, inference only.
"""

def __init__(self, dtype=None):
@@ -25,16 +23,13 @@ class Concat(Module):
self.output_dtype = dtype

def forward(self, inps: Iterable[Tensor], axis: int = 0):
if self.training:
raise ValueError("quantized module only support inference.")
new_inps = (x.astype(self.output_dtype) for x in inps)
return F.concat(new_inps, axis)


@register_method_to_class(Float.Concat)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
return Concat(float_module.act_observer.get_dtype())
@classmethod
def from_qat_module(cls, qat_module: QAT.Concat):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return cls(qat_module.get_activation_dtype())

+ 32
- 36
python_module/megengine/module/quantized/conv_bn_relu.py View File

@@ -5,7 +5,6 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from functools import partial
from typing import Tuple, Union

import megengine._internal as mgb
@@ -13,11 +12,11 @@ import megengine._internal as mgb
from ... import module as Float
from ...core import Parameter
from ...functional import conv_bias_activation
from ...module import Conv2d
from ...quantization.utils import register_method_to_class
from ..qat import conv_bn_relu as QAT
from .module import QuantizedModule


class _ConvBnActivation2d(Conv2d):
class _ConvBnActivation2d(Float.Conv2d, QuantizedModule):
r"""Applies a 2D convolution over an quantized input tensor, inference only.

The parameter is same with :class: `~.Conv2d`
@@ -68,44 +67,41 @@ class _ConvBnActivation2d(Conv2d):
nonlinear_mode=nonlinear_mode,
)

@classmethod
def from_qat_module(cls, qat_module: QAT._ConvBnActivation2d):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
output_dtype = qat_module.get_activation_dtype()
qconv = cls(
qat_module.conv.in_channels,
qat_module.conv.out_channels,
qat_module.conv.kernel_size,
qat_module.conv.stride,
qat_module.conv.padding,
qat_module.conv.dilation,
qat_module.conv.groups,
dtype=output_dtype,
)
w_fold, b_fold = qat_module.fold_weight_bias(
qat_module.bn.running_mean, qat_module.bn.running_var
)
weight = w_fold.astype(qat_module.get_weight_dtype())
qconv.weight = Parameter(weight.numpy())
qconv.bias = Parameter(b_fold.numpy())
return qconv


class ConvBn2d(_ConvBnActivation2d):
r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBn2d`."""

def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY")


class ConvBnRelu2d(_ConvBnActivation2d):
r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBnRelu2d`."""

def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return self.calc_conv_quantized(inp, nonlinear_mode="RELU")


def to_quantized(quantized_class, float_module):
output_dtype = float_module.act_observer.get_dtype()
qconv = quantized_class(
float_module.conv.in_channels,
float_module.conv.out_channels,
float_module.conv.kernel_size,
float_module.conv.stride,
float_module.conv.padding,
float_module.conv.dilation,
float_module.conv.groups,
dtype=output_dtype,
)
w_fold, b_fold = float_module.fold_weight_bias(
float_module.bn.running_mean, float_module.bn.running_var
)
weight = w_fold.astype(float_module.weight_observer.get_dtype())
qconv.weight = Parameter(weight.numpy())
qconv.bias = Parameter(b_fold.numpy())

return qconv


# replace :class:`~.module.QATModule`'s ``to_quantized`` method.
# implemented here to avoid circular import.
register_method_to_class(Float.ConvBn2d)(partial(to_quantized, ConvBn2d))
register_method_to_class(Float.ConvBnRelu2d)(partial(to_quantized, ConvBnRelu2d))

+ 11
- 18
python_module/megengine/module/quantized/elemwise.py View File

@@ -6,11 +6,10 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import _internal as mgb
from ... import module as Float
from ...core import Tensor, wrap_io_tensor
from ...core.graph import _use_default_if_none
from ...quantization.utils import register_method_to_class
from ..module import Module
from ..qat import elemwise as QAT
from .module import QuantizedModule


@wrap_io_tensor
@@ -24,13 +23,8 @@ def _elemwise_multi_type(mode, *inputs, **kwargs) -> Tensor:
return mgb.opr.elemwise_multi_type(*inputs, mode=mode, **kwargs)


class Elemwise(Module):
r"""
quantized module for elemwise operator, inference only.

:param method: the elemwise method, supported string refer to :class:`~.module.elemwise.Elemwise`.
it will do quantized operator with specified output quantized dtype.
"""
class Elemwise(QuantizedModule):
r"""quantized version of :class:`~.qat.elemwise.Elemwise`."""

_elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode

@@ -44,11 +38,10 @@ class Elemwise(Module):
raise ValueError("quantized module only support inference.")
return _elemwise_multi_type(self.method, *inps, dtype=self.output_dtype)


@register_method_to_class(Float.Elemwise)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
return Elemwise(float_module.method.name, float_module.act_observer.get_dtype())
@classmethod
def from_qat_module(cls, qat_module: QAT.Elemwise):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return cls(qat_module.method.name, qat_module.get_activation_dtype())

+ 17
- 24
python_module/megengine/module/quantized/linear.py View File

@@ -10,19 +10,13 @@ import numpy as np
import megengine._internal as mgb

from ... import functional as F
from ... import module as Float
from ...core import Parameter
from ...quantization.utils import register_method_to_class
from ..module import Module
from ..qat import linear as QAT
from .module import QuantizedModule


class Linear(Module):
r"""Applies a quantized linear transformation to the input. The module
usually convert from QAT module by to_quantized method.

:param dtype: output data type.

"""
class Linear(QuantizedModule):
r"""quantized version of :class:`~.qat.linear.Linear`."""

def __init__(
self, dtype: np.dtype = None,
@@ -44,17 +38,16 @@ class Linear(Module):
None if self.bias is None else self.bias.astype(bias_dtype),
).astype(self.output_dtype)


@register_method_to_class(Float.Linear)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
output_dtype = float_module.act_observer.get_dtype()
qmod = Linear(dtype=output_dtype,)
weight = float_module.weight.astype(float_module.weight_observer.get_dtype())
qmod.weight = Parameter(weight.numpy())
if float_module.bias is not None:
qmod.bias = Parameter(float_module.bias.numpy())
return qmod
@classmethod
def from_qat_module(cls, qat_module: QAT.Linear):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
output_dtype = qat_module.get_activation_dtype()
qmod = cls(dtype=output_dtype)
weight = qat_module.weight.astype(qat_module.get_weight_dtype())
qmod.weight = Parameter(weight.numpy())
if qat_module.bias is not None:
qmod.bias = Parameter(qat_module.bias.numpy())
return qmod

+ 31
- 0
python_module/megengine/module/quantized/module.py View File

@@ -0,0 +1,31 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from abc import abstractmethod

from ..module import Module
from ..qat import QATModule


class QuantizedModule(Module):
r"""
Base class of quantized Module, which should be converted from QATModule
and not support traning.
"""

def __call__(self, *inputs, **kwargs):
if self.training:
raise ValueError("quantized module only support inference.")
return super().__call__(*inputs, **kwargs)

@classmethod
@abstractmethod
def from_qat_module(cls, qat_module: QATModule):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""

+ 23
- 29
python_module/megengine/module/quantized/quant_dequant.py View File

@@ -5,15 +5,14 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from ... import _internal as mgb
from ... import module as Float
from ...quantization.utils import register_method_to_class
from ..module import Module
from ..qat import quant_dequant as QAT
from .module import QuantizedModule


class QuantStub(Module):
class QuantStub(QuantizedModule):
r"""
A helper quantize operation on input and inference only.
quantized version of :class:`~.qat.quant_dequant.QuantStub`,
will convert input to quantized dtype.
"""

def __init__(self, dtype=None):
@@ -21,35 +20,30 @@ class QuantStub(Module):
self.output_dtype = dtype

def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return inp.astype(self.output_dtype)

@classmethod
def from_qat_module(cls, qat_module: QAT.QuantStub):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return cls(qat_module.get_activation_dtype())

class DequantStub(Module):

class DequantStub(QuantizedModule):
r"""
A helper de-quantize operation and inference only.
quantized version of :class:`~.qat.quant_dequant.DequantStub`,
will restore quantized input to float32 dtype.
"""

def forward(self, inp):
if self.training:
raise ValueError("quantized module only support inference.")
return inp.astype("float32")


@register_method_to_class(Float.QuantStub)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
return QuantStub(float_module.act_observer.get_dtype())


@register_method_to_class(Float.DequantStub)
def to_quantized(float_module):
r"""
Replace :class:`~.module.QATModule`'s ``to_quantized`` method.
implemented here to avoid circular import.
"""
return DequantStub()
@classmethod
def from_qat_module(cls, qat_module: QAT.DequantStub):
r"""
return a :class:`~.QuantizedModule` instance converted from a
:class:`~.QATModule` instance.
"""
return cls()

+ 0
- 9
python_module/megengine/quantization/__init__.py View File

@@ -13,12 +13,3 @@ from .qconfig import (
ema_fakequant_qconfig,
min_max_fakequant_qconfig,
)
from .quantize import (
disable_fake_quant,
disable_observer,
enable_fake_quant,
enable_observer,
quantize,
quantize_calibration,
quantize_qat,
)

+ 3
- 7
python_module/megengine/quantization/qconfig.py View File

@@ -15,16 +15,12 @@ from .observer import (


class QConfig:
"""
r"""
A config class indicating how to do quantize toward :class:`~.QATModule`'s
``activation`` and ``weight``.

And ``fake_quant`` parameter to indicate

See :meth:`~.QATModule.set_qconfig` for detail usage.
``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage.

:param weight_observer: interface to instantiate an :class:`~.Observer` indicating
- how to collect scales and zero_point of wegiht.
how to collect scales and zero_point of wegiht.
:param act_observer: similar to ``weight_observer`` but toward activation.
:param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating
how to do fake_quant calculation. can be invoked multi times to get different


+ 82
- 25
python_module/megengine/quantization/quantize.py View File

@@ -6,68 +6,125 @@
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from copy import deepcopy

from ..module import Module, QATModule, Sequential, quantized
from typing import Dict, Tuple

from .. import module as Float
from ..module import Module
from ..module import qat as QAT
from ..module import quantized as Quantized
from ..module.qat import QATModule
from ..module.quantized import QuantizedModule
from .qconfig import QConfig, ema_fakequant_qconfig


def _get_quantable_module_names():
def is_quantable(key: str):
value = getattr(Quantized, key)
return (
isinstance(value, type)
and issubclass(value, QuantizedModule)
and value != QuantizedModule
)

# source should have all quantable modules' names
quantable_module_names = [key for key in dir(Quantized) if is_quantable(key)]
return quantable_module_names


def _get_convert_dict() -> Tuple[
Dict[Module, QATModule], Dict[QATModule, QuantizedModule]
]:
quantable_module_names = _get_quantable_module_names()

quantable_modules = [getattr(Float, key) for key in quantable_module_names]
qat_modules = [getattr(QAT, key) for key in quantable_module_names]
quantized_modules = [getattr(Quantized, key) for key in quantable_module_names]

float2qat_dict = dict(zip(quantable_modules, qat_modules))
qat2quantized_dict = dict(zip(qat_modules, quantized_modules))
return float2qat_dict, qat2quantized_dict


_float2qat_dict, _qat2quantized_dict = _get_convert_dict()


def quantize(module: Module, inplace=True):
r"""
Recursively convert `module` to `quantized` mode through :meth:`~.Module.apply`.
Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule`
through :meth:`~.Module.apply`.

:param module: root module to do convert recursively.
:param inplace: whether to convert submodules in-place.
"""

if not inplace:
module = deepcopy(module)

def is_qat_module(obj):
return isinstance(obj, QATModule)
qat_modules = tuple(_qat2quantized_dict.keys())

def is_qat(mod: Module):
return isinstance(mod, qat_modules)

# no need to pass prefix and get pure key of parent Module.
for key, submodule, parent in module._flatten(
with_key=True, with_parent=True, predicate=is_qat_module
with_key=True, with_parent=True, predicate=is_qat
):
if isinstance(parent, Sequential):
new_mod = _qat2quantized_dict[type(submodule)].from_qat_module(submodule)
if isinstance(parent, Float.Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent[int(key.split(".")[-1])] = submodule.to_quantized()
parent[int(key.split(".")[-1])] = new_mod
else:
setattr(parent, key.split(".")[-1], submodule.to_quantized())
setattr(parent, key.split(".")[-1], new_mod)

return module


def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig):
def quantize_qat(
module: Module, inplace=True, qconfig: QConfig = ema_fakequant_qconfig
):
r"""
Recursively convert `module` to `qat` mode through :meth:`~.Module.apply`
and set qconfig relatively.
Recursively convert float :class:`~.Module` to :class:`~.QATModule`
through :meth:`~.Module.apply` and set qconfig relatively.

:param module: root module to do convert recursively.
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is :any:`~.qconfig.ema_fakequant_qconfig`.
:param inplace: whether to convert submodules in-place.
:param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is ``ema_fakequant_qconfig``.
"""

def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_qat_mode(QATModule.QATMode.QAT)
mod.set_qconfig(qconfig)
if not inplace:
module = deepcopy(module)

module.apply(fn)
quantable_modules = tuple(_float2qat_dict.keys())

def is_quantable(mod: Module):
return isinstance(mod, quantable_modules)

# no need to pass prefix and get pure key of parent Module.
for key, submodule, parent in module._flatten(
with_key=True, with_parent=True, predicate=is_quantable
):
new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule)
if isinstance(parent, Float.Sequential):
# cannnot use setattr to be compatible with Sequential's ``__setitem__``
parent[int(key.split(".")[-1])] = new_mod
else:
setattr(parent, key.split(".")[-1], new_mod)

propagate_qconfig(module, qconfig)
return module


def quantize_calibration(module: Module, qconfig: QConfig = ema_fakequant_qconfig):
def propagate_qconfig(module: QATModule, qconfig: QConfig):
r"""
Recursively convert `module` to `calibration` mode through :meth:`~.Module.apply`
and set qconfig relatively.
Recursively set ``module``'s qconfig through :meth:`~.Module.apply`.

:param module: root module to do convert recursively.
:param module: root module to traverse recursively.
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig.
default is :any:`~.qconfig.ema_fakequant_qconfig`.
"""

def fn(mod: Module):
if isinstance(mod, QATModule):
mod.set_qat_mode(QATModule.QATMode.CALIBRATION)
mod.set_qconfig(qconfig)

module.apply(fn)


+ 10
- 12
python_module/test/unit/module/test_conv_bn_relu.py View File

@@ -5,8 +5,7 @@ import numpy as np

from megengine import tensor
from megengine.module import ConvBn2d
from megengine.quantization import quantize_qat
from megengine.quantization.quantize import disable_fake_quant
from megengine.quantization.quantize import disable_fake_quant, quantize_qat
from megengine.test import assertTensorClose


@@ -14,18 +13,17 @@ def test_convbn2d():
in_channels = 32
out_channels = 64
kernel_size = 3
module = ConvBn2d(in_channels, out_channels, kernel_size)
quantize_qat(module)
for groups, bias in product([1, 4], [True, False]):
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
module = ConvBn2d(
in_channels, out_channels, kernel_size, groups=groups, bias=bias
)
module.train()
qat_module = copy.deepcopy(module)
qat_module = quantize_qat(module, inplace=False)
disable_fake_quant(qat_module)
normal_outputs = module.forward(inputs)
qat_outputs = qat_module.forward_qat(inputs)
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
normal_outputs = module(inputs)
qat_outputs = qat_module(inputs)
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)
a = module.bn.running_mean.numpy()
b = qat_module.bn.running_mean.numpy()
assertTensorClose(
module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8
)
@@ -33,7 +31,7 @@ def test_convbn2d():
module.bn.running_var, qat_module.bn.running_var, max_err=5e-7
)
module.eval()
normal_outputs = module.forward(inputs)
normal_outputs = module(inputs)
qat_module.eval()
qat_outputs = qat_module.forward_qat(inputs)
qat_outputs = qat_module(inputs)
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)

+ 38
- 0
python_module/test/unit/quantization/quantize.py View File

@@ -0,0 +1,38 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from megengine import module as Float
from megengine.module import qat as QAT
from megengine.quantization.quantize import _get_quantable_module_names


def test_get_quantable_module_names():
# need to make sure names from Quantized and QAT are the same
def _get_qat_module_names():
def is_qat(key: str):
value = getattr(QAT, key)
return (
isinstance(value, type)
and issubclass(value, QAT.QATModule)
and value != QAT.QATModule
)

# source should have all quantable modules' names
quantable_module_names = [key for key in dir(QAT) if is_qat(key)]
return quantable_module_names

qat_module_names = _get_qat_module_names()
quantized_module_names = _get_quantable_module_names()
assert set(qat_module_names) == set(quantized_module_names)

for key in qat_module_names:
value = getattr(Float, key)
assert (
isinstance(value, type)
and issubclass(value, Float.Module)
and value != Float.Module
)

Loading…
Cancel
Save