GitOrigin-RevId: 80cfb12d10
tags/v0.5.0
@@ -27,10 +27,10 @@ from .utils import _decide_comp_node_and_comp_graph | |||
def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: | |||
"""Applies a linear transformation to the input. | |||
Refer to :class:`~.Linear` for more information. | |||
Refer to :class:`~.module.linear.Linear` for more information. | |||
:param inp: the input tensor with shape `(N, in_features)`. | |||
:param weight: the weight with shape `(out_features, in_features)`. | |||
:param weight: the weight with shape `(out_features, in_features)`. | |||
:param bias: the bias with shape `(out_features,)`. | |||
Default: ``None`` | |||
""" | |||
@@ -300,9 +300,9 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: | |||
def softplus(inp: Tensor, beta: float = 1, threshold: float = 20) -> Tensor: | |||
r""" | |||
Performs the elementwise function: | |||
.. math:: | |||
\mathsf{softplus}(x) = \log(1+\exp(\beta x)) / \beta. | |||
For numerical stability the identity function is used when :math:`\beta x > \textrm{threshold}`. | |||
@@ -16,7 +16,7 @@ from .elemwise import Elemwise | |||
from .embedding import Embedding | |||
from .identity import Identity | |||
from .linear import Linear | |||
from .module import Module, QATModule | |||
from .module import Module | |||
from .parampack import ParamPack | |||
from .pooling import AvgPool2d, MaxPool2d | |||
from .quant_dequant import DequantStub, QuantStub | |||
@@ -9,19 +9,14 @@ from typing import Iterable | |||
from .. import functional as F | |||
from ..core.tensor import Tensor | |||
from .module import QATModule | |||
from .module import Module | |||
class Concat(QATModule): | |||
class Concat(Module): | |||
r""" | |||
A :class:`~.QATModule` to do functional concat, should replace concat with this module, | |||
supporting ``qat`` mode and ``quantized`` mode. | |||
A :class:`~.Module` to do functional concat. Could be replaced with :class:`~.QATModule` | |||
version :class:`~.qat.concat.Concat` using :func:`~.quantize.quantize_qat`. | |||
""" | |||
def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||
return F.concat(inps, axis) | |||
def forward_qat(self, inps: Iterable[Tensor], axis: int = 0): | |||
return self.apply_fakequant_with_observer( | |||
self.forward(inps, axis), self.act_fake_quant, self.act_observer | |||
) |
@@ -7,14 +7,13 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Tuple, Union | |||
from ..core import ones, zeros | |||
from ..functional import add_update, flatten, relu, sqrt, sum, zero_grad | |||
from ..functional import relu | |||
from .batchnorm import BatchNorm2d | |||
from .conv import Conv2d | |||
from .module import QATModule | |||
from .module import Module | |||
class _ConvBn2d(QATModule): | |||
class _ConvBnActivation2d(Module): | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
@@ -47,171 +46,24 @@ class _ConvBn2d(QATModule): | |||
) | |||
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | |||
def get_batch_mean_var(self, inp): | |||
def _sum_channel(inp, axis=0, keepdims=True): | |||
if isinstance(axis, int): | |||
out = sum(inp, axis=axis, keepdims=keepdims) | |||
elif isinstance(axis, tuple): | |||
for idx, elem in enumerate(axis): | |||
out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) | |||
return out | |||
sum1 = _sum_channel(inp, (0, 2, 3)) | |||
sum2 = _sum_channel(inp ** 2, (0, 2, 3)) | |||
reduce_size = inp.shapeof().prod() / inp.shapeof(1) | |||
batch_mean = sum1 / reduce_size | |||
batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size | |||
return batch_mean, batch_var | |||
def fold_weight_bias(self, bn_mean, bn_var): | |||
# get fold bn conv param | |||
# bn_istd = 1 / bn_std | |||
# w_fold = gamma / bn_std * W | |||
# b_fold = gamma * (b - bn_mean) / bn_std + beta | |||
gamma = self.bn.weight | |||
if gamma is None: | |||
gamma = ones((self.bn.num_features), dtype="float32") | |||
gamma = gamma.reshape(1, -1, 1, 1) | |||
beta = self.bn.bias | |||
if beta is None: | |||
beta = zeros((self.bn.num_features), dtype="float32") | |||
beta = beta.reshape(1, -1, 1, 1) | |||
if bn_mean is None: | |||
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||
if bn_var is None: | |||
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") | |||
conv_bias = self.conv.bias | |||
if conv_bias is None: | |||
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
# bn_istd = 1 / bn_std | |||
# w_fold = gamma / bn_std * W | |||
scale_factor = gamma * bn_istd | |||
if self.conv.groups == 1: | |||
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) | |||
else: | |||
w_fold = self.conv.weight * scale_factor.reshape( | |||
self.conv.groups, -1, 1, 1, 1 | |||
) | |||
# b_fold = gamma * (b - bn_mean) / bn_std + beta | |||
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | |||
return w_fold, b_fold | |||
def update_running_mean_and_running_var( | |||
self, bn_mean, bn_var, num_elements_per_channel | |||
): | |||
# update running mean and running var. no grad, use unbiased bn var | |||
bn_mean = zero_grad(bn_mean) | |||
bn_var = ( | |||
zero_grad(bn_var) | |||
* num_elements_per_channel | |||
/ (num_elements_per_channel - 1) | |||
) | |||
exponential_average_factor = 1 - self.bn.momentum | |||
add_update( | |||
self.bn.running_mean, | |||
delta=bn_mean, | |||
alpha=1 - exponential_average_factor, | |||
beta=exponential_average_factor, | |||
) | |||
add_update( | |||
self.bn.running_var, | |||
delta=bn_var, | |||
alpha=1 - exponential_average_factor, | |||
beta=exponential_average_factor, | |||
) | |||
def calc_conv_bn_qat(self, inp, approx=True): | |||
if self.training and not approx: | |||
conv = self.conv(inp) | |||
bn_mean, bn_var = self.get_batch_mean_var(conv) | |||
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) | |||
self.update_running_mean_and_running_var( | |||
bn_mean, bn_var, num_elements_per_channel | |||
) | |||
else: | |||
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var | |||
# get gamma and beta in BatchNorm | |||
gamma = self.bn.weight | |||
if gamma is None: | |||
gamma = ones((self.bn.num_features), dtype="float32") | |||
gamma = gamma.reshape(1, -1, 1, 1) | |||
beta = self.bn.bias | |||
if beta is None: | |||
beta = zeros((self.bn.num_features), dtype="float32") | |||
beta = beta.reshape(1, -1, 1, 1) | |||
# conv_bias | |||
conv_bias = self.conv.bias | |||
if conv_bias is None: | |||
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
# bn_istd = 1 / bn_std | |||
# w_fold = gamma / bn_std * W | |||
scale_factor = gamma * bn_istd | |||
if self.conv.groups == 1: | |||
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) | |||
else: | |||
w_fold = self.conv.weight * scale_factor.reshape( | |||
self.conv.groups, -1, 1, 1, 1 | |||
) | |||
b_fold = None | |||
if not (self.training and approx): | |||
# b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta | |||
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | |||
w_qat = self.apply_fakequant_with_observer( | |||
w_fold, self.weight_fake_quant, self.weight_observer | |||
) | |||
conv = self.conv.calc_conv(inp, w_qat, b_fold) | |||
if not (self.training and approx): | |||
return conv | |||
# rescale conv to get original conv output | |||
orig_conv = conv / scale_factor.reshape(1, -1, 1, 1) | |||
if self.conv.bias is not None: | |||
orig_conv = orig_conv + self.conv.bias | |||
# calculate batch norm | |||
bn_mean, bn_var = self.get_batch_mean_var(orig_conv) | |||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
conv = gamma * bn_istd * (orig_conv - bn_mean) + beta | |||
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) | |||
self.update_running_mean_and_running_var( | |||
bn_mean, bn_var, num_elements_per_channel | |||
) | |||
return conv | |||
class ConvBn2d(_ConvBn2d): | |||
class ConvBn2d(_ConvBnActivation2d): | |||
r""" | |||
A fused :class:`~.QATModule` including Conv2d and BatchNorm2d, supporting ``qat`` mode | |||
and ``normal`` mode. | |||
A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced | |||
with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBn2d` using | |||
:func:`~.quantize.quantize_qat`. | |||
""" | |||
def forward_qat(self, inp): | |||
return self.apply_fakequant_with_observer( | |||
self.calc_conv_bn_qat(inp), self.act_fake_quant, self.act_observer | |||
) | |||
def forward(self, inp): | |||
return self.bn(self.conv(inp)) | |||
class ConvBnRelu2d(_ConvBn2d): | |||
class ConvBnRelu2d(_ConvBnActivation2d): | |||
r""" | |||
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu, supporting ``qat`` | |||
mode and ``normal`` mode. | |||
A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced | |||
with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBnRelu2d` using | |||
:func:`~.quantize.quantize_qat`. | |||
""" | |||
def forward_qat(self, inp): | |||
return self.apply_fakequant_with_observer( | |||
relu(self.calc_conv_bn_qat(inp)), self.act_fake_quant, self.act_observer | |||
) | |||
def forward(self, inp): | |||
return relu(self.bn(self.conv(inp))) |
@@ -8,7 +8,7 @@ | |||
from .. import _internal as mgb | |||
from ..core import Tensor, wrap_io_tensor | |||
from ..core.graph import _use_default_if_none | |||
from .module import QATModule | |||
from .module import Module | |||
@wrap_io_tensor | |||
@@ -22,10 +22,10 @@ def _elemwise_func(mode, *inputs, **kwargs) -> Tensor: | |||
return mgb.opr.elemwise(*inputs, mode=mode, **kwargs) | |||
class Elemwise(QATModule): | |||
class Elemwise(Module): | |||
r""" | |||
A :class:`~.QATModule` to do elemwise operator, should functional operator with this module, | |||
supporting ``qat`` mode and ``normal`` mode. | |||
A :class:`~.Module` to do elemwise operator. Could be replaced with :class:`~.QATModule` | |||
version :class:`~.qat.elemwise.Elemwise` using :func:`~.quantize.quantize_qat`. | |||
:param method: the elemwise method, support the following string. | |||
It will do the normal elemwise operator for float. | |||
@@ -88,8 +88,3 @@ class Elemwise(QATModule): | |||
def forward(self, *inps): | |||
return _elemwise_func(self.method, *inps) | |||
def forward_qat(self, *inps): | |||
return self.apply_fakequant_with_observer( | |||
self.forward(*inps), self.act_fake_quant, self.act_observer, | |||
) |
@@ -1,4 +1,3 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
@@ -11,10 +10,10 @@ import numpy as np | |||
from .. import functional as F | |||
from ..core import Parameter | |||
from . import init | |||
from .module import QATModule | |||
from .module import Module | |||
class Linear(QATModule): | |||
class Linear(Module): | |||
r"""Applies a linear transformation to the input. For instance, if input | |||
is x, then output y is: | |||
@@ -60,13 +59,3 @@ class Linear(QATModule): | |||
def forward(self, x): | |||
return self._calc_linear(x, self.weight, self.bias) | |||
def forward_qat(self, x): | |||
w_qat = self.apply_fakequant_with_observer( | |||
self.weight, self.weight_fake_quant, self.weight_observer | |||
) | |||
return self.apply_fakequant_with_observer( | |||
self._calc_linear(x, w_qat, self.bias), | |||
self.act_fake_quant, | |||
self.act_observer, | |||
) |
@@ -7,7 +7,6 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from abc import ABCMeta, abstractmethod | |||
from collections import OrderedDict | |||
from enum import Enum | |||
from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||
import numpy as np | |||
@@ -443,98 +442,3 @@ class Module(metaclass=ABCMeta): | |||
loaded.append(k) | |||
return set(loaded), set(skipped) | |||
class QATModule(Module): | |||
r""" | |||
Base class of quantization related Module. Add extra forward methods | |||
:meth:`~.QATModule.forward_qat` and :meth:`~.QATModule.forward_quantized` for | |||
``qat``(quantization aware training) mode and ``quantized`` mode respectively. | |||
Use :meth:`~.QATModule.quant` to switch between ``QAT`` and ``NORMAL`` mode, | |||
and use :meth:`~.QATModule.to_quantized` to switch to ``quantized`` mode, | |||
which is irreversible. | |||
If you want to recursively switch mode for all QATModule in network, use | |||
functions in :mod:`~.quantization.quantize`. | |||
""" | |||
class QATMode(Enum): | |||
DISABLED = 1 | |||
QAT = 2 | |||
CALIBRATION = 3 | |||
def __init__(self): | |||
from ..quantization import ( | |||
QConfig, | |||
FakeQuantize, | |||
Observer, | |||
) # pylint: disable=all | |||
super().__init__() | |||
self.quantizing = self.QATMode.DISABLED | |||
self.scale = None | |||
self.weight_observer = None # type: Observer | |||
self.act_observer = None # type: Observer | |||
self.weight_fake_quant = None # type: FakeQuantize | |||
self.act_fake_quant = None # type: FakeQuantize | |||
def set_qconfig(self, qconfig: "QConfig"): | |||
self.weight_observer = qconfig.weight_observer() | |||
self.act_observer = qconfig.act_observer() | |||
self.weight_fake_quant = ( | |||
None | |||
if qconfig.fake_quant is None | |||
else qconfig.fake_quant(self.weight_observer.dtype) | |||
) | |||
self.act_fake_quant = ( | |||
None | |||
if qconfig.fake_quant is None | |||
else qconfig.fake_quant(self.act_observer.dtype) | |||
) | |||
def apply_observer(self, target: Tensor, obs: "Observer"): | |||
return obs(target) | |||
def apply_fakequant_with_observer( | |||
self, target: Tensor, fq: "FakeQuantize", obs: "Observer" | |||
): | |||
oup = self.apply_observer(target, obs) | |||
if fq is not None: | |||
q_dict = obs.get_qparams() | |||
oup = fq(oup, q_dict) | |||
return oup | |||
def set_qat_mode(self, mode: QATMode): | |||
r""" | |||
Change ``self.quantizing`` mode, available values: ``self.QATMode.DISABLED``, | |||
``QAT``,``CALIBRATION``. | |||
""" | |||
if not isinstance(mode, self.QATMode): | |||
raise TypeError("mode must be QATMode Enum type") | |||
self.quantizing = mode | |||
def to_quantized(self): | |||
r""" | |||
Return a new :class:`~.Module` with quantized parameters of ``self`` | |||
according to scale and zero_point in ``self.xxx_observer``. | |||
""" | |||
raise NotImplementedError( | |||
"Use megengine.quantization.quantize to register the method." | |||
) | |||
@abstractmethod | |||
def forward_qat(self, *args, **kwargs): | |||
r""" | |||
Forward method for ``qat`` mode. | |||
""" | |||
def __call__(self, *args, **kwargs): | |||
if self.quantizing == self.QATMode.DISABLED: | |||
return self.forward(*args, **kwargs) | |||
else: | |||
return self.forward_qat(*args, **kwargs) |
@@ -0,0 +1,13 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .concat import Concat | |||
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||
from .elemwise import Elemwise | |||
from .linear import Linear | |||
from .module import QATModule | |||
from .quant_dequant import DequantStub, QuantStub |
@@ -0,0 +1,30 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Iterable | |||
from ...core.tensor import Tensor | |||
from .. import concat as Float | |||
from .module import QATModule | |||
class Concat(Float.Concat, QATModule): | |||
r""" | |||
A :class:`~.QATModule` to do functional concat with QAT support. | |||
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
""" | |||
def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||
return self.apply_quant_activation(super().forward(inps, axis)) | |||
@classmethod | |||
def from_float_module(cls, float_module): | |||
r""" | |||
Return a :class:`~.QATModule` instance converted from | |||
a float :class:`~.Module` instance. | |||
""" | |||
return cls() |
@@ -0,0 +1,193 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ...core import ones, zeros | |||
from ...functional import add_update, relu, sqrt, sum, zero_grad | |||
from .. import conv_bn_relu as Float | |||
from .module import QATModule | |||
class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): | |||
def get_batch_mean_var(self, inp): | |||
def _sum_channel(inp, axis=0, keepdims=True): | |||
if isinstance(axis, int): | |||
out = sum(inp, axis=axis, keepdims=keepdims) | |||
elif isinstance(axis, tuple): | |||
for idx, elem in enumerate(axis): | |||
out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) | |||
return out | |||
sum1 = _sum_channel(inp, (0, 2, 3)) | |||
sum2 = _sum_channel(inp ** 2, (0, 2, 3)) | |||
reduce_size = inp.shapeof().prod() / inp.shapeof(1) | |||
batch_mean = sum1 / reduce_size | |||
batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size | |||
return batch_mean, batch_var | |||
def fold_weight_bias(self, bn_mean, bn_var): | |||
# get fold bn conv param | |||
# bn_istd = 1 / bn_std | |||
# w_fold = gamma / bn_std * W | |||
# b_fold = gamma * (b - bn_mean) / bn_std + beta | |||
gamma = self.bn.weight | |||
if gamma is None: | |||
gamma = ones((self.bn.num_features), dtype="float32") | |||
gamma = gamma.reshape(1, -1, 1, 1) | |||
beta = self.bn.bias | |||
if beta is None: | |||
beta = zeros((self.bn.num_features), dtype="float32") | |||
beta = beta.reshape(1, -1, 1, 1) | |||
if bn_mean is None: | |||
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||
if bn_var is None: | |||
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") | |||
conv_bias = self.conv.bias | |||
if conv_bias is None: | |||
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
# bn_istd = 1 / bn_std | |||
# w_fold = gamma / bn_std * W | |||
scale_factor = gamma * bn_istd | |||
if self.conv.groups == 1: | |||
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) | |||
else: | |||
w_fold = self.conv.weight * scale_factor.reshape( | |||
self.conv.groups, -1, 1, 1, 1 | |||
) | |||
# b_fold = gamma * (b - bn_mean) / bn_std + beta | |||
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | |||
return w_fold, b_fold | |||
def update_running_mean_and_running_var( | |||
self, bn_mean, bn_var, num_elements_per_channel | |||
): | |||
# update running mean and running var. no grad, use unbiased bn var | |||
bn_mean = zero_grad(bn_mean) | |||
bn_var = ( | |||
zero_grad(bn_var) | |||
* num_elements_per_channel | |||
/ (num_elements_per_channel - 1) | |||
) | |||
exponential_average_factor = 1 - self.bn.momentum | |||
add_update( | |||
self.bn.running_mean, | |||
delta=bn_mean, | |||
alpha=1 - exponential_average_factor, | |||
beta=exponential_average_factor, | |||
) | |||
add_update( | |||
self.bn.running_var, | |||
delta=bn_var, | |||
alpha=1 - exponential_average_factor, | |||
beta=exponential_average_factor, | |||
) | |||
def calc_conv_bn_qat(self, inp, approx=True): | |||
if self.training and not approx: | |||
conv = self.conv(inp) | |||
bn_mean, bn_var = self.get_batch_mean_var(conv) | |||
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) | |||
self.update_running_mean_and_running_var( | |||
bn_mean, bn_var, num_elements_per_channel | |||
) | |||
else: | |||
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var | |||
# get gamma and beta in BatchNorm | |||
gamma = self.bn.weight | |||
if gamma is None: | |||
gamma = ones((self.bn.num_features), dtype="float32") | |||
gamma = gamma.reshape(1, -1, 1, 1) | |||
beta = self.bn.bias | |||
if beta is None: | |||
beta = zeros((self.bn.num_features), dtype="float32") | |||
beta = beta.reshape(1, -1, 1, 1) | |||
# conv_bias | |||
conv_bias = self.conv.bias | |||
if conv_bias is None: | |||
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
# bn_istd = 1 / bn_std | |||
# w_fold = gamma / bn_std * W | |||
scale_factor = gamma * bn_istd | |||
if self.conv.groups == 1: | |||
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) | |||
else: | |||
w_fold = self.conv.weight * scale_factor.reshape( | |||
self.conv.groups, -1, 1, 1, 1 | |||
) | |||
b_fold = None | |||
if not (self.training and approx): | |||
# b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta | |||
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | |||
w_qat = self.apply_quant_weight(w_fold) | |||
conv = self.conv.calc_conv(inp, w_qat, b_fold) | |||
if not (self.training and approx): | |||
return conv | |||
# rescale conv to get original conv output | |||
orig_conv = conv / scale_factor.reshape(1, -1, 1, 1) | |||
if self.conv.bias is not None: | |||
orig_conv = orig_conv + self.conv.bias | |||
# calculate batch norm | |||
bn_mean, bn_var = self.get_batch_mean_var(orig_conv) | |||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
conv = gamma * bn_istd * (orig_conv - bn_mean) + beta | |||
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) | |||
self.update_running_mean_and_running_var( | |||
bn_mean, bn_var, num_elements_per_channel | |||
) | |||
return conv | |||
@classmethod | |||
def from_float_module(cls, float_module: Float._ConvBnActivation2d): | |||
r""" | |||
Return a :class:`~.QATModule` instance converted from | |||
a float :class:`~.Module` instance. | |||
""" | |||
qat_module = cls( | |||
float_module.conv.in_channels, | |||
float_module.conv.out_channels, | |||
float_module.conv.kernel_size, | |||
float_module.conv.stride, | |||
float_module.conv.padding, | |||
float_module.conv.dilation, | |||
float_module.conv.groups, | |||
bool(float_module.conv.bias), | |||
float_module.conv.conv_mode.name, | |||
float_module.conv.compute_mode.name, | |||
) | |||
qat_module.conv.weight = float_module.conv.weight | |||
qat_module.conv.bias = float_module.conv.bias | |||
qat_module.bn = float_module.bn | |||
return qat_module | |||
class ConvBn2d(_ConvBnActivation2d): | |||
r""" | |||
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d with QAT support. | |||
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
""" | |||
def forward(self, inp): | |||
return self.apply_quant_activation(self.calc_conv_bn_qat(inp)) | |||
class ConvBnRelu2d(_ConvBnActivation2d): | |||
r""" | |||
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu with QAT support. | |||
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
""" | |||
def forward(self, inp): | |||
return self.apply_quant_activation(relu(self.calc_conv_bn_qat(inp))) |
@@ -0,0 +1,29 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .. import elemwise as Float | |||
from .module import QATModule | |||
class Elemwise(Float.Elemwise, QATModule): | |||
r""" | |||
A :class:`~.QATModule` to do elemwise operator with QAT support. | |||
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
:param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail. | |||
""" | |||
def forward(self, *inps): | |||
return self.apply_quant_activation(super().forward(*inps)) | |||
@classmethod | |||
def from_float_module(cls, float_module: Float.Elemwise): | |||
r""" | |||
Return a :class:`~.QATModule` instance converted from | |||
a float :class:`~.Module` instance. | |||
""" | |||
return cls(float_module.method.name) |
@@ -0,0 +1,37 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .. import linear as Float | |||
from .module import QATModule | |||
class Linear(Float.Linear, QATModule): | |||
r""" | |||
A :class:`~.QATModule` version of :class:`~.module.linear.Linear`. | |||
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
:param in_features: size of each input sample. | |||
:param out_features: size of each output sample. | |||
:param bias: If set to ``False``, the layer will not learn an additive bias. | |||
Default: ``True`` | |||
""" | |||
def forward(self, x): | |||
w_qat = self.apply_quant_weight(self.weight) | |||
return self.apply_quant_activation(self._calc_linear(x, w_qat, self.bias),) | |||
@classmethod | |||
def from_float_module(cls, float_module: Float.Linear): | |||
r""" | |||
Return a :class:`~.QATModule` instance converted from | |||
a float :class:`~.Module` instance. | |||
""" | |||
qmod = cls(float_module.in_features, float_module.out_features) | |||
qmod.weight = float_module.weight | |||
qmod.bias = float_module.bias | |||
return qmod |
@@ -0,0 +1,96 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from abc import abstractmethod | |||
from ...core import Tensor | |||
from ...quantization import FakeQuantize, Observer, QConfig | |||
from ..module import Module | |||
class QATModule(Module): | |||
r""" | |||
Base class of quantized-float related Module, basically for QAT and Calibration. | |||
Use :meth:`~.QATModule.from_float_module` to generate a instance from float :class:`~.Module`. | |||
Or use :func:`~.quantize.quantize_qat` to do it recursively and automatically. | |||
Can also be converted to :class:`~.QuantizedModule` for deployment using | |||
:func:`~.quantize.quantize` further. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self.scale = None | |||
self.weight_observer = None # type: Observer | |||
self.act_observer = None # type: Observer | |||
self.weight_fake_quant = None # type: FakeQuantize | |||
self.act_fake_quant = None # type: FakeQuantize | |||
def set_qconfig(self, qconfig: QConfig): | |||
r""" | |||
Set quantization related configs with ``qconfig``, including | |||
observer and fake_quant for weight and activation. | |||
""" | |||
self.weight_observer = qconfig.weight_observer() | |||
self.act_observer = qconfig.act_observer() | |||
if qconfig.fake_quant is None: | |||
self.weight_fake_quant = None | |||
self.act_fake_quant = None | |||
else: | |||
self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype) | |||
self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype) | |||
def _apply_fakequant_with_observer( | |||
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | |||
): | |||
oup = observer(target) | |||
if fake_quant is None: | |||
return oup | |||
else: | |||
q_dict = observer.get_qparams() | |||
return fake_quant(oup, q_dict) | |||
def apply_quant_weight(self, target: Tensor): | |||
r""" | |||
Apply weight's observer and fake_quant from ``qconfig`` on ``target``. | |||
""" | |||
return self._apply_fakequant_with_observer( | |||
target, self.weight_fake_quant, self.weight_observer | |||
) | |||
def apply_quant_activation(self, target: Tensor): | |||
r""" | |||
Apply weight's observer and fake_quant from ``qconfig`` on ``target``. | |||
""" | |||
return self._apply_fakequant_with_observer( | |||
target, self.act_fake_quant, self.act_observer | |||
) | |||
def get_weight_dtype(self): | |||
r""" | |||
Get weight's quantization dtype as the method from ``qconfig``. | |||
""" | |||
return self.weight_observer.get_dtype() | |||
def get_activation_dtype(self): | |||
r""" | |||
Get activation's quantization dtype as the method from ``qconfig``. | |||
""" | |||
return self.act_observer.get_dtype() | |||
@classmethod | |||
@abstractmethod | |||
def from_float_module(cls, float_module: Module): | |||
r""" | |||
Return a :class:`~.QATModule` instance converted from | |||
a float :class:`~.Module` instance. | |||
""" |
@@ -0,0 +1,45 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .. import quant_dequant as Float | |||
from .module import QATModule | |||
class QuantStub(Float.QuantStub, QATModule): | |||
r""" | |||
A helper QATModule simply return input, but will quantize | |||
input after converted to :class:`~.QuantizedModule`. | |||
""" | |||
def forward(self, inp): | |||
return self.apply_quant_activation(inp) | |||
@classmethod | |||
def from_float_module(cls, float_module: Float.QuantStub): | |||
r""" | |||
Return a :class:`~.QATModule` instance converted from | |||
a float :class:`~.Module` instance. | |||
""" | |||
return cls() | |||
class DequantStub(Float.DequantStub, QATModule): | |||
r""" | |||
A helper QATModule simply return input, but will de-quantize | |||
input after converted to :class:`~.QuantizedModule`. | |||
""" | |||
def forward(self, inp): | |||
return inp | |||
@classmethod | |||
def from_float_module(cls, float_module: Float.DequantStub): | |||
r""" | |||
Return a :class:`~.QATModule` instance converted from | |||
a float :class:`~.Module` instance. | |||
""" | |||
return cls() |
@@ -5,30 +5,24 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .module import QATModule | |||
from .module import Module | |||
class QuantStub(QATModule): | |||
class QuantStub(Module): | |||
r""" | |||
A helper QATModule doing quantize operation on input. | |||
A helper :class:`~.Module` simply returning input. Could be replaced with :class:`~.QATModule` | |||
version :class:`~.qat.QuantStub` using :func:`~.quantize.quantize_qat`. | |||
""" | |||
def forward(self, inp): | |||
return inp | |||
def forward_qat(self, inp): | |||
return self.apply_fakequant_with_observer( | |||
inp, self.act_fake_quant, self.act_observer | |||
) | |||
class DequantStub(QATModule): | |||
class DequantStub(Module): | |||
r""" | |||
A helper QATModule doing de-quantize operation on input. | |||
A helper :class:`~.Module` simply returning input. Could be replaced with :class:`~.QATModule` | |||
version :class:`~.qat.DequantStub` using :func:`~.quantize.quantize_qat`. | |||
""" | |||
def forward(self, inp): | |||
return inp | |||
def forward_qat(self, inp): | |||
return inp |
@@ -9,4 +9,5 @@ from .concat import Concat | |||
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||
from .elemwise import Elemwise | |||
from .linear import Linear | |||
from .module import QuantizedModule | |||
from .quant_dequant import DequantStub, QuantStub |
@@ -7,17 +7,15 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Iterable | |||
from ... import _internal as mgb | |||
from ... import functional as F | |||
from ... import module as Float | |||
from ...core.tensor import Tensor | |||
from ...quantization.utils import register_method_to_class | |||
from ..module import Module | |||
from ..qat import concat as QAT | |||
from .module import QuantizedModule | |||
class Concat(Module): | |||
class Concat(QuantizedModule): | |||
r""" | |||
A :class:`~.Module` to do quantized concat, inference only. | |||
A :class:`~.QuantizedModule` to do quantized concat, inference only. | |||
""" | |||
def __init__(self, dtype=None): | |||
@@ -25,16 +23,13 @@ class Concat(Module): | |||
self.output_dtype = dtype | |||
def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||
if self.training: | |||
raise ValueError("quantized module only support inference.") | |||
new_inps = (x.astype(self.output_dtype) for x in inps) | |||
return F.concat(new_inps, axis) | |||
@register_method_to_class(Float.Concat) | |||
def to_quantized(float_module): | |||
r""" | |||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
implemented here to avoid circular import. | |||
""" | |||
return Concat(float_module.act_observer.get_dtype()) | |||
@classmethod | |||
def from_qat_module(cls, qat_module: QAT.Concat): | |||
r""" | |||
return a :class:`~.QuantizedModule` instance converted from a | |||
:class:`~.QATModule` instance. | |||
""" | |||
return cls(qat_module.get_activation_dtype()) |
@@ -5,7 +5,6 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from functools import partial | |||
from typing import Tuple, Union | |||
import megengine._internal as mgb | |||
@@ -13,11 +12,11 @@ import megengine._internal as mgb | |||
from ... import module as Float | |||
from ...core import Parameter | |||
from ...functional import conv_bias_activation | |||
from ...module import Conv2d | |||
from ...quantization.utils import register_method_to_class | |||
from ..qat import conv_bn_relu as QAT | |||
from .module import QuantizedModule | |||
class _ConvBnActivation2d(Conv2d): | |||
class _ConvBnActivation2d(Float.Conv2d, QuantizedModule): | |||
r"""Applies a 2D convolution over an quantized input tensor, inference only. | |||
The parameter is same with :class: `~.Conv2d` | |||
@@ -68,44 +67,41 @@ class _ConvBnActivation2d(Conv2d): | |||
nonlinear_mode=nonlinear_mode, | |||
) | |||
@classmethod | |||
def from_qat_module(cls, qat_module: QAT._ConvBnActivation2d): | |||
r""" | |||
return a :class:`~.QuantizedModule` instance converted from a | |||
:class:`~.QATModule` instance. | |||
""" | |||
output_dtype = qat_module.get_activation_dtype() | |||
qconv = cls( | |||
qat_module.conv.in_channels, | |||
qat_module.conv.out_channels, | |||
qat_module.conv.kernel_size, | |||
qat_module.conv.stride, | |||
qat_module.conv.padding, | |||
qat_module.conv.dilation, | |||
qat_module.conv.groups, | |||
dtype=output_dtype, | |||
) | |||
w_fold, b_fold = qat_module.fold_weight_bias( | |||
qat_module.bn.running_mean, qat_module.bn.running_var | |||
) | |||
weight = w_fold.astype(qat_module.get_weight_dtype()) | |||
qconv.weight = Parameter(weight.numpy()) | |||
qconv.bias = Parameter(b_fold.numpy()) | |||
return qconv | |||
class ConvBn2d(_ConvBnActivation2d): | |||
r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBn2d`.""" | |||
def forward(self, inp): | |||
if self.training: | |||
raise ValueError("quantized module only support inference.") | |||
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") | |||
class ConvBnRelu2d(_ConvBnActivation2d): | |||
r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBnRelu2d`.""" | |||
def forward(self, inp): | |||
if self.training: | |||
raise ValueError("quantized module only support inference.") | |||
return self.calc_conv_quantized(inp, nonlinear_mode="RELU") | |||
def to_quantized(quantized_class, float_module): | |||
output_dtype = float_module.act_observer.get_dtype() | |||
qconv = quantized_class( | |||
float_module.conv.in_channels, | |||
float_module.conv.out_channels, | |||
float_module.conv.kernel_size, | |||
float_module.conv.stride, | |||
float_module.conv.padding, | |||
float_module.conv.dilation, | |||
float_module.conv.groups, | |||
dtype=output_dtype, | |||
) | |||
w_fold, b_fold = float_module.fold_weight_bias( | |||
float_module.bn.running_mean, float_module.bn.running_var | |||
) | |||
weight = w_fold.astype(float_module.weight_observer.get_dtype()) | |||
qconv.weight = Parameter(weight.numpy()) | |||
qconv.bias = Parameter(b_fold.numpy()) | |||
return qconv | |||
# replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
# implemented here to avoid circular import. | |||
register_method_to_class(Float.ConvBn2d)(partial(to_quantized, ConvBn2d)) | |||
register_method_to_class(Float.ConvBnRelu2d)(partial(to_quantized, ConvBnRelu2d)) |
@@ -6,11 +6,10 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ... import _internal as mgb | |||
from ... import module as Float | |||
from ...core import Tensor, wrap_io_tensor | |||
from ...core.graph import _use_default_if_none | |||
from ...quantization.utils import register_method_to_class | |||
from ..module import Module | |||
from ..qat import elemwise as QAT | |||
from .module import QuantizedModule | |||
@wrap_io_tensor | |||
@@ -24,13 +23,8 @@ def _elemwise_multi_type(mode, *inputs, **kwargs) -> Tensor: | |||
return mgb.opr.elemwise_multi_type(*inputs, mode=mode, **kwargs) | |||
class Elemwise(Module): | |||
r""" | |||
quantized module for elemwise operator, inference only. | |||
:param method: the elemwise method, supported string refer to :class:`~.module.elemwise.Elemwise`. | |||
it will do quantized operator with specified output quantized dtype. | |||
""" | |||
class Elemwise(QuantizedModule): | |||
r"""quantized version of :class:`~.qat.elemwise.Elemwise`.""" | |||
_elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode | |||
@@ -44,11 +38,10 @@ class Elemwise(Module): | |||
raise ValueError("quantized module only support inference.") | |||
return _elemwise_multi_type(self.method, *inps, dtype=self.output_dtype) | |||
@register_method_to_class(Float.Elemwise) | |||
def to_quantized(float_module): | |||
r""" | |||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
implemented here to avoid circular import. | |||
""" | |||
return Elemwise(float_module.method.name, float_module.act_observer.get_dtype()) | |||
@classmethod | |||
def from_qat_module(cls, qat_module: QAT.Elemwise): | |||
r""" | |||
return a :class:`~.QuantizedModule` instance converted from a | |||
:class:`~.QATModule` instance. | |||
""" | |||
return cls(qat_module.method.name, qat_module.get_activation_dtype()) |
@@ -10,19 +10,13 @@ import numpy as np | |||
import megengine._internal as mgb | |||
from ... import functional as F | |||
from ... import module as Float | |||
from ...core import Parameter | |||
from ...quantization.utils import register_method_to_class | |||
from ..module import Module | |||
from ..qat import linear as QAT | |||
from .module import QuantizedModule | |||
class Linear(Module): | |||
r"""Applies a quantized linear transformation to the input. The module | |||
usually convert from QAT module by to_quantized method. | |||
:param dtype: output data type. | |||
""" | |||
class Linear(QuantizedModule): | |||
r"""quantized version of :class:`~.qat.linear.Linear`.""" | |||
def __init__( | |||
self, dtype: np.dtype = None, | |||
@@ -44,17 +38,16 @@ class Linear(Module): | |||
None if self.bias is None else self.bias.astype(bias_dtype), | |||
).astype(self.output_dtype) | |||
@register_method_to_class(Float.Linear) | |||
def to_quantized(float_module): | |||
r""" | |||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
implemented here to avoid circular import. | |||
""" | |||
output_dtype = float_module.act_observer.get_dtype() | |||
qmod = Linear(dtype=output_dtype,) | |||
weight = float_module.weight.astype(float_module.weight_observer.get_dtype()) | |||
qmod.weight = Parameter(weight.numpy()) | |||
if float_module.bias is not None: | |||
qmod.bias = Parameter(float_module.bias.numpy()) | |||
return qmod | |||
@classmethod | |||
def from_qat_module(cls, qat_module: QAT.Linear): | |||
r""" | |||
return a :class:`~.QuantizedModule` instance converted from a | |||
:class:`~.QATModule` instance. | |||
""" | |||
output_dtype = qat_module.get_activation_dtype() | |||
qmod = cls(dtype=output_dtype) | |||
weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | |||
qmod.weight = Parameter(weight.numpy()) | |||
if qat_module.bias is not None: | |||
qmod.bias = Parameter(qat_module.bias.numpy()) | |||
return qmod |
@@ -0,0 +1,31 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from abc import abstractmethod | |||
from ..module import Module | |||
from ..qat import QATModule | |||
class QuantizedModule(Module): | |||
r""" | |||
Base class of quantized Module, which should be converted from QATModule | |||
and not support traning. | |||
""" | |||
def __call__(self, *inputs, **kwargs): | |||
if self.training: | |||
raise ValueError("quantized module only support inference.") | |||
return super().__call__(*inputs, **kwargs) | |||
@classmethod | |||
@abstractmethod | |||
def from_qat_module(cls, qat_module: QATModule): | |||
r""" | |||
return a :class:`~.QuantizedModule` instance converted from a | |||
:class:`~.QATModule` instance. | |||
""" |
@@ -5,15 +5,14 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ... import _internal as mgb | |||
from ... import module as Float | |||
from ...quantization.utils import register_method_to_class | |||
from ..module import Module | |||
from ..qat import quant_dequant as QAT | |||
from .module import QuantizedModule | |||
class QuantStub(Module): | |||
class QuantStub(QuantizedModule): | |||
r""" | |||
A helper quantize operation on input and inference only. | |||
quantized version of :class:`~.qat.quant_dequant.QuantStub`, | |||
will convert input to quantized dtype. | |||
""" | |||
def __init__(self, dtype=None): | |||
@@ -21,35 +20,30 @@ class QuantStub(Module): | |||
self.output_dtype = dtype | |||
def forward(self, inp): | |||
if self.training: | |||
raise ValueError("quantized module only support inference.") | |||
return inp.astype(self.output_dtype) | |||
@classmethod | |||
def from_qat_module(cls, qat_module: QAT.QuantStub): | |||
r""" | |||
return a :class:`~.QuantizedModule` instance converted from a | |||
:class:`~.QATModule` instance. | |||
""" | |||
return cls(qat_module.get_activation_dtype()) | |||
class DequantStub(Module): | |||
class DequantStub(QuantizedModule): | |||
r""" | |||
A helper de-quantize operation and inference only. | |||
quantized version of :class:`~.qat.quant_dequant.DequantStub`, | |||
will restore quantized input to float32 dtype. | |||
""" | |||
def forward(self, inp): | |||
if self.training: | |||
raise ValueError("quantized module only support inference.") | |||
return inp.astype("float32") | |||
@register_method_to_class(Float.QuantStub) | |||
def to_quantized(float_module): | |||
r""" | |||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
implemented here to avoid circular import. | |||
""" | |||
return QuantStub(float_module.act_observer.get_dtype()) | |||
@register_method_to_class(Float.DequantStub) | |||
def to_quantized(float_module): | |||
r""" | |||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
implemented here to avoid circular import. | |||
""" | |||
return DequantStub() | |||
@classmethod | |||
def from_qat_module(cls, qat_module: QAT.DequantStub): | |||
r""" | |||
return a :class:`~.QuantizedModule` instance converted from a | |||
:class:`~.QATModule` instance. | |||
""" | |||
return cls() |
@@ -13,12 +13,3 @@ from .qconfig import ( | |||
ema_fakequant_qconfig, | |||
min_max_fakequant_qconfig, | |||
) | |||
from .quantize import ( | |||
disable_fake_quant, | |||
disable_observer, | |||
enable_fake_quant, | |||
enable_observer, | |||
quantize, | |||
quantize_calibration, | |||
quantize_qat, | |||
) |
@@ -15,16 +15,12 @@ from .observer import ( | |||
class QConfig: | |||
""" | |||
r""" | |||
A config class indicating how to do quantize toward :class:`~.QATModule`'s | |||
``activation`` and ``weight``. | |||
And ``fake_quant`` parameter to indicate | |||
See :meth:`~.QATModule.set_qconfig` for detail usage. | |||
``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage. | |||
:param weight_observer: interface to instantiate an :class:`~.Observer` indicating | |||
- how to collect scales and zero_point of wegiht. | |||
how to collect scales and zero_point of wegiht. | |||
:param act_observer: similar to ``weight_observer`` but toward activation. | |||
:param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating | |||
how to do fake_quant calculation. can be invoked multi times to get different | |||
@@ -6,68 +6,125 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from copy import deepcopy | |||
from ..module import Module, QATModule, Sequential, quantized | |||
from typing import Dict, Tuple | |||
from .. import module as Float | |||
from ..module import Module | |||
from ..module import qat as QAT | |||
from ..module import quantized as Quantized | |||
from ..module.qat import QATModule | |||
from ..module.quantized import QuantizedModule | |||
from .qconfig import QConfig, ema_fakequant_qconfig | |||
def _get_quantable_module_names(): | |||
def is_quantable(key: str): | |||
value = getattr(Quantized, key) | |||
return ( | |||
isinstance(value, type) | |||
and issubclass(value, QuantizedModule) | |||
and value != QuantizedModule | |||
) | |||
# source should have all quantable modules' names | |||
quantable_module_names = [key for key in dir(Quantized) if is_quantable(key)] | |||
return quantable_module_names | |||
def _get_convert_dict() -> Tuple[ | |||
Dict[Module, QATModule], Dict[QATModule, QuantizedModule] | |||
]: | |||
quantable_module_names = _get_quantable_module_names() | |||
quantable_modules = [getattr(Float, key) for key in quantable_module_names] | |||
qat_modules = [getattr(QAT, key) for key in quantable_module_names] | |||
quantized_modules = [getattr(Quantized, key) for key in quantable_module_names] | |||
float2qat_dict = dict(zip(quantable_modules, qat_modules)) | |||
qat2quantized_dict = dict(zip(qat_modules, quantized_modules)) | |||
return float2qat_dict, qat2quantized_dict | |||
_float2qat_dict, _qat2quantized_dict = _get_convert_dict() | |||
def quantize(module: Module, inplace=True): | |||
r""" | |||
Recursively convert `module` to `quantized` mode through :meth:`~.Module.apply`. | |||
Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule` | |||
through :meth:`~.Module.apply`. | |||
:param module: root module to do convert recursively. | |||
:param inplace: whether to convert submodules in-place. | |||
""" | |||
if not inplace: | |||
module = deepcopy(module) | |||
def is_qat_module(obj): | |||
return isinstance(obj, QATModule) | |||
qat_modules = tuple(_qat2quantized_dict.keys()) | |||
def is_qat(mod: Module): | |||
return isinstance(mod, qat_modules) | |||
# no need to pass prefix and get pure key of parent Module. | |||
for key, submodule, parent in module._flatten( | |||
with_key=True, with_parent=True, predicate=is_qat_module | |||
with_key=True, with_parent=True, predicate=is_qat | |||
): | |||
if isinstance(parent, Sequential): | |||
new_mod = _qat2quantized_dict[type(submodule)].from_qat_module(submodule) | |||
if isinstance(parent, Float.Sequential): | |||
# cannnot use setattr to be compatible with Sequential's ``__setitem__`` | |||
parent[int(key.split(".")[-1])] = submodule.to_quantized() | |||
parent[int(key.split(".")[-1])] = new_mod | |||
else: | |||
setattr(parent, key.split(".")[-1], submodule.to_quantized()) | |||
setattr(parent, key.split(".")[-1], new_mod) | |||
return module | |||
def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): | |||
def quantize_qat( | |||
module: Module, inplace=True, qconfig: QConfig = ema_fakequant_qconfig | |||
): | |||
r""" | |||
Recursively convert `module` to `qat` mode through :meth:`~.Module.apply` | |||
and set qconfig relatively. | |||
Recursively convert float :class:`~.Module` to :class:`~.QATModule` | |||
through :meth:`~.Module.apply` and set qconfig relatively. | |||
:param module: root module to do convert recursively. | |||
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. | |||
default is :any:`~.qconfig.ema_fakequant_qconfig`. | |||
:param inplace: whether to convert submodules in-place. | |||
:param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig. | |||
default is ``ema_fakequant_qconfig``. | |||
""" | |||
def fn(mod: Module): | |||
if isinstance(mod, QATModule): | |||
mod.set_qat_mode(QATModule.QATMode.QAT) | |||
mod.set_qconfig(qconfig) | |||
if not inplace: | |||
module = deepcopy(module) | |||
module.apply(fn) | |||
quantable_modules = tuple(_float2qat_dict.keys()) | |||
def is_quantable(mod: Module): | |||
return isinstance(mod, quantable_modules) | |||
# no need to pass prefix and get pure key of parent Module. | |||
for key, submodule, parent in module._flatten( | |||
with_key=True, with_parent=True, predicate=is_quantable | |||
): | |||
new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule) | |||
if isinstance(parent, Float.Sequential): | |||
# cannnot use setattr to be compatible with Sequential's ``__setitem__`` | |||
parent[int(key.split(".")[-1])] = new_mod | |||
else: | |||
setattr(parent, key.split(".")[-1], new_mod) | |||
propagate_qconfig(module, qconfig) | |||
return module | |||
def quantize_calibration(module: Module, qconfig: QConfig = ema_fakequant_qconfig): | |||
def propagate_qconfig(module: QATModule, qconfig: QConfig): | |||
r""" | |||
Recursively convert `module` to `calibration` mode through :meth:`~.Module.apply` | |||
and set qconfig relatively. | |||
Recursively set ``module``'s qconfig through :meth:`~.Module.apply`. | |||
:param module: root module to do convert recursively. | |||
:param module: root module to traverse recursively. | |||
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. | |||
default is :any:`~.qconfig.ema_fakequant_qconfig`. | |||
""" | |||
def fn(mod: Module): | |||
if isinstance(mod, QATModule): | |||
mod.set_qat_mode(QATModule.QATMode.CALIBRATION) | |||
mod.set_qconfig(qconfig) | |||
module.apply(fn) | |||
@@ -5,8 +5,7 @@ import numpy as np | |||
from megengine import tensor | |||
from megengine.module import ConvBn2d | |||
from megengine.quantization import quantize_qat | |||
from megengine.quantization.quantize import disable_fake_quant | |||
from megengine.quantization.quantize import disable_fake_quant, quantize_qat | |||
from megengine.test import assertTensorClose | |||
@@ -14,18 +13,17 @@ def test_convbn2d(): | |||
in_channels = 32 | |||
out_channels = 64 | |||
kernel_size = 3 | |||
module = ConvBn2d(in_channels, out_channels, kernel_size) | |||
quantize_qat(module) | |||
for groups, bias in product([1, 4], [True, False]): | |||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
module = ConvBn2d( | |||
in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||
) | |||
module.train() | |||
qat_module = copy.deepcopy(module) | |||
qat_module = quantize_qat(module, inplace=False) | |||
disable_fake_quant(qat_module) | |||
normal_outputs = module.forward(inputs) | |||
qat_outputs = qat_module.forward_qat(inputs) | |||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
normal_outputs = module(inputs) | |||
qat_outputs = qat_module(inputs) | |||
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) | |||
a = module.bn.running_mean.numpy() | |||
b = qat_module.bn.running_mean.numpy() | |||
assertTensorClose( | |||
module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8 | |||
) | |||
@@ -33,7 +31,7 @@ def test_convbn2d(): | |||
module.bn.running_var, qat_module.bn.running_var, max_err=5e-7 | |||
) | |||
module.eval() | |||
normal_outputs = module.forward(inputs) | |||
normal_outputs = module(inputs) | |||
qat_module.eval() | |||
qat_outputs = qat_module.forward_qat(inputs) | |||
qat_outputs = qat_module(inputs) | |||
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) |
@@ -0,0 +1,38 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from megengine import module as Float | |||
from megengine.module import qat as QAT | |||
from megengine.quantization.quantize import _get_quantable_module_names | |||
def test_get_quantable_module_names(): | |||
# need to make sure names from Quantized and QAT are the same | |||
def _get_qat_module_names(): | |||
def is_qat(key: str): | |||
value = getattr(QAT, key) | |||
return ( | |||
isinstance(value, type) | |||
and issubclass(value, QAT.QATModule) | |||
and value != QAT.QATModule | |||
) | |||
# source should have all quantable modules' names | |||
quantable_module_names = [key for key in dir(QAT) if is_qat(key)] | |||
return quantable_module_names | |||
qat_module_names = _get_qat_module_names() | |||
quantized_module_names = _get_quantable_module_names() | |||
assert set(qat_module_names) == set(quantized_module_names) | |||
for key in qat_module_names: | |||
value = getattr(Float, key) | |||
assert ( | |||
isinstance(value, type) | |||
and issubclass(value, Float.Module) | |||
and value != Float.Module | |||
) |