GitOrigin-RevId: 80cfb12d10
tags/v0.5.0
@@ -27,10 +27,10 @@ from .utils import _decide_comp_node_and_comp_graph | |||||
def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: | def linear(inp: Tensor, weight: Tensor, bias: Optional[Tensor] = None) -> Tensor: | ||||
"""Applies a linear transformation to the input. | """Applies a linear transformation to the input. | ||||
Refer to :class:`~.Linear` for more information. | |||||
Refer to :class:`~.module.linear.Linear` for more information. | |||||
:param inp: the input tensor with shape `(N, in_features)`. | :param inp: the input tensor with shape `(N, in_features)`. | ||||
:param weight: the weight with shape `(out_features, in_features)`. | |||||
:param weight: the weight with shape `(out_features, in_features)`. | |||||
:param bias: the bias with shape `(out_features,)`. | :param bias: the bias with shape `(out_features,)`. | ||||
Default: ``None`` | Default: ``None`` | ||||
""" | """ | ||||
@@ -300,9 +300,9 @@ def leaky_relu(inp: Tensor, negative_slope: float = 0.01) -> Tensor: | |||||
def softplus(inp: Tensor, beta: float = 1, threshold: float = 20) -> Tensor: | def softplus(inp: Tensor, beta: float = 1, threshold: float = 20) -> Tensor: | ||||
r""" | r""" | ||||
Performs the elementwise function: | Performs the elementwise function: | ||||
.. math:: | .. math:: | ||||
\mathsf{softplus}(x) = \log(1+\exp(\beta x)) / \beta. | \mathsf{softplus}(x) = \log(1+\exp(\beta x)) / \beta. | ||||
For numerical stability the identity function is used when :math:`\beta x > \textrm{threshold}`. | For numerical stability the identity function is used when :math:`\beta x > \textrm{threshold}`. | ||||
@@ -16,7 +16,7 @@ from .elemwise import Elemwise | |||||
from .embedding import Embedding | from .embedding import Embedding | ||||
from .identity import Identity | from .identity import Identity | ||||
from .linear import Linear | from .linear import Linear | ||||
from .module import Module, QATModule | |||||
from .module import Module | |||||
from .parampack import ParamPack | from .parampack import ParamPack | ||||
from .pooling import AvgPool2d, MaxPool2d | from .pooling import AvgPool2d, MaxPool2d | ||||
from .quant_dequant import DequantStub, QuantStub | from .quant_dequant import DequantStub, QuantStub | ||||
@@ -9,19 +9,14 @@ from typing import Iterable | |||||
from .. import functional as F | from .. import functional as F | ||||
from ..core.tensor import Tensor | from ..core.tensor import Tensor | ||||
from .module import QATModule | |||||
from .module import Module | |||||
class Concat(QATModule): | |||||
class Concat(Module): | |||||
r""" | r""" | ||||
A :class:`~.QATModule` to do functional concat, should replace concat with this module, | |||||
supporting ``qat`` mode and ``quantized`` mode. | |||||
A :class:`~.Module` to do functional concat. Could be replaced with :class:`~.QATModule` | |||||
version :class:`~.qat.concat.Concat` using :func:`~.quantize.quantize_qat`. | |||||
""" | """ | ||||
def forward(self, inps: Iterable[Tensor], axis: int = 0): | def forward(self, inps: Iterable[Tensor], axis: int = 0): | ||||
return F.concat(inps, axis) | return F.concat(inps, axis) | ||||
def forward_qat(self, inps: Iterable[Tensor], axis: int = 0): | |||||
return self.apply_fakequant_with_observer( | |||||
self.forward(inps, axis), self.act_fake_quant, self.act_observer | |||||
) |
@@ -7,14 +7,13 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from typing import Tuple, Union | from typing import Tuple, Union | ||||
from ..core import ones, zeros | |||||
from ..functional import add_update, flatten, relu, sqrt, sum, zero_grad | |||||
from ..functional import relu | |||||
from .batchnorm import BatchNorm2d | from .batchnorm import BatchNorm2d | ||||
from .conv import Conv2d | from .conv import Conv2d | ||||
from .module import QATModule | |||||
from .module import Module | |||||
class _ConvBn2d(QATModule): | |||||
class _ConvBnActivation2d(Module): | |||||
def __init__( | def __init__( | ||||
self, | self, | ||||
in_channels: int, | in_channels: int, | ||||
@@ -47,171 +46,24 @@ class _ConvBn2d(QATModule): | |||||
) | ) | ||||
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | ||||
def get_batch_mean_var(self, inp): | |||||
def _sum_channel(inp, axis=0, keepdims=True): | |||||
if isinstance(axis, int): | |||||
out = sum(inp, axis=axis, keepdims=keepdims) | |||||
elif isinstance(axis, tuple): | |||||
for idx, elem in enumerate(axis): | |||||
out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) | |||||
return out | |||||
sum1 = _sum_channel(inp, (0, 2, 3)) | |||||
sum2 = _sum_channel(inp ** 2, (0, 2, 3)) | |||||
reduce_size = inp.shapeof().prod() / inp.shapeof(1) | |||||
batch_mean = sum1 / reduce_size | |||||
batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size | |||||
return batch_mean, batch_var | |||||
def fold_weight_bias(self, bn_mean, bn_var): | |||||
# get fold bn conv param | |||||
# bn_istd = 1 / bn_std | |||||
# w_fold = gamma / bn_std * W | |||||
# b_fold = gamma * (b - bn_mean) / bn_std + beta | |||||
gamma = self.bn.weight | |||||
if gamma is None: | |||||
gamma = ones((self.bn.num_features), dtype="float32") | |||||
gamma = gamma.reshape(1, -1, 1, 1) | |||||
beta = self.bn.bias | |||||
if beta is None: | |||||
beta = zeros((self.bn.num_features), dtype="float32") | |||||
beta = beta.reshape(1, -1, 1, 1) | |||||
if bn_mean is None: | |||||
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||||
if bn_var is None: | |||||
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") | |||||
conv_bias = self.conv.bias | |||||
if conv_bias is None: | |||||
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||||
# bn_istd = 1 / bn_std | |||||
# w_fold = gamma / bn_std * W | |||||
scale_factor = gamma * bn_istd | |||||
if self.conv.groups == 1: | |||||
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) | |||||
else: | |||||
w_fold = self.conv.weight * scale_factor.reshape( | |||||
self.conv.groups, -1, 1, 1, 1 | |||||
) | |||||
# b_fold = gamma * (b - bn_mean) / bn_std + beta | |||||
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | |||||
return w_fold, b_fold | |||||
def update_running_mean_and_running_var( | |||||
self, bn_mean, bn_var, num_elements_per_channel | |||||
): | |||||
# update running mean and running var. no grad, use unbiased bn var | |||||
bn_mean = zero_grad(bn_mean) | |||||
bn_var = ( | |||||
zero_grad(bn_var) | |||||
* num_elements_per_channel | |||||
/ (num_elements_per_channel - 1) | |||||
) | |||||
exponential_average_factor = 1 - self.bn.momentum | |||||
add_update( | |||||
self.bn.running_mean, | |||||
delta=bn_mean, | |||||
alpha=1 - exponential_average_factor, | |||||
beta=exponential_average_factor, | |||||
) | |||||
add_update( | |||||
self.bn.running_var, | |||||
delta=bn_var, | |||||
alpha=1 - exponential_average_factor, | |||||
beta=exponential_average_factor, | |||||
) | |||||
def calc_conv_bn_qat(self, inp, approx=True): | |||||
if self.training and not approx: | |||||
conv = self.conv(inp) | |||||
bn_mean, bn_var = self.get_batch_mean_var(conv) | |||||
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) | |||||
self.update_running_mean_and_running_var( | |||||
bn_mean, bn_var, num_elements_per_channel | |||||
) | |||||
else: | |||||
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var | |||||
# get gamma and beta in BatchNorm | |||||
gamma = self.bn.weight | |||||
if gamma is None: | |||||
gamma = ones((self.bn.num_features), dtype="float32") | |||||
gamma = gamma.reshape(1, -1, 1, 1) | |||||
beta = self.bn.bias | |||||
if beta is None: | |||||
beta = zeros((self.bn.num_features), dtype="float32") | |||||
beta = beta.reshape(1, -1, 1, 1) | |||||
# conv_bias | |||||
conv_bias = self.conv.bias | |||||
if conv_bias is None: | |||||
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||||
# bn_istd = 1 / bn_std | |||||
# w_fold = gamma / bn_std * W | |||||
scale_factor = gamma * bn_istd | |||||
if self.conv.groups == 1: | |||||
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) | |||||
else: | |||||
w_fold = self.conv.weight * scale_factor.reshape( | |||||
self.conv.groups, -1, 1, 1, 1 | |||||
) | |||||
b_fold = None | |||||
if not (self.training and approx): | |||||
# b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta | |||||
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | |||||
w_qat = self.apply_fakequant_with_observer( | |||||
w_fold, self.weight_fake_quant, self.weight_observer | |||||
) | |||||
conv = self.conv.calc_conv(inp, w_qat, b_fold) | |||||
if not (self.training and approx): | |||||
return conv | |||||
# rescale conv to get original conv output | |||||
orig_conv = conv / scale_factor.reshape(1, -1, 1, 1) | |||||
if self.conv.bias is not None: | |||||
orig_conv = orig_conv + self.conv.bias | |||||
# calculate batch norm | |||||
bn_mean, bn_var = self.get_batch_mean_var(orig_conv) | |||||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||||
conv = gamma * bn_istd * (orig_conv - bn_mean) + beta | |||||
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) | |||||
self.update_running_mean_and_running_var( | |||||
bn_mean, bn_var, num_elements_per_channel | |||||
) | |||||
return conv | |||||
class ConvBn2d(_ConvBn2d): | |||||
class ConvBn2d(_ConvBnActivation2d): | |||||
r""" | r""" | ||||
A fused :class:`~.QATModule` including Conv2d and BatchNorm2d, supporting ``qat`` mode | |||||
and ``normal`` mode. | |||||
A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced | |||||
with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBn2d` using | |||||
:func:`~.quantize.quantize_qat`. | |||||
""" | """ | ||||
def forward_qat(self, inp): | |||||
return self.apply_fakequant_with_observer( | |||||
self.calc_conv_bn_qat(inp), self.act_fake_quant, self.act_observer | |||||
) | |||||
def forward(self, inp): | def forward(self, inp): | ||||
return self.bn(self.conv(inp)) | return self.bn(self.conv(inp)) | ||||
class ConvBnRelu2d(_ConvBn2d): | |||||
class ConvBnRelu2d(_ConvBnActivation2d): | |||||
r""" | r""" | ||||
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu, supporting ``qat`` | |||||
mode and ``normal`` mode. | |||||
A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced | |||||
with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBnRelu2d` using | |||||
:func:`~.quantize.quantize_qat`. | |||||
""" | """ | ||||
def forward_qat(self, inp): | |||||
return self.apply_fakequant_with_observer( | |||||
relu(self.calc_conv_bn_qat(inp)), self.act_fake_quant, self.act_observer | |||||
) | |||||
def forward(self, inp): | def forward(self, inp): | ||||
return relu(self.bn(self.conv(inp))) | return relu(self.bn(self.conv(inp))) |
@@ -8,7 +8,7 @@ | |||||
from .. import _internal as mgb | from .. import _internal as mgb | ||||
from ..core import Tensor, wrap_io_tensor | from ..core import Tensor, wrap_io_tensor | ||||
from ..core.graph import _use_default_if_none | from ..core.graph import _use_default_if_none | ||||
from .module import QATModule | |||||
from .module import Module | |||||
@wrap_io_tensor | @wrap_io_tensor | ||||
@@ -22,10 +22,10 @@ def _elemwise_func(mode, *inputs, **kwargs) -> Tensor: | |||||
return mgb.opr.elemwise(*inputs, mode=mode, **kwargs) | return mgb.opr.elemwise(*inputs, mode=mode, **kwargs) | ||||
class Elemwise(QATModule): | |||||
class Elemwise(Module): | |||||
r""" | r""" | ||||
A :class:`~.QATModule` to do elemwise operator, should functional operator with this module, | |||||
supporting ``qat`` mode and ``normal`` mode. | |||||
A :class:`~.Module` to do elemwise operator. Could be replaced with :class:`~.QATModule` | |||||
version :class:`~.qat.elemwise.Elemwise` using :func:`~.quantize.quantize_qat`. | |||||
:param method: the elemwise method, support the following string. | :param method: the elemwise method, support the following string. | ||||
It will do the normal elemwise operator for float. | It will do the normal elemwise operator for float. | ||||
@@ -88,8 +88,3 @@ class Elemwise(QATModule): | |||||
def forward(self, *inps): | def forward(self, *inps): | ||||
return _elemwise_func(self.method, *inps) | return _elemwise_func(self.method, *inps) | ||||
def forward_qat(self, *inps): | |||||
return self.apply_fakequant_with_observer( | |||||
self.forward(*inps), self.act_fake_quant, self.act_observer, | |||||
) |
@@ -1,4 +1,3 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
# | # | ||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
@@ -11,10 +10,10 @@ import numpy as np | |||||
from .. import functional as F | from .. import functional as F | ||||
from ..core import Parameter | from ..core import Parameter | ||||
from . import init | from . import init | ||||
from .module import QATModule | |||||
from .module import Module | |||||
class Linear(QATModule): | |||||
class Linear(Module): | |||||
r"""Applies a linear transformation to the input. For instance, if input | r"""Applies a linear transformation to the input. For instance, if input | ||||
is x, then output y is: | is x, then output y is: | ||||
@@ -60,13 +59,3 @@ class Linear(QATModule): | |||||
def forward(self, x): | def forward(self, x): | ||||
return self._calc_linear(x, self.weight, self.bias) | return self._calc_linear(x, self.weight, self.bias) | ||||
def forward_qat(self, x): | |||||
w_qat = self.apply_fakequant_with_observer( | |||||
self.weight, self.weight_fake_quant, self.weight_observer | |||||
) | |||||
return self.apply_fakequant_with_observer( | |||||
self._calc_linear(x, w_qat, self.bias), | |||||
self.act_fake_quant, | |||||
self.act_observer, | |||||
) |
@@ -7,7 +7,6 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from abc import ABCMeta, abstractmethod | from abc import ABCMeta, abstractmethod | ||||
from collections import OrderedDict | from collections import OrderedDict | ||||
from enum import Enum | |||||
from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | ||||
import numpy as np | import numpy as np | ||||
@@ -443,98 +442,3 @@ class Module(metaclass=ABCMeta): | |||||
loaded.append(k) | loaded.append(k) | ||||
return set(loaded), set(skipped) | return set(loaded), set(skipped) | ||||
class QATModule(Module): | |||||
r""" | |||||
Base class of quantization related Module. Add extra forward methods | |||||
:meth:`~.QATModule.forward_qat` and :meth:`~.QATModule.forward_quantized` for | |||||
``qat``(quantization aware training) mode and ``quantized`` mode respectively. | |||||
Use :meth:`~.QATModule.quant` to switch between ``QAT`` and ``NORMAL`` mode, | |||||
and use :meth:`~.QATModule.to_quantized` to switch to ``quantized`` mode, | |||||
which is irreversible. | |||||
If you want to recursively switch mode for all QATModule in network, use | |||||
functions in :mod:`~.quantization.quantize`. | |||||
""" | |||||
class QATMode(Enum): | |||||
DISABLED = 1 | |||||
QAT = 2 | |||||
CALIBRATION = 3 | |||||
def __init__(self): | |||||
from ..quantization import ( | |||||
QConfig, | |||||
FakeQuantize, | |||||
Observer, | |||||
) # pylint: disable=all | |||||
super().__init__() | |||||
self.quantizing = self.QATMode.DISABLED | |||||
self.scale = None | |||||
self.weight_observer = None # type: Observer | |||||
self.act_observer = None # type: Observer | |||||
self.weight_fake_quant = None # type: FakeQuantize | |||||
self.act_fake_quant = None # type: FakeQuantize | |||||
def set_qconfig(self, qconfig: "QConfig"): | |||||
self.weight_observer = qconfig.weight_observer() | |||||
self.act_observer = qconfig.act_observer() | |||||
self.weight_fake_quant = ( | |||||
None | |||||
if qconfig.fake_quant is None | |||||
else qconfig.fake_quant(self.weight_observer.dtype) | |||||
) | |||||
self.act_fake_quant = ( | |||||
None | |||||
if qconfig.fake_quant is None | |||||
else qconfig.fake_quant(self.act_observer.dtype) | |||||
) | |||||
def apply_observer(self, target: Tensor, obs: "Observer"): | |||||
return obs(target) | |||||
def apply_fakequant_with_observer( | |||||
self, target: Tensor, fq: "FakeQuantize", obs: "Observer" | |||||
): | |||||
oup = self.apply_observer(target, obs) | |||||
if fq is not None: | |||||
q_dict = obs.get_qparams() | |||||
oup = fq(oup, q_dict) | |||||
return oup | |||||
def set_qat_mode(self, mode: QATMode): | |||||
r""" | |||||
Change ``self.quantizing`` mode, available values: ``self.QATMode.DISABLED``, | |||||
``QAT``,``CALIBRATION``. | |||||
""" | |||||
if not isinstance(mode, self.QATMode): | |||||
raise TypeError("mode must be QATMode Enum type") | |||||
self.quantizing = mode | |||||
def to_quantized(self): | |||||
r""" | |||||
Return a new :class:`~.Module` with quantized parameters of ``self`` | |||||
according to scale and zero_point in ``self.xxx_observer``. | |||||
""" | |||||
raise NotImplementedError( | |||||
"Use megengine.quantization.quantize to register the method." | |||||
) | |||||
@abstractmethod | |||||
def forward_qat(self, *args, **kwargs): | |||||
r""" | |||||
Forward method for ``qat`` mode. | |||||
""" | |||||
def __call__(self, *args, **kwargs): | |||||
if self.quantizing == self.QATMode.DISABLED: | |||||
return self.forward(*args, **kwargs) | |||||
else: | |||||
return self.forward_qat(*args, **kwargs) |
@@ -0,0 +1,13 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from .concat import Concat | |||||
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||||
from .elemwise import Elemwise | |||||
from .linear import Linear | |||||
from .module import QATModule | |||||
from .quant_dequant import DequantStub, QuantStub |
@@ -0,0 +1,30 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from typing import Iterable | |||||
from ...core.tensor import Tensor | |||||
from .. import concat as Float | |||||
from .module import QATModule | |||||
class Concat(Float.Concat, QATModule): | |||||
r""" | |||||
A :class:`~.QATModule` to do functional concat with QAT support. | |||||
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||||
""" | |||||
def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||||
return self.apply_quant_activation(super().forward(inps, axis)) | |||||
@classmethod | |||||
def from_float_module(cls, float_module): | |||||
r""" | |||||
Return a :class:`~.QATModule` instance converted from | |||||
a float :class:`~.Module` instance. | |||||
""" | |||||
return cls() |
@@ -0,0 +1,193 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from ...core import ones, zeros | |||||
from ...functional import add_update, relu, sqrt, sum, zero_grad | |||||
from .. import conv_bn_relu as Float | |||||
from .module import QATModule | |||||
class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): | |||||
def get_batch_mean_var(self, inp): | |||||
def _sum_channel(inp, axis=0, keepdims=True): | |||||
if isinstance(axis, int): | |||||
out = sum(inp, axis=axis, keepdims=keepdims) | |||||
elif isinstance(axis, tuple): | |||||
for idx, elem in enumerate(axis): | |||||
out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) | |||||
return out | |||||
sum1 = _sum_channel(inp, (0, 2, 3)) | |||||
sum2 = _sum_channel(inp ** 2, (0, 2, 3)) | |||||
reduce_size = inp.shapeof().prod() / inp.shapeof(1) | |||||
batch_mean = sum1 / reduce_size | |||||
batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size | |||||
return batch_mean, batch_var | |||||
def fold_weight_bias(self, bn_mean, bn_var): | |||||
# get fold bn conv param | |||||
# bn_istd = 1 / bn_std | |||||
# w_fold = gamma / bn_std * W | |||||
# b_fold = gamma * (b - bn_mean) / bn_std + beta | |||||
gamma = self.bn.weight | |||||
if gamma is None: | |||||
gamma = ones((self.bn.num_features), dtype="float32") | |||||
gamma = gamma.reshape(1, -1, 1, 1) | |||||
beta = self.bn.bias | |||||
if beta is None: | |||||
beta = zeros((self.bn.num_features), dtype="float32") | |||||
beta = beta.reshape(1, -1, 1, 1) | |||||
if bn_mean is None: | |||||
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||||
if bn_var is None: | |||||
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") | |||||
conv_bias = self.conv.bias | |||||
if conv_bias is None: | |||||
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||||
# bn_istd = 1 / bn_std | |||||
# w_fold = gamma / bn_std * W | |||||
scale_factor = gamma * bn_istd | |||||
if self.conv.groups == 1: | |||||
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) | |||||
else: | |||||
w_fold = self.conv.weight * scale_factor.reshape( | |||||
self.conv.groups, -1, 1, 1, 1 | |||||
) | |||||
# b_fold = gamma * (b - bn_mean) / bn_std + beta | |||||
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | |||||
return w_fold, b_fold | |||||
def update_running_mean_and_running_var( | |||||
self, bn_mean, bn_var, num_elements_per_channel | |||||
): | |||||
# update running mean and running var. no grad, use unbiased bn var | |||||
bn_mean = zero_grad(bn_mean) | |||||
bn_var = ( | |||||
zero_grad(bn_var) | |||||
* num_elements_per_channel | |||||
/ (num_elements_per_channel - 1) | |||||
) | |||||
exponential_average_factor = 1 - self.bn.momentum | |||||
add_update( | |||||
self.bn.running_mean, | |||||
delta=bn_mean, | |||||
alpha=1 - exponential_average_factor, | |||||
beta=exponential_average_factor, | |||||
) | |||||
add_update( | |||||
self.bn.running_var, | |||||
delta=bn_var, | |||||
alpha=1 - exponential_average_factor, | |||||
beta=exponential_average_factor, | |||||
) | |||||
def calc_conv_bn_qat(self, inp, approx=True): | |||||
if self.training and not approx: | |||||
conv = self.conv(inp) | |||||
bn_mean, bn_var = self.get_batch_mean_var(conv) | |||||
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) | |||||
self.update_running_mean_and_running_var( | |||||
bn_mean, bn_var, num_elements_per_channel | |||||
) | |||||
else: | |||||
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var | |||||
# get gamma and beta in BatchNorm | |||||
gamma = self.bn.weight | |||||
if gamma is None: | |||||
gamma = ones((self.bn.num_features), dtype="float32") | |||||
gamma = gamma.reshape(1, -1, 1, 1) | |||||
beta = self.bn.bias | |||||
if beta is None: | |||||
beta = zeros((self.bn.num_features), dtype="float32") | |||||
beta = beta.reshape(1, -1, 1, 1) | |||||
# conv_bias | |||||
conv_bias = self.conv.bias | |||||
if conv_bias is None: | |||||
conv_bias = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||||
# bn_istd = 1 / bn_std | |||||
# w_fold = gamma / bn_std * W | |||||
scale_factor = gamma * bn_istd | |||||
if self.conv.groups == 1: | |||||
w_fold = self.conv.weight * scale_factor.reshape(-1, 1, 1, 1) | |||||
else: | |||||
w_fold = self.conv.weight * scale_factor.reshape( | |||||
self.conv.groups, -1, 1, 1, 1 | |||||
) | |||||
b_fold = None | |||||
if not (self.training and approx): | |||||
# b_fold = gamma * (conv_bias - bn_mean) / bn_std + beta | |||||
b_fold = beta + gamma * (conv_bias - bn_mean) * bn_istd | |||||
w_qat = self.apply_quant_weight(w_fold) | |||||
conv = self.conv.calc_conv(inp, w_qat, b_fold) | |||||
if not (self.training and approx): | |||||
return conv | |||||
# rescale conv to get original conv output | |||||
orig_conv = conv / scale_factor.reshape(1, -1, 1, 1) | |||||
if self.conv.bias is not None: | |||||
orig_conv = orig_conv + self.conv.bias | |||||
# calculate batch norm | |||||
bn_mean, bn_var = self.get_batch_mean_var(orig_conv) | |||||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||||
conv = gamma * bn_istd * (orig_conv - bn_mean) + beta | |||||
num_elements_per_channel = conv.shapeof().prod() / conv.shapeof(1) | |||||
self.update_running_mean_and_running_var( | |||||
bn_mean, bn_var, num_elements_per_channel | |||||
) | |||||
return conv | |||||
@classmethod | |||||
def from_float_module(cls, float_module: Float._ConvBnActivation2d): | |||||
r""" | |||||
Return a :class:`~.QATModule` instance converted from | |||||
a float :class:`~.Module` instance. | |||||
""" | |||||
qat_module = cls( | |||||
float_module.conv.in_channels, | |||||
float_module.conv.out_channels, | |||||
float_module.conv.kernel_size, | |||||
float_module.conv.stride, | |||||
float_module.conv.padding, | |||||
float_module.conv.dilation, | |||||
float_module.conv.groups, | |||||
bool(float_module.conv.bias), | |||||
float_module.conv.conv_mode.name, | |||||
float_module.conv.compute_mode.name, | |||||
) | |||||
qat_module.conv.weight = float_module.conv.weight | |||||
qat_module.conv.bias = float_module.conv.bias | |||||
qat_module.bn = float_module.bn | |||||
return qat_module | |||||
class ConvBn2d(_ConvBnActivation2d): | |||||
r""" | |||||
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d with QAT support. | |||||
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||||
""" | |||||
def forward(self, inp): | |||||
return self.apply_quant_activation(self.calc_conv_bn_qat(inp)) | |||||
class ConvBnRelu2d(_ConvBnActivation2d): | |||||
r""" | |||||
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu with QAT support. | |||||
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||||
""" | |||||
def forward(self, inp): | |||||
return self.apply_quant_activation(relu(self.calc_conv_bn_qat(inp))) |
@@ -0,0 +1,29 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from .. import elemwise as Float | |||||
from .module import QATModule | |||||
class Elemwise(Float.Elemwise, QATModule): | |||||
r""" | |||||
A :class:`~.QATModule` to do elemwise operator with QAT support. | |||||
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||||
:param method: the elemwise method, see :class:`~.module.elemwise.Elemwise` for detail. | |||||
""" | |||||
def forward(self, *inps): | |||||
return self.apply_quant_activation(super().forward(*inps)) | |||||
@classmethod | |||||
def from_float_module(cls, float_module: Float.Elemwise): | |||||
r""" | |||||
Return a :class:`~.QATModule` instance converted from | |||||
a float :class:`~.Module` instance. | |||||
""" | |||||
return cls(float_module.method.name) |
@@ -0,0 +1,37 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from .. import linear as Float | |||||
from .module import QATModule | |||||
class Linear(Float.Linear, QATModule): | |||||
r""" | |||||
A :class:`~.QATModule` version of :class:`~.module.linear.Linear`. | |||||
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||||
:param in_features: size of each input sample. | |||||
:param out_features: size of each output sample. | |||||
:param bias: If set to ``False``, the layer will not learn an additive bias. | |||||
Default: ``True`` | |||||
""" | |||||
def forward(self, x): | |||||
w_qat = self.apply_quant_weight(self.weight) | |||||
return self.apply_quant_activation(self._calc_linear(x, w_qat, self.bias),) | |||||
@classmethod | |||||
def from_float_module(cls, float_module: Float.Linear): | |||||
r""" | |||||
Return a :class:`~.QATModule` instance converted from | |||||
a float :class:`~.Module` instance. | |||||
""" | |||||
qmod = cls(float_module.in_features, float_module.out_features) | |||||
qmod.weight = float_module.weight | |||||
qmod.bias = float_module.bias | |||||
return qmod |
@@ -0,0 +1,96 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from abc import abstractmethod | |||||
from ...core import Tensor | |||||
from ...quantization import FakeQuantize, Observer, QConfig | |||||
from ..module import Module | |||||
class QATModule(Module): | |||||
r""" | |||||
Base class of quantized-float related Module, basically for QAT and Calibration. | |||||
Use :meth:`~.QATModule.from_float_module` to generate a instance from float :class:`~.Module`. | |||||
Or use :func:`~.quantize.quantize_qat` to do it recursively and automatically. | |||||
Can also be converted to :class:`~.QuantizedModule` for deployment using | |||||
:func:`~.quantize.quantize` further. | |||||
""" | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.scale = None | |||||
self.weight_observer = None # type: Observer | |||||
self.act_observer = None # type: Observer | |||||
self.weight_fake_quant = None # type: FakeQuantize | |||||
self.act_fake_quant = None # type: FakeQuantize | |||||
def set_qconfig(self, qconfig: QConfig): | |||||
r""" | |||||
Set quantization related configs with ``qconfig``, including | |||||
observer and fake_quant for weight and activation. | |||||
""" | |||||
self.weight_observer = qconfig.weight_observer() | |||||
self.act_observer = qconfig.act_observer() | |||||
if qconfig.fake_quant is None: | |||||
self.weight_fake_quant = None | |||||
self.act_fake_quant = None | |||||
else: | |||||
self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype) | |||||
self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype) | |||||
def _apply_fakequant_with_observer( | |||||
self, target: Tensor, fake_quant: FakeQuantize, observer: Observer | |||||
): | |||||
oup = observer(target) | |||||
if fake_quant is None: | |||||
return oup | |||||
else: | |||||
q_dict = observer.get_qparams() | |||||
return fake_quant(oup, q_dict) | |||||
def apply_quant_weight(self, target: Tensor): | |||||
r""" | |||||
Apply weight's observer and fake_quant from ``qconfig`` on ``target``. | |||||
""" | |||||
return self._apply_fakequant_with_observer( | |||||
target, self.weight_fake_quant, self.weight_observer | |||||
) | |||||
def apply_quant_activation(self, target: Tensor): | |||||
r""" | |||||
Apply weight's observer and fake_quant from ``qconfig`` on ``target``. | |||||
""" | |||||
return self._apply_fakequant_with_observer( | |||||
target, self.act_fake_quant, self.act_observer | |||||
) | |||||
def get_weight_dtype(self): | |||||
r""" | |||||
Get weight's quantization dtype as the method from ``qconfig``. | |||||
""" | |||||
return self.weight_observer.get_dtype() | |||||
def get_activation_dtype(self): | |||||
r""" | |||||
Get activation's quantization dtype as the method from ``qconfig``. | |||||
""" | |||||
return self.act_observer.get_dtype() | |||||
@classmethod | |||||
@abstractmethod | |||||
def from_float_module(cls, float_module: Module): | |||||
r""" | |||||
Return a :class:`~.QATModule` instance converted from | |||||
a float :class:`~.Module` instance. | |||||
""" |
@@ -0,0 +1,45 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from .. import quant_dequant as Float | |||||
from .module import QATModule | |||||
class QuantStub(Float.QuantStub, QATModule): | |||||
r""" | |||||
A helper QATModule simply return input, but will quantize | |||||
input after converted to :class:`~.QuantizedModule`. | |||||
""" | |||||
def forward(self, inp): | |||||
return self.apply_quant_activation(inp) | |||||
@classmethod | |||||
def from_float_module(cls, float_module: Float.QuantStub): | |||||
r""" | |||||
Return a :class:`~.QATModule` instance converted from | |||||
a float :class:`~.Module` instance. | |||||
""" | |||||
return cls() | |||||
class DequantStub(Float.DequantStub, QATModule): | |||||
r""" | |||||
A helper QATModule simply return input, but will de-quantize | |||||
input after converted to :class:`~.QuantizedModule`. | |||||
""" | |||||
def forward(self, inp): | |||||
return inp | |||||
@classmethod | |||||
def from_float_module(cls, float_module: Float.DequantStub): | |||||
r""" | |||||
Return a :class:`~.QATModule` instance converted from | |||||
a float :class:`~.Module` instance. | |||||
""" | |||||
return cls() |
@@ -5,30 +5,24 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .module import QATModule | |||||
from .module import Module | |||||
class QuantStub(QATModule): | |||||
class QuantStub(Module): | |||||
r""" | r""" | ||||
A helper QATModule doing quantize operation on input. | |||||
A helper :class:`~.Module` simply returning input. Could be replaced with :class:`~.QATModule` | |||||
version :class:`~.qat.QuantStub` using :func:`~.quantize.quantize_qat`. | |||||
""" | """ | ||||
def forward(self, inp): | def forward(self, inp): | ||||
return inp | return inp | ||||
def forward_qat(self, inp): | |||||
return self.apply_fakequant_with_observer( | |||||
inp, self.act_fake_quant, self.act_observer | |||||
) | |||||
class DequantStub(QATModule): | |||||
class DequantStub(Module): | |||||
r""" | r""" | ||||
A helper QATModule doing de-quantize operation on input. | |||||
A helper :class:`~.Module` simply returning input. Could be replaced with :class:`~.QATModule` | |||||
version :class:`~.qat.DequantStub` using :func:`~.quantize.quantize_qat`. | |||||
""" | """ | ||||
def forward(self, inp): | def forward(self, inp): | ||||
return inp | return inp | ||||
def forward_qat(self, inp): | |||||
return inp |
@@ -9,4 +9,5 @@ from .concat import Concat | |||||
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | ||||
from .elemwise import Elemwise | from .elemwise import Elemwise | ||||
from .linear import Linear | from .linear import Linear | ||||
from .module import QuantizedModule | |||||
from .quant_dequant import DequantStub, QuantStub | from .quant_dequant import DequantStub, QuantStub |
@@ -7,17 +7,15 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from typing import Iterable | from typing import Iterable | ||||
from ... import _internal as mgb | |||||
from ... import functional as F | from ... import functional as F | ||||
from ... import module as Float | |||||
from ...core.tensor import Tensor | from ...core.tensor import Tensor | ||||
from ...quantization.utils import register_method_to_class | |||||
from ..module import Module | |||||
from ..qat import concat as QAT | |||||
from .module import QuantizedModule | |||||
class Concat(Module): | |||||
class Concat(QuantizedModule): | |||||
r""" | r""" | ||||
A :class:`~.Module` to do quantized concat, inference only. | |||||
A :class:`~.QuantizedModule` to do quantized concat, inference only. | |||||
""" | """ | ||||
def __init__(self, dtype=None): | def __init__(self, dtype=None): | ||||
@@ -25,16 +23,13 @@ class Concat(Module): | |||||
self.output_dtype = dtype | self.output_dtype = dtype | ||||
def forward(self, inps: Iterable[Tensor], axis: int = 0): | def forward(self, inps: Iterable[Tensor], axis: int = 0): | ||||
if self.training: | |||||
raise ValueError("quantized module only support inference.") | |||||
new_inps = (x.astype(self.output_dtype) for x in inps) | new_inps = (x.astype(self.output_dtype) for x in inps) | ||||
return F.concat(new_inps, axis) | return F.concat(new_inps, axis) | ||||
@register_method_to_class(Float.Concat) | |||||
def to_quantized(float_module): | |||||
r""" | |||||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||||
implemented here to avoid circular import. | |||||
""" | |||||
return Concat(float_module.act_observer.get_dtype()) | |||||
@classmethod | |||||
def from_qat_module(cls, qat_module: QAT.Concat): | |||||
r""" | |||||
return a :class:`~.QuantizedModule` instance converted from a | |||||
:class:`~.QATModule` instance. | |||||
""" | |||||
return cls(qat_module.get_activation_dtype()) |
@@ -5,7 +5,6 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from functools import partial | |||||
from typing import Tuple, Union | from typing import Tuple, Union | ||||
import megengine._internal as mgb | import megengine._internal as mgb | ||||
@@ -13,11 +12,11 @@ import megengine._internal as mgb | |||||
from ... import module as Float | from ... import module as Float | ||||
from ...core import Parameter | from ...core import Parameter | ||||
from ...functional import conv_bias_activation | from ...functional import conv_bias_activation | ||||
from ...module import Conv2d | |||||
from ...quantization.utils import register_method_to_class | |||||
from ..qat import conv_bn_relu as QAT | |||||
from .module import QuantizedModule | |||||
class _ConvBnActivation2d(Conv2d): | |||||
class _ConvBnActivation2d(Float.Conv2d, QuantizedModule): | |||||
r"""Applies a 2D convolution over an quantized input tensor, inference only. | r"""Applies a 2D convolution over an quantized input tensor, inference only. | ||||
The parameter is same with :class: `~.Conv2d` | The parameter is same with :class: `~.Conv2d` | ||||
@@ -68,44 +67,41 @@ class _ConvBnActivation2d(Conv2d): | |||||
nonlinear_mode=nonlinear_mode, | nonlinear_mode=nonlinear_mode, | ||||
) | ) | ||||
@classmethod | |||||
def from_qat_module(cls, qat_module: QAT._ConvBnActivation2d): | |||||
r""" | |||||
return a :class:`~.QuantizedModule` instance converted from a | |||||
:class:`~.QATModule` instance. | |||||
""" | |||||
output_dtype = qat_module.get_activation_dtype() | |||||
qconv = cls( | |||||
qat_module.conv.in_channels, | |||||
qat_module.conv.out_channels, | |||||
qat_module.conv.kernel_size, | |||||
qat_module.conv.stride, | |||||
qat_module.conv.padding, | |||||
qat_module.conv.dilation, | |||||
qat_module.conv.groups, | |||||
dtype=output_dtype, | |||||
) | |||||
w_fold, b_fold = qat_module.fold_weight_bias( | |||||
qat_module.bn.running_mean, qat_module.bn.running_var | |||||
) | |||||
weight = w_fold.astype(qat_module.get_weight_dtype()) | |||||
qconv.weight = Parameter(weight.numpy()) | |||||
qconv.bias = Parameter(b_fold.numpy()) | |||||
return qconv | |||||
class ConvBn2d(_ConvBnActivation2d): | class ConvBn2d(_ConvBnActivation2d): | ||||
r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBn2d`.""" | |||||
def forward(self, inp): | def forward(self, inp): | ||||
if self.training: | |||||
raise ValueError("quantized module only support inference.") | |||||
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") | return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") | ||||
class ConvBnRelu2d(_ConvBnActivation2d): | class ConvBnRelu2d(_ConvBnActivation2d): | ||||
r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBnRelu2d`.""" | |||||
def forward(self, inp): | def forward(self, inp): | ||||
if self.training: | |||||
raise ValueError("quantized module only support inference.") | |||||
return self.calc_conv_quantized(inp, nonlinear_mode="RELU") | return self.calc_conv_quantized(inp, nonlinear_mode="RELU") | ||||
def to_quantized(quantized_class, float_module): | |||||
output_dtype = float_module.act_observer.get_dtype() | |||||
qconv = quantized_class( | |||||
float_module.conv.in_channels, | |||||
float_module.conv.out_channels, | |||||
float_module.conv.kernel_size, | |||||
float_module.conv.stride, | |||||
float_module.conv.padding, | |||||
float_module.conv.dilation, | |||||
float_module.conv.groups, | |||||
dtype=output_dtype, | |||||
) | |||||
w_fold, b_fold = float_module.fold_weight_bias( | |||||
float_module.bn.running_mean, float_module.bn.running_var | |||||
) | |||||
weight = w_fold.astype(float_module.weight_observer.get_dtype()) | |||||
qconv.weight = Parameter(weight.numpy()) | |||||
qconv.bias = Parameter(b_fold.numpy()) | |||||
return qconv | |||||
# replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||||
# implemented here to avoid circular import. | |||||
register_method_to_class(Float.ConvBn2d)(partial(to_quantized, ConvBn2d)) | |||||
register_method_to_class(Float.ConvBnRelu2d)(partial(to_quantized, ConvBnRelu2d)) |
@@ -6,11 +6,10 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from ... import _internal as mgb | from ... import _internal as mgb | ||||
from ... import module as Float | |||||
from ...core import Tensor, wrap_io_tensor | from ...core import Tensor, wrap_io_tensor | ||||
from ...core.graph import _use_default_if_none | from ...core.graph import _use_default_if_none | ||||
from ...quantization.utils import register_method_to_class | |||||
from ..module import Module | |||||
from ..qat import elemwise as QAT | |||||
from .module import QuantizedModule | |||||
@wrap_io_tensor | @wrap_io_tensor | ||||
@@ -24,13 +23,8 @@ def _elemwise_multi_type(mode, *inputs, **kwargs) -> Tensor: | |||||
return mgb.opr.elemwise_multi_type(*inputs, mode=mode, **kwargs) | return mgb.opr.elemwise_multi_type(*inputs, mode=mode, **kwargs) | ||||
class Elemwise(Module): | |||||
r""" | |||||
quantized module for elemwise operator, inference only. | |||||
:param method: the elemwise method, supported string refer to :class:`~.module.elemwise.Elemwise`. | |||||
it will do quantized operator with specified output quantized dtype. | |||||
""" | |||||
class Elemwise(QuantizedModule): | |||||
r"""quantized version of :class:`~.qat.elemwise.Elemwise`.""" | |||||
_elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode | _elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode | ||||
@@ -44,11 +38,10 @@ class Elemwise(Module): | |||||
raise ValueError("quantized module only support inference.") | raise ValueError("quantized module only support inference.") | ||||
return _elemwise_multi_type(self.method, *inps, dtype=self.output_dtype) | return _elemwise_multi_type(self.method, *inps, dtype=self.output_dtype) | ||||
@register_method_to_class(Float.Elemwise) | |||||
def to_quantized(float_module): | |||||
r""" | |||||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||||
implemented here to avoid circular import. | |||||
""" | |||||
return Elemwise(float_module.method.name, float_module.act_observer.get_dtype()) | |||||
@classmethod | |||||
def from_qat_module(cls, qat_module: QAT.Elemwise): | |||||
r""" | |||||
return a :class:`~.QuantizedModule` instance converted from a | |||||
:class:`~.QATModule` instance. | |||||
""" | |||||
return cls(qat_module.method.name, qat_module.get_activation_dtype()) |
@@ -10,19 +10,13 @@ import numpy as np | |||||
import megengine._internal as mgb | import megengine._internal as mgb | ||||
from ... import functional as F | from ... import functional as F | ||||
from ... import module as Float | |||||
from ...core import Parameter | from ...core import Parameter | ||||
from ...quantization.utils import register_method_to_class | |||||
from ..module import Module | |||||
from ..qat import linear as QAT | |||||
from .module import QuantizedModule | |||||
class Linear(Module): | |||||
r"""Applies a quantized linear transformation to the input. The module | |||||
usually convert from QAT module by to_quantized method. | |||||
:param dtype: output data type. | |||||
""" | |||||
class Linear(QuantizedModule): | |||||
r"""quantized version of :class:`~.qat.linear.Linear`.""" | |||||
def __init__( | def __init__( | ||||
self, dtype: np.dtype = None, | self, dtype: np.dtype = None, | ||||
@@ -44,17 +38,16 @@ class Linear(Module): | |||||
None if self.bias is None else self.bias.astype(bias_dtype), | None if self.bias is None else self.bias.astype(bias_dtype), | ||||
).astype(self.output_dtype) | ).astype(self.output_dtype) | ||||
@register_method_to_class(Float.Linear) | |||||
def to_quantized(float_module): | |||||
r""" | |||||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||||
implemented here to avoid circular import. | |||||
""" | |||||
output_dtype = float_module.act_observer.get_dtype() | |||||
qmod = Linear(dtype=output_dtype,) | |||||
weight = float_module.weight.astype(float_module.weight_observer.get_dtype()) | |||||
qmod.weight = Parameter(weight.numpy()) | |||||
if float_module.bias is not None: | |||||
qmod.bias = Parameter(float_module.bias.numpy()) | |||||
return qmod | |||||
@classmethod | |||||
def from_qat_module(cls, qat_module: QAT.Linear): | |||||
r""" | |||||
return a :class:`~.QuantizedModule` instance converted from a | |||||
:class:`~.QATModule` instance. | |||||
""" | |||||
output_dtype = qat_module.get_activation_dtype() | |||||
qmod = cls(dtype=output_dtype) | |||||
weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | |||||
qmod.weight = Parameter(weight.numpy()) | |||||
if qat_module.bias is not None: | |||||
qmod.bias = Parameter(qat_module.bias.numpy()) | |||||
return qmod |
@@ -0,0 +1,31 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from abc import abstractmethod | |||||
from ..module import Module | |||||
from ..qat import QATModule | |||||
class QuantizedModule(Module): | |||||
r""" | |||||
Base class of quantized Module, which should be converted from QATModule | |||||
and not support traning. | |||||
""" | |||||
def __call__(self, *inputs, **kwargs): | |||||
if self.training: | |||||
raise ValueError("quantized module only support inference.") | |||||
return super().__call__(*inputs, **kwargs) | |||||
@classmethod | |||||
@abstractmethod | |||||
def from_qat_module(cls, qat_module: QATModule): | |||||
r""" | |||||
return a :class:`~.QuantizedModule` instance converted from a | |||||
:class:`~.QATModule` instance. | |||||
""" |
@@ -5,15 +5,14 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from ... import _internal as mgb | |||||
from ... import module as Float | |||||
from ...quantization.utils import register_method_to_class | |||||
from ..module import Module | |||||
from ..qat import quant_dequant as QAT | |||||
from .module import QuantizedModule | |||||
class QuantStub(Module): | |||||
class QuantStub(QuantizedModule): | |||||
r""" | r""" | ||||
A helper quantize operation on input and inference only. | |||||
quantized version of :class:`~.qat.quant_dequant.QuantStub`, | |||||
will convert input to quantized dtype. | |||||
""" | """ | ||||
def __init__(self, dtype=None): | def __init__(self, dtype=None): | ||||
@@ -21,35 +20,30 @@ class QuantStub(Module): | |||||
self.output_dtype = dtype | self.output_dtype = dtype | ||||
def forward(self, inp): | def forward(self, inp): | ||||
if self.training: | |||||
raise ValueError("quantized module only support inference.") | |||||
return inp.astype(self.output_dtype) | return inp.astype(self.output_dtype) | ||||
@classmethod | |||||
def from_qat_module(cls, qat_module: QAT.QuantStub): | |||||
r""" | |||||
return a :class:`~.QuantizedModule` instance converted from a | |||||
:class:`~.QATModule` instance. | |||||
""" | |||||
return cls(qat_module.get_activation_dtype()) | |||||
class DequantStub(Module): | |||||
class DequantStub(QuantizedModule): | |||||
r""" | r""" | ||||
A helper de-quantize operation and inference only. | |||||
quantized version of :class:`~.qat.quant_dequant.DequantStub`, | |||||
will restore quantized input to float32 dtype. | |||||
""" | """ | ||||
def forward(self, inp): | def forward(self, inp): | ||||
if self.training: | |||||
raise ValueError("quantized module only support inference.") | |||||
return inp.astype("float32") | return inp.astype("float32") | ||||
@register_method_to_class(Float.QuantStub) | |||||
def to_quantized(float_module): | |||||
r""" | |||||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||||
implemented here to avoid circular import. | |||||
""" | |||||
return QuantStub(float_module.act_observer.get_dtype()) | |||||
@register_method_to_class(Float.DequantStub) | |||||
def to_quantized(float_module): | |||||
r""" | |||||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||||
implemented here to avoid circular import. | |||||
""" | |||||
return DequantStub() | |||||
@classmethod | |||||
def from_qat_module(cls, qat_module: QAT.DequantStub): | |||||
r""" | |||||
return a :class:`~.QuantizedModule` instance converted from a | |||||
:class:`~.QATModule` instance. | |||||
""" | |||||
return cls() |
@@ -13,12 +13,3 @@ from .qconfig import ( | |||||
ema_fakequant_qconfig, | ema_fakequant_qconfig, | ||||
min_max_fakequant_qconfig, | min_max_fakequant_qconfig, | ||||
) | ) | ||||
from .quantize import ( | |||||
disable_fake_quant, | |||||
disable_observer, | |||||
enable_fake_quant, | |||||
enable_observer, | |||||
quantize, | |||||
quantize_calibration, | |||||
quantize_qat, | |||||
) |
@@ -15,16 +15,12 @@ from .observer import ( | |||||
class QConfig: | class QConfig: | ||||
""" | |||||
r""" | |||||
A config class indicating how to do quantize toward :class:`~.QATModule`'s | A config class indicating how to do quantize toward :class:`~.QATModule`'s | ||||
``activation`` and ``weight``. | |||||
And ``fake_quant`` parameter to indicate | |||||
See :meth:`~.QATModule.set_qconfig` for detail usage. | |||||
``activation`` and ``weight``. See :meth:`~.QATModule.set_qconfig` for detail usage. | |||||
:param weight_observer: interface to instantiate an :class:`~.Observer` indicating | :param weight_observer: interface to instantiate an :class:`~.Observer` indicating | ||||
- how to collect scales and zero_point of wegiht. | |||||
how to collect scales and zero_point of wegiht. | |||||
:param act_observer: similar to ``weight_observer`` but toward activation. | :param act_observer: similar to ``weight_observer`` but toward activation. | ||||
:param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating | :param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating | ||||
how to do fake_quant calculation. can be invoked multi times to get different | how to do fake_quant calculation. can be invoked multi times to get different | ||||
@@ -6,68 +6,125 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from copy import deepcopy | from copy import deepcopy | ||||
from ..module import Module, QATModule, Sequential, quantized | |||||
from typing import Dict, Tuple | |||||
from .. import module as Float | |||||
from ..module import Module | |||||
from ..module import qat as QAT | |||||
from ..module import quantized as Quantized | |||||
from ..module.qat import QATModule | |||||
from ..module.quantized import QuantizedModule | |||||
from .qconfig import QConfig, ema_fakequant_qconfig | from .qconfig import QConfig, ema_fakequant_qconfig | ||||
def _get_quantable_module_names(): | |||||
def is_quantable(key: str): | |||||
value = getattr(Quantized, key) | |||||
return ( | |||||
isinstance(value, type) | |||||
and issubclass(value, QuantizedModule) | |||||
and value != QuantizedModule | |||||
) | |||||
# source should have all quantable modules' names | |||||
quantable_module_names = [key for key in dir(Quantized) if is_quantable(key)] | |||||
return quantable_module_names | |||||
def _get_convert_dict() -> Tuple[ | |||||
Dict[Module, QATModule], Dict[QATModule, QuantizedModule] | |||||
]: | |||||
quantable_module_names = _get_quantable_module_names() | |||||
quantable_modules = [getattr(Float, key) for key in quantable_module_names] | |||||
qat_modules = [getattr(QAT, key) for key in quantable_module_names] | |||||
quantized_modules = [getattr(Quantized, key) for key in quantable_module_names] | |||||
float2qat_dict = dict(zip(quantable_modules, qat_modules)) | |||||
qat2quantized_dict = dict(zip(qat_modules, quantized_modules)) | |||||
return float2qat_dict, qat2quantized_dict | |||||
_float2qat_dict, _qat2quantized_dict = _get_convert_dict() | |||||
def quantize(module: Module, inplace=True): | def quantize(module: Module, inplace=True): | ||||
r""" | r""" | ||||
Recursively convert `module` to `quantized` mode through :meth:`~.Module.apply`. | |||||
Recursively convert :class:`~.QATModule` to :class:`~.QuantizedModule` | |||||
through :meth:`~.Module.apply`. | |||||
:param module: root module to do convert recursively. | :param module: root module to do convert recursively. | ||||
:param inplace: whether to convert submodules in-place. | |||||
""" | """ | ||||
if not inplace: | if not inplace: | ||||
module = deepcopy(module) | module = deepcopy(module) | ||||
def is_qat_module(obj): | |||||
return isinstance(obj, QATModule) | |||||
qat_modules = tuple(_qat2quantized_dict.keys()) | |||||
def is_qat(mod: Module): | |||||
return isinstance(mod, qat_modules) | |||||
# no need to pass prefix and get pure key of parent Module. | # no need to pass prefix and get pure key of parent Module. | ||||
for key, submodule, parent in module._flatten( | for key, submodule, parent in module._flatten( | ||||
with_key=True, with_parent=True, predicate=is_qat_module | |||||
with_key=True, with_parent=True, predicate=is_qat | |||||
): | ): | ||||
if isinstance(parent, Sequential): | |||||
new_mod = _qat2quantized_dict[type(submodule)].from_qat_module(submodule) | |||||
if isinstance(parent, Float.Sequential): | |||||
# cannnot use setattr to be compatible with Sequential's ``__setitem__`` | # cannnot use setattr to be compatible with Sequential's ``__setitem__`` | ||||
parent[int(key.split(".")[-1])] = submodule.to_quantized() | |||||
parent[int(key.split(".")[-1])] = new_mod | |||||
else: | else: | ||||
setattr(parent, key.split(".")[-1], submodule.to_quantized()) | |||||
setattr(parent, key.split(".")[-1], new_mod) | |||||
return module | return module | ||||
def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): | |||||
def quantize_qat( | |||||
module: Module, inplace=True, qconfig: QConfig = ema_fakequant_qconfig | |||||
): | |||||
r""" | r""" | ||||
Recursively convert `module` to `qat` mode through :meth:`~.Module.apply` | |||||
and set qconfig relatively. | |||||
Recursively convert float :class:`~.Module` to :class:`~.QATModule` | |||||
through :meth:`~.Module.apply` and set qconfig relatively. | |||||
:param module: root module to do convert recursively. | :param module: root module to do convert recursively. | ||||
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. | |||||
default is :any:`~.qconfig.ema_fakequant_qconfig`. | |||||
:param inplace: whether to convert submodules in-place. | |||||
:param qconfig: an instance of :class:`~.QConfig` to be set as submodules' qconfig. | |||||
default is ``ema_fakequant_qconfig``. | |||||
""" | """ | ||||
def fn(mod: Module): | |||||
if isinstance(mod, QATModule): | |||||
mod.set_qat_mode(QATModule.QATMode.QAT) | |||||
mod.set_qconfig(qconfig) | |||||
if not inplace: | |||||
module = deepcopy(module) | |||||
module.apply(fn) | |||||
quantable_modules = tuple(_float2qat_dict.keys()) | |||||
def is_quantable(mod: Module): | |||||
return isinstance(mod, quantable_modules) | |||||
# no need to pass prefix and get pure key of parent Module. | |||||
for key, submodule, parent in module._flatten( | |||||
with_key=True, with_parent=True, predicate=is_quantable | |||||
): | |||||
new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule) | |||||
if isinstance(parent, Float.Sequential): | |||||
# cannnot use setattr to be compatible with Sequential's ``__setitem__`` | |||||
parent[int(key.split(".")[-1])] = new_mod | |||||
else: | |||||
setattr(parent, key.split(".")[-1], new_mod) | |||||
propagate_qconfig(module, qconfig) | |||||
return module | |||||
def quantize_calibration(module: Module, qconfig: QConfig = ema_fakequant_qconfig): | |||||
def propagate_qconfig(module: QATModule, qconfig: QConfig): | |||||
r""" | r""" | ||||
Recursively convert `module` to `calibration` mode through :meth:`~.Module.apply` | |||||
and set qconfig relatively. | |||||
Recursively set ``module``'s qconfig through :meth:`~.Module.apply`. | |||||
:param module: root module to do convert recursively. | |||||
:param module: root module to traverse recursively. | |||||
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. | :param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. | ||||
default is :any:`~.qconfig.ema_fakequant_qconfig`. | |||||
""" | """ | ||||
def fn(mod: Module): | def fn(mod: Module): | ||||
if isinstance(mod, QATModule): | if isinstance(mod, QATModule): | ||||
mod.set_qat_mode(QATModule.QATMode.CALIBRATION) | |||||
mod.set_qconfig(qconfig) | mod.set_qconfig(qconfig) | ||||
module.apply(fn) | module.apply(fn) | ||||
@@ -5,8 +5,7 @@ import numpy as np | |||||
from megengine import tensor | from megengine import tensor | ||||
from megengine.module import ConvBn2d | from megengine.module import ConvBn2d | ||||
from megengine.quantization import quantize_qat | |||||
from megengine.quantization.quantize import disable_fake_quant | |||||
from megengine.quantization.quantize import disable_fake_quant, quantize_qat | |||||
from megengine.test import assertTensorClose | from megengine.test import assertTensorClose | ||||
@@ -14,18 +13,17 @@ def test_convbn2d(): | |||||
in_channels = 32 | in_channels = 32 | ||||
out_channels = 64 | out_channels = 64 | ||||
kernel_size = 3 | kernel_size = 3 | ||||
module = ConvBn2d(in_channels, out_channels, kernel_size) | |||||
quantize_qat(module) | |||||
for groups, bias in product([1, 4], [True, False]): | for groups, bias in product([1, 4], [True, False]): | ||||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||||
module = ConvBn2d( | |||||
in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||||
) | |||||
module.train() | module.train() | ||||
qat_module = copy.deepcopy(module) | |||||
qat_module = quantize_qat(module, inplace=False) | |||||
disable_fake_quant(qat_module) | disable_fake_quant(qat_module) | ||||
normal_outputs = module.forward(inputs) | |||||
qat_outputs = qat_module.forward_qat(inputs) | |||||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||||
normal_outputs = module(inputs) | |||||
qat_outputs = qat_module(inputs) | |||||
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) | assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) | ||||
a = module.bn.running_mean.numpy() | |||||
b = qat_module.bn.running_mean.numpy() | |||||
assertTensorClose( | assertTensorClose( | ||||
module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8 | module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8 | ||||
) | ) | ||||
@@ -33,7 +31,7 @@ def test_convbn2d(): | |||||
module.bn.running_var, qat_module.bn.running_var, max_err=5e-7 | module.bn.running_var, qat_module.bn.running_var, max_err=5e-7 | ||||
) | ) | ||||
module.eval() | module.eval() | ||||
normal_outputs = module.forward(inputs) | |||||
normal_outputs = module(inputs) | |||||
qat_module.eval() | qat_module.eval() | ||||
qat_outputs = qat_module.forward_qat(inputs) | |||||
qat_outputs = qat_module(inputs) | |||||
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) | assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) |
@@ -0,0 +1,38 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from megengine import module as Float | |||||
from megengine.module import qat as QAT | |||||
from megengine.quantization.quantize import _get_quantable_module_names | |||||
def test_get_quantable_module_names(): | |||||
# need to make sure names from Quantized and QAT are the same | |||||
def _get_qat_module_names(): | |||||
def is_qat(key: str): | |||||
value = getattr(QAT, key) | |||||
return ( | |||||
isinstance(value, type) | |||||
and issubclass(value, QAT.QATModule) | |||||
and value != QAT.QATModule | |||||
) | |||||
# source should have all quantable modules' names | |||||
quantable_module_names = [key for key in dir(QAT) if is_qat(key)] | |||||
return quantable_module_names | |||||
qat_module_names = _get_qat_module_names() | |||||
quantized_module_names = _get_quantable_module_names() | |||||
assert set(qat_module_names) == set(quantized_module_names) | |||||
for key in qat_module_names: | |||||
value = getattr(Float, key) | |||||
assert ( | |||||
isinstance(value, type) | |||||
and issubclass(value, Float.Module) | |||||
and value != Float.Module | |||||
) |