GitOrigin-RevId: 4fe1233ec3
tags/v0.4.0
@@ -74,12 +74,14 @@ from .nn import ( | |||||
softmax, | softmax, | ||||
warp_perspective, | warp_perspective, | ||||
) | ) | ||||
from .quantized import conv_bias_activation | |||||
from .sort import argsort, sort, top_k | from .sort import argsort, sort, top_k | ||||
from .tensor import ( | from .tensor import ( | ||||
add_axis, | add_axis, | ||||
arange, | arange, | ||||
broadcast_to, | broadcast_to, | ||||
concat, | concat, | ||||
cond_take, | |||||
dimshuffle, | dimshuffle, | ||||
gather, | gather, | ||||
linspace, | linspace, | ||||
@@ -0,0 +1,84 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
# pylint: disable=too-many-lines | |||||
from typing import Tuple, Union | |||||
from .. import _internal as mgb | |||||
from ..core import Tensor, wrap_io_tensor | |||||
from ..utils.types import _pair, _pair_nonzero | |||||
from .debug_param import get_conv_execution_strategy | |||||
@wrap_io_tensor | |||||
def conv_bias_activation( | |||||
inp: Tensor, | |||||
weight: Tensor, | |||||
bias: Tensor, | |||||
dtype=None, | |||||
stride: Union[int, Tuple[int, int]] = 1, | |||||
padding: Union[int, Tuple[int, int]] = 0, | |||||
dilation: Union[int, Tuple[int, int]] = 1, | |||||
groups: int = 1, | |||||
nonlinear_mode="IDENTITY", | |||||
conv_mode="CROSS_CORRELATION", | |||||
compute_mode="DEFAULT", | |||||
) -> Tensor: | |||||
""" convolution bias with activation operation, only for inference. | |||||
:param inp: The feature map of the convolution operation | |||||
:param weight: The convolution kernel | |||||
:param bias: The bias added to the result of convolution | |||||
:param stride: Stride of the 2D convolution operation. Default: 1 | |||||
:param padding: Size of the paddings added to the input on both sides of its | |||||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||||
:param dilation: Dilation of the 2D convolution operation. Default: 1 | |||||
:param groups: number of groups to divide input and output channels into, | |||||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||||
and the shape of weight should be ``(groups, out_channel // groups, | |||||
in_channels // groups, height, width)``. | |||||
:type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode` | |||||
:param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||||
'CROSS_CORRELATION'. | |||||
:param dtype: Support for np.dtype, Default: | |||||
np.int8. | |||||
:param scale: scale if use quantization, Default: | |||||
0.0. | |||||
:param zero_point: scale if use quantization quint8, Default: | |||||
0.0. | |||||
:type compute_mode: string or | |||||
:class:`mgb.opr_param_defs.Convolution.ComputeMode` | |||||
:param compute_mode: When set to 'DEFAULT', no special requirements will be | |||||
placed on the precision of intermediate results. When set to 'FLOAT32', | |||||
Float32 would be used for accumulator and intermediate result, but only | |||||
effective when input and output are of Float16 dtype. | |||||
""" | |||||
ph, pw = _pair(padding) | |||||
sh, sw = _pair_nonzero(stride) | |||||
dh, dw = _pair_nonzero(dilation) | |||||
sparse_type = "DENSE" if groups == 1 else "GROUP" | |||||
res = mgb.opr.conv_bias_activation( | |||||
inp, | |||||
weight, | |||||
bias, | |||||
compute_mode=compute_mode, | |||||
dtype=dtype, | |||||
strategy=get_conv_execution_strategy(), | |||||
nonlineMode=nonlinear_mode, | |||||
sparse=sparse_type, | |||||
format="NCHW", | |||||
pad_h=ph, | |||||
pad_w=pw, | |||||
stride_h=sh, | |||||
stride_w=sw, | |||||
dilate_h=dh, | |||||
dilate_w=dw, | |||||
mode=conv_mode, | |||||
) | |||||
return res |
@@ -359,6 +359,41 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: | |||||
return out | return out | ||||
@wrap_io_tensor | |||||
def cond_take(mask: Tensor, x: Tensor, val=1) -> Tensor: | |||||
r""" | |||||
Take elements from data if specific condition is satisfied on mask. This operator has two outputs: the first is the elements taken, and the second is the indices corresponding to those elements; they are both 1-dimensional. High-dimension input would first be flattened. | |||||
:param mask: condition param; must be the same shape with data | |||||
:param x: input tensor from which to take elements | |||||
:param val: value to be compared to by mode | |||||
Examples: | |||||
.. testcode:: | |||||
from megengine import tensor | |||||
import megengine.functional as F | |||||
mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32)) | |||||
x = tensor(np.array([[1, np.inf], [np.nan, 4]], | |||||
dtype=np.float32)) | |||||
v, index = F.cond_take(mask, x, 1) | |||||
print(v, index) | |||||
Outputs: | |||||
.. testoutput:: | |||||
Tensor([1. 4.]) Tensor([0 3], dtype=int32) | |||||
""" | |||||
v, index = mgb.opr.cond_take( | |||||
x, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=val | |||||
) | |||||
return v, index | |||||
def shapeof(x: Tensor, axis=None): | def shapeof(x: Tensor, axis=None): | ||||
r""" | r""" | ||||
The shape of input tensor. | The shape of input tensor. | ||||
@@ -8,12 +8,16 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | ||||
from .batchnorm import BatchNorm1d, BatchNorm2d | from .batchnorm import BatchNorm1d, BatchNorm2d | ||||
from .concat import Concat | |||||
from .conv import Conv2d, ConvTranspose2d | from .conv import Conv2d, ConvTranspose2d | ||||
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||||
from .dropout import Dropout | from .dropout import Dropout | ||||
from .elemwise import Elemwise | |||||
from .embedding import Embedding | from .embedding import Embedding | ||||
from .identity import Identity | from .identity import Identity | ||||
from .linear import Linear | from .linear import Linear | ||||
from .module import Module | |||||
from .module import Module, QATModule | |||||
from .parampack import ParamPack | from .parampack import ParamPack | ||||
from .pooling import AvgPool2d, MaxPool2d | from .pooling import AvgPool2d, MaxPool2d | ||||
from .quant_dequant import DequantStub, QuantStub | |||||
from .sequential import Sequential | from .sequential import Sequential |
@@ -0,0 +1,27 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from typing import Iterable | |||||
from .. import functional as F | |||||
from ..core.tensor import Tensor | |||||
from .module import QATModule | |||||
class Concat(QATModule): | |||||
r""" | |||||
A :class:`~.QATModule` to do functional concat, should replace concat with this module, | |||||
supporting ``qat`` mode and ``quantized`` mode. | |||||
""" | |||||
def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||||
return F.concat(inps, axis) | |||||
def forward_qat(self, inps: Iterable[Tensor], axis: int = 0): | |||||
return self.apply_fakequant_with_observer( | |||||
self.forward(inps, axis), self.act_fake_quant, self.act_observer | |||||
) |
@@ -182,11 +182,11 @@ class Conv2d(_ConvNd): | |||||
# Assume format is NCHW | # Assume format is NCHW | ||||
return (1, self.out_channels, 1, 1) | return (1, self.out_channels, 1, 1) | ||||
def forward(self, inp): | |||||
def calc_conv(self, inp, weight, bias): | |||||
return conv2d( | return conv2d( | ||||
inp, | inp, | ||||
self.weight, | |||||
self.bias, | |||||
weight, | |||||
bias, | |||||
self.stride, | self.stride, | ||||
self.padding, | self.padding, | ||||
self.dilation, | self.dilation, | ||||
@@ -195,6 +195,9 @@ class Conv2d(_ConvNd): | |||||
self.compute_mode, | self.compute_mode, | ||||
) | ) | ||||
def forward(self, inp): | |||||
return self.calc_conv(inp, self.weight, self.bias) | |||||
class ConvTranspose2d(_ConvNd): | class ConvTranspose2d(_ConvNd): | ||||
r"""Applies a 2D transposed convolution over an input tensor. | r"""Applies a 2D transposed convolution over an input tensor. | ||||
@@ -0,0 +1,168 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from typing import Tuple, Union | |||||
from ..core import ones, zeros | |||||
from ..functional import flatten, relu, sqrt, sum | |||||
from .batchnorm import BatchNorm2d | |||||
from .conv import Conv2d | |||||
from .module import QATModule | |||||
class _ConvBn2d(QATModule): | |||||
def __init__( | |||||
self, | |||||
in_channels: int, | |||||
out_channels: int, | |||||
kernel_size: Union[int, Tuple[int, int]], | |||||
stride: Union[int, Tuple[int, int]] = 1, | |||||
padding: Union[int, Tuple[int, int]] = 0, | |||||
dilation: Union[int, Tuple[int, int]] = 1, | |||||
groups: int = 1, | |||||
bias: bool = True, | |||||
conv_mode: str = "CROSS_CORRELATION", | |||||
compute_mode: str = "DEFAULT", | |||||
eps=1e-5, | |||||
momentum=0.9, | |||||
affine=True, | |||||
track_running_stats=True, | |||||
freeze_bn=False, | |||||
): | |||||
super().__init__() | |||||
self.conv = Conv2d( | |||||
in_channels, | |||||
out_channels, | |||||
kernel_size, | |||||
stride, | |||||
padding, | |||||
dilation, | |||||
groups, | |||||
bias, | |||||
conv_mode, | |||||
compute_mode, | |||||
) | |||||
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | |||||
self.freeze_bn = freeze_bn | |||||
def update_bn_stats(self): | |||||
self.freeze_bn = False | |||||
return self | |||||
def freeze_bn_stats(self): | |||||
self.freeze_bn = True | |||||
return self | |||||
def get_bn_gamma_beta(self): | |||||
if self.bn.weight is None: | |||||
gamma = ones((self.bn.num_features), dtype="float32") | |||||
else: | |||||
gamma = self.bn.weight | |||||
if self.bn.bias is None: | |||||
beta = zeros((self.bn.num_features), dtype="float32") | |||||
else: | |||||
beta = self.bn.bias | |||||
return gamma, beta | |||||
def get_batch_mean_var(self, inp): | |||||
def _sum_channel(inp, axis=0, keepdims=True): | |||||
if isinstance(axis, int): | |||||
out = sum(inp, axis=axis, keepdims=keepdims) | |||||
elif isinstance(axis, tuple): | |||||
for idx, elem in enumerate(axis): | |||||
out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) | |||||
return out | |||||
sum1 = _sum_channel(inp, (0, 2, 3)) | |||||
sum2 = _sum_channel(inp ** 2, (0, 2, 3)) | |||||
reduce_size = inp.shapeof().prod() / inp.shapeof(1) | |||||
batch_mean = sum1 / reduce_size | |||||
batch_var = (sum2 - sum1 ** 2 / reduce_size) / (reduce_size - 1) | |||||
return batch_mean, batch_var | |||||
def fold_weight_bias(self, bn_mean, bn_var): | |||||
# get fold bn conv param | |||||
# bn_istd = 1 / bn_std | |||||
# w_fold = gamma / bn_std * W | |||||
# b_fold = gamma * (b - bn_mean) / bn_std + beta | |||||
gamma, beta = self.get_bn_gamma_beta() | |||||
b = self.conv.bias | |||||
if b is None: | |||||
b = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||||
if bn_mean is None: | |||||
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||||
if bn_var is None: | |||||
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") | |||||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||||
if self.conv.groups == 1: | |||||
w_fold = ( | |||||
self.conv.weight | |||||
* gamma.reshape(-1, 1, 1, 1) | |||||
* bn_istd.reshape(-1, 1, 1, 1) | |||||
) | |||||
else: | |||||
w_fold = ( | |||||
self.conv.weight | |||||
* gamma.reshape(self.conv.groups, -1, 1, 1, 1) | |||||
* bn_istd.reshape(self.conv.groups, -1, 1, 1, 1) | |||||
) | |||||
b_fold = flatten(beta) + ( | |||||
flatten(gamma) * (flatten(b) - flatten(bn_mean)) * flatten(bn_istd) | |||||
) | |||||
b_fold = b_fold.reshape(self.conv._infer_bias_shape()) | |||||
return w_fold, b_fold | |||||
def calc_conv_bn_qat(self, inp): | |||||
# TODO: use pytorch method as | |||||
conv = self.conv(inp) | |||||
self.bn(conv) | |||||
if self.training: | |||||
bn_mean, bn_var = self.get_batch_mean_var(conv) | |||||
else: | |||||
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var | |||||
w_fold, b_fold = self.fold_weight_bias(bn_mean, bn_var) | |||||
w_qat = self.apply_fakequant_with_observer( | |||||
w_fold, self.weight_fake_quant, self.weight_observer | |||||
) | |||||
return self.conv.calc_conv(inp, w_qat, b_fold) | |||||
class ConvBn2d(_ConvBn2d): | |||||
r""" | |||||
A fused :class:`~.QATModule` including Conv2d and BatchNorm2d, supporting ``qat`` mode | |||||
and ``normal`` mode. | |||||
""" | |||||
def forward_qat(self, inp): | |||||
return self.apply_fakequant_with_observer( | |||||
self.calc_conv_bn_qat(inp), self.act_fake_quant, self.act_observer | |||||
) | |||||
def forward(self, inp): | |||||
return self.bn(self.conv(inp)) | |||||
class ConvBnRelu2d(_ConvBn2d): | |||||
r""" | |||||
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu, supporting ``qat`` | |||||
mode and ``normal`` mode. | |||||
""" | |||||
def forward_qat(self, inp): | |||||
return self.apply_fakequant_with_observer( | |||||
relu(self.calc_conv_bn_qat(inp)), self.act_fake_quant, self.act_observer | |||||
) | |||||
def forward(self, inp): | |||||
return relu(self.bn(self.conv(inp))) |
@@ -0,0 +1,95 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from .. import _internal as mgb | |||||
from ..core import Tensor, wrap_io_tensor | |||||
from ..core.graph import _use_default_if_none | |||||
from .module import QATModule | |||||
@wrap_io_tensor | |||||
def _elemwise_func(mode, *inputs, **kwargs) -> Tensor: | |||||
if all(isinstance(i, (int, float)) for i in inputs): | |||||
device, comp_graph = _use_default_if_none(None, None) | |||||
ret = mgb.opr.elemwise( | |||||
*inputs, mode=mode, comp_node=device, comp_graph=comp_graph, **kwargs | |||||
) | |||||
return ret.inferred_value[0] | |||||
return mgb.opr.elemwise(*inputs, mode=mode, **kwargs) | |||||
class Elemwise(QATModule): | |||||
r""" | |||||
A :class:`~.QATModule` to do elemwise operator, should functional operator with this module, | |||||
supporting ``qat`` mode and ``normal`` mode. | |||||
:param method: the elemwise method, support the following string. | |||||
It will do the normal elemwise operator for float. | |||||
* "ADD": a + b | |||||
* "FUSE_ADD_RELU": max(x+y, 0) | |||||
* "MUL": x * y | |||||
* "MIN": min(x, y) | |||||
* "MAX": max(x, y) | |||||
* "SUB": x - y | |||||
* "TRUE_DIV": x / y | |||||
* "FUSE_ADD_SIGMOID": sigmoid(x + y) | |||||
* "FUSE_ADD_TANH": tanh(x + y) | |||||
* "RELU": x > 0 ? x : 0 | |||||
* "ABS": x > 0 ? x : -x | |||||
* "SIGMOID": sigmoid(x) | |||||
* "EXP": exp(x) | |||||
* "TANH": tanh(x) | |||||
* "FUSE_MUL_ADD3": x * y + z | |||||
* "FAST_TANH": fast_tanh(x) | |||||
* "NEGATE": -x | |||||
* "ACOS": acos(x) | |||||
* "ASIN": asin(x) | |||||
* "CEIL": ceil(x) | |||||
* "COS": cos(x) | |||||
* "EXPM1": expm1(x) | |||||
* "FLOOR": floor(x) | |||||
* "LOG": log(x) | |||||
* "LOG1P": log1p(x) | |||||
* "SIN": sin(x) | |||||
* "ROUND": round(x) | |||||
* "ERF": erf(x) | |||||
* "ERFINV": erfinv(x) | |||||
* "ERFC": erfc(x) | |||||
* "ERFCINV": erfcinv(x) | |||||
* "ABS_GRAD": abs_grad | |||||
* "FLOOR_DIV": floor_div | |||||
* "MOD": mod | |||||
* "SIGMOID_GRAD": sigmoid_grad | |||||
* "SWITCH_GT0": switch_gt0 | |||||
* "TANH_GRAD": tanh_grad | |||||
* "LT": lt | |||||
* "LEQ": leq | |||||
* "EQ": eq | |||||
* "POW": pow | |||||
* "LOG_SUM_EXP": log_sum_exp | |||||
* "FAST_TANH_GRAD": fast_tanh_grad | |||||
* "ATAN2": atan2 | |||||
* "COND_LEQ_MOV": cond_leq_mov | |||||
* "H_SWISH": h_swish | |||||
* "FUSE_ADD_H_SWISH": h_swish(x+y) | |||||
* "H_SWISH_GRAD": h_swish_grad | |||||
""" | |||||
_elemwise_mode_type = mgb.opr_param_defs.Elemwise.Mode | |||||
def __init__(self, method): | |||||
super().__init__() | |||||
self.method = self._elemwise_mode_type.convert(method) | |||||
def forward(self, *inps): | |||||
return _elemwise_func(self.method, *inps) | |||||
def forward_qat(self, *inps): | |||||
return self.apply_fakequant_with_observer( | |||||
self.forward(*inps), self.act_fake_quant, self.act_observer, | |||||
) |
@@ -1,4 +1,3 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
# | # | ||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
@@ -8,6 +7,7 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from abc import ABCMeta, abstractmethod | from abc import ABCMeta, abstractmethod | ||||
from collections import OrderedDict | from collections import OrderedDict | ||||
from enum import Enum | |||||
from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | ||||
import numpy as np | import numpy as np | ||||
@@ -442,3 +442,95 @@ class Module(metaclass=ABCMeta): | |||||
loaded.append(k) | loaded.append(k) | ||||
return set(loaded), set(skipped) | return set(loaded), set(skipped) | ||||
class QATModule(Module): | |||||
r""" | |||||
Base class of quantization related Module. Add extra forward methods | |||||
:meth:`~.QATModule.forward_qat` and :meth:`~.QATModule.forward_quantized` for | |||||
``qat``(quantization aware training) mode and ``quantized`` mode respectively. | |||||
Use :meth:`~.QATModule.quant` to switch between ``QAT`` and ``NORMAL`` mode, | |||||
and use :meth:`~.QATModule.to_quantized` to switch to ``quantized`` mode, | |||||
which is irreversible. | |||||
If you want to recursively switch mode for all QATModule in network, use | |||||
functions in :mod:`~.quantization.quantize`. | |||||
""" | |||||
class QATMode(Enum): | |||||
DISABLED = 1 | |||||
QAT = 2 | |||||
CALIBRATION = 3 | |||||
def __init__(self): | |||||
from ..quantization import ( | |||||
QConfig, | |||||
FakeQuantize, | |||||
Observer, | |||||
) # pylint: disable=all | |||||
super().__init__() | |||||
self.quantizing = self.QATMode.DISABLED | |||||
self.scale = None | |||||
self.inp_observer = None # type: Observer | |||||
self.weight_observer = None # type: Observer | |||||
self.act_observer = None # type: Observer | |||||
self.weight_fake_quant = None # type: FakeQuantize | |||||
self.bias_fake_quant = None # type: FakeQuantize | |||||
self.act_fake_quant = None # type: FakeQuantize | |||||
def set_qconfig(self, qconfig: "QConfig"): | |||||
self.inp_observer = qconfig.inp_observer() | |||||
self.weight_observer = qconfig.weight_observer() | |||||
self.act_observer = qconfig.act_observer() | |||||
self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype) | |||||
self.bias_fake_quant = qconfig.bias_fake_quant() | |||||
self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype) | |||||
def apply_observer(self, target: Tensor, obs: "Observer"): | |||||
return obs(target) | |||||
def apply_fakequant_with_observer( | |||||
self, target: Tensor, fq: "FakeQuantize", obs: "Observer" | |||||
): | |||||
oup = self.apply_observer(target, obs) | |||||
return fq(oup, obs.scale, obs.zero_point) | |||||
def set_qat_mode(self, mode: QATMode): | |||||
r""" | |||||
Change ``self.quantizing`` mode, available values: ``self.QATMode.DISABLED``, | |||||
``QAT``,``CALIBRATION``. | |||||
""" | |||||
if not isinstance(mode, self.QATMode): | |||||
raise TypeError("mode must be QATMode Enum type") | |||||
self.quantizing = mode | |||||
def to_quantized(self): | |||||
r""" | |||||
Return a new :class:`~.Module` with quantized parameters of ``self`` | |||||
according to scale and zero_point in ``self.xxx_observer``. | |||||
""" | |||||
raise NotImplementedError( | |||||
"Use megengine.quantization.quantize to register the method." | |||||
) | |||||
@abstractmethod | |||||
def forward_qat(self, *args, **kwargs): | |||||
r""" | |||||
Forward method for ``qat`` mode. | |||||
""" | |||||
def __call__(self, *args, **kwargs): | |||||
if self.quantizing == self.QATMode.QAT: | |||||
return self.forward_qat(*args, **kwargs) | |||||
elif self.quantizing == self.QATMode.CALIBRATION: | |||||
# TODO implement the CALIBRATION | |||||
assert False | |||||
return None | |||||
else: | |||||
return self.forward(*args, **kwargs) |
@@ -0,0 +1,34 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from .module import QATModule | |||||
class QuantStub(QATModule): | |||||
r""" | |||||
A helper QATModule doing quantize operation on input. | |||||
""" | |||||
def forward(self, inp): | |||||
return inp | |||||
def forward_qat(self, inp): | |||||
return self.apply_fakequant_with_observer( | |||||
inp, self.act_fake_quant, self.act_observer | |||||
) | |||||
class DequantStub(QATModule): | |||||
r""" | |||||
A helper QATModule doing de-quantize operation on input. | |||||
""" | |||||
def forward(self, inp): | |||||
return inp | |||||
def forward_qat(self, inp): | |||||
return inp |
@@ -0,0 +1,11 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from .concat import Concat | |||||
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||||
from .elemwise import Elemwise | |||||
from .quant_dequant import DequantStub, QuantStub |
@@ -0,0 +1,45 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from typing import Iterable | |||||
from ... import _internal as mgb | |||||
from ... import functional as F | |||||
from ... import module as Float | |||||
from ...core.tensor import Tensor | |||||
from ...quantization.utils import register_method_to_class | |||||
from ..module import Module | |||||
class Concat(Module): | |||||
r""" | |||||
A :class:`~.Module` to do quantized concat, inference only. | |||||
""" | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.scale = 1.0 | |||||
self.zero_point = 0.0 | |||||
self.output_dtype = mgb.dtype.qint8(self.scale) | |||||
def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||||
if self.training: | |||||
raise ValueError("quantized module only support inference.") | |||||
new_inps = (x.astype(self.output_dtype) for x in inps) | |||||
return F.concat(new_inps, axis) | |||||
@register_method_to_class(Float.Concat) | |||||
def to_quantized(float_module): | |||||
r""" | |||||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||||
implemented here to avoid circular import. | |||||
""" | |||||
qmod = Concat() | |||||
qmod.output_dtype = float_module.act_observer.get_dtype() | |||||
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() | |||||
return qmod |
@@ -0,0 +1,114 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from functools import partial | |||||
from typing import Tuple, Union | |||||
import megengine._internal as mgb | |||||
from ... import module as Float | |||||
from ...core import Parameter | |||||
from ...functional import conv_bias_activation | |||||
from ...module import Conv2d | |||||
from ...quantization.utils import register_method_to_class | |||||
class _ConvBnActivation2d(Conv2d): | |||||
r"""Applies a 2D convolution over an quantized input tensor, inference only. | |||||
The parameter is same with :class: `~.Conv2d` | |||||
""" | |||||
def __init__( | |||||
self, | |||||
in_channels: int, | |||||
out_channels: int, | |||||
kernel_size: Union[int, Tuple[int, int]], | |||||
stride: Union[int, Tuple[int, int]] = 1, | |||||
padding: Union[int, Tuple[int, int]] = 0, | |||||
dilation: Union[int, Tuple[int, int]] = 1, | |||||
groups: int = 1, | |||||
conv_mode: str = "CROSS_CORRELATION", | |||||
compute_mode: str = "DEFAULT", | |||||
): | |||||
super().__init__( | |||||
in_channels, | |||||
out_channels, | |||||
kernel_size, | |||||
stride, | |||||
padding, | |||||
dilation, | |||||
groups, | |||||
True, | |||||
conv_mode, | |||||
compute_mode, | |||||
) | |||||
self.scale = 1.0 | |||||
self.zero_point = 0.0 | |||||
self.output_dtype = mgb.dtype.qint8(self.scale) | |||||
self.weight = self.weight.astype(self.output_dtype) | |||||
self.bias = self.bias.astype(mgb.dtype.qint32(self.scale)) | |||||
def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"): | |||||
inp_scale = mgb.dtype.get_scale(inp.dtype) | |||||
w_scale = mgb.dtype.get_scale(self.weight.dtype) | |||||
bias_scale = inp_scale * w_scale | |||||
return conv_bias_activation( | |||||
inp, | |||||
self.weight, | |||||
self.bias.astype(mgb.dtype.qint32(bias_scale)), | |||||
self.output_dtype, | |||||
self.stride, | |||||
self.padding, | |||||
self.dilation, | |||||
self.groups, | |||||
conv_mode=self.conv_mode, | |||||
compute_mode=self.compute_mode, | |||||
nonlinear_mode=nonlinear_mode, | |||||
) | |||||
class ConvBn2d(_ConvBnActivation2d): | |||||
def forward(self, inp): | |||||
if self.training: | |||||
raise ValueError("quantized module only support inference.") | |||||
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") | |||||
class ConvBnRelu2d(_ConvBnActivation2d): | |||||
def forward(self, inp): | |||||
if self.training: | |||||
raise ValueError("quantized module only support inference.") | |||||
return self.calc_conv_quantized(inp, nonlinear_mode="RELU") | |||||
def to_quantized(quantized_class, float_module): | |||||
qconv = quantized_class( | |||||
float_module.conv.in_channels, | |||||
float_module.conv.out_channels, | |||||
float_module.conv.kernel_size, | |||||
float_module.conv.stride, | |||||
float_module.conv.padding, | |||||
float_module.conv.dilation, | |||||
float_module.conv.groups, | |||||
) | |||||
w_fold, b_fold = float_module.fold_weight_bias( | |||||
float_module.bn.running_mean, float_module.bn.running_var | |||||
) | |||||
weight = w_fold.astype(float_module.weight_observer.get_dtype()) | |||||
qconv.output_dtype = float_module.act_observer.get_dtype() | |||||
qconv.weight = Parameter(weight.numpy()) | |||||
qconv.bias = Parameter(b_fold.numpy()) | |||||
qconv.scale, qconv.zero_point = float_module.act_observer.get_qparams() | |||||
return qconv | |||||
# replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||||
# implemented here to avoid circular import. | |||||
register_method_to_class(Float.ConvBn2d)(partial(to_quantized, ConvBn2d)) | |||||
register_method_to_class(Float.ConvBnRelu2d)(partial(to_quantized, ConvBnRelu2d)) |
@@ -0,0 +1,59 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from ... import _internal as mgb | |||||
from ... import module as Float | |||||
from ...core import Tensor, wrap_io_tensor | |||||
from ...core.graph import _use_default_if_none | |||||
from ...quantization.utils import register_method_to_class | |||||
from ..module import Module | |||||
@wrap_io_tensor | |||||
def _elemwise_multi_type(mode, *inputs, **kwargs) -> Tensor: | |||||
if all(isinstance(i, (int, float)) for i in inputs): | |||||
device, comp_graph = _use_default_if_none(None, None) | |||||
ret = mgb.opr.elemwise_multi_type( | |||||
*inputs, mode=mode, comp_node=device, comp_graph=comp_graph, **kwargs, | |||||
) | |||||
return ret.inferred_value[0] | |||||
return mgb.opr.elemwise_multi_type(*inputs, mode=mode, **kwargs) | |||||
class Elemwise(Module): | |||||
r""" | |||||
quantized module for elemwise operator, inference only. | |||||
:param method: the elemwise method, supported string refer to :class:`~.module.elemwise.Elemwise`. | |||||
it will do quantized operator with specified output quantized dtype. | |||||
""" | |||||
_elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode | |||||
def __init__(self, method): | |||||
super().__init__() | |||||
self.method = self._elemwise_multi_type_mode.convert("Q" + method) | |||||
self.scale = 1.0 | |||||
self.zero_point = 0.0 | |||||
self.output_dtype = mgb.dtype.qint8(self.scale) | |||||
def forward(self, *inps): | |||||
if self.training: | |||||
raise ValueError("quantized module only support inference.") | |||||
return _elemwise_multi_type(self.method, *inps, dtype=self.output_dtype) | |||||
@register_method_to_class(Float.Elemwise) | |||||
def to_quantized(float_module): | |||||
r""" | |||||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||||
implemented here to avoid circular import. | |||||
""" | |||||
qmod = Elemwise(float_module.method.name) | |||||
qmod.output_dtype = float_module.act_observer.get_dtype() | |||||
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() | |||||
return qmod |
@@ -0,0 +1,61 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from ... import _internal as mgb | |||||
from ... import module as Float | |||||
from ...quantization.utils import register_method_to_class | |||||
from ..module import Module | |||||
class QuantStub(Module): | |||||
r""" | |||||
A helper quantize operation on input and inference only. | |||||
""" | |||||
def __init__(self): | |||||
super().__init__() | |||||
self.scale = 1.0 | |||||
self.zero_point = 0.0 | |||||
self.output_dtype = mgb.dtype.qint8(self.scale) | |||||
def forward(self, inp): | |||||
if self.training: | |||||
raise ValueError("quantized module only support inference.") | |||||
return inp.astype(self.output_dtype) | |||||
class DequantStub(Module): | |||||
r""" | |||||
A helper de-quantize operation and inference only. | |||||
""" | |||||
def forward(self, inp): | |||||
if self.training: | |||||
raise ValueError("quantized module only support inference.") | |||||
return inp.astype("float32") | |||||
@register_method_to_class(Float.QuantStub) | |||||
def to_quantized(float_module): | |||||
r""" | |||||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||||
implemented here to avoid circular import. | |||||
""" | |||||
qmod = QuantStub() | |||||
qmod.output_dtype = float_module.act_observer.get_dtype() | |||||
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() | |||||
return qmod | |||||
@register_method_to_class(Float.DequantStub) | |||||
def to_quantized(float_module): | |||||
r""" | |||||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||||
implemented here to avoid circular import. | |||||
""" | |||||
qmod = DequantStub() | |||||
return qmod |
@@ -68,6 +68,7 @@ class Sequential(Module): | |||||
def __setitem__(self, idx, module): | def __setitem__(self, idx, module): | ||||
key = self.layer_keys[idx] | key = self.layer_keys[idx] | ||||
self.layer_values[idx] = module | |||||
return setattr(self, key, module) | return setattr(self, key, module) | ||||
def __delitem__(self, idx): | def __delitem__(self, idx): | ||||
@@ -0,0 +1,11 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from .fake_quant import FakeQuantize | |||||
from .observer import Observer | |||||
from .qconfig import QConfig, ema_fakequant_qconfig, min_max_fakequant_qconfig | |||||
from .quantize import quantize, quantize_qat |
@@ -0,0 +1,48 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from .. import functional as F | |||||
from .._internal.dtype import _metadata_dict | |||||
from ..module import Module | |||||
from .observer import Round | |||||
class FakeQuantize(Module): | |||||
r""" | |||||
A module to do quant and dequant according to observer's scale and zero_point. | |||||
""" | |||||
def __init__(self, dtype: str, enable: bool = True): | |||||
super().__init__() | |||||
if not dtype in _metadata_dict.keys(): | |||||
raise ValueError( | |||||
"unknown dtype: {}, only support {}".format( | |||||
dtype, _metadata_dict.keys() | |||||
) | |||||
) | |||||
self.dtype = dtype | |||||
self.qmin = _metadata_dict[dtype].qmin | |||||
self.qmax = _metadata_dict[dtype].qmax | |||||
self.enabled = enable | |||||
def enable(self): | |||||
self.enabled = True | |||||
def disable(self): | |||||
self.enabled = False | |||||
def forward(self, inp, scale, zero_point): | |||||
if self.enabled: | |||||
# Quant | |||||
oup = Round()(inp / scale) + zero_point | |||||
# clip | |||||
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||||
# DeQuant | |||||
oup = (oup - zero_point) * scale | |||||
return oup | |||||
return inp |
@@ -0,0 +1,193 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from abc import abstractmethod | |||||
import numpy as np | |||||
from .. import functional as F | |||||
from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||||
from ..core import Buffer, Function, ones, tensor, zeros | |||||
from ..module import Module | |||||
class Round(Function): | |||||
def forward(self, x): | |||||
return x.round() | |||||
def backward(self, output_grads): | |||||
return output_grads | |||||
class Observer(Module): | |||||
r""" | |||||
A base class for Observer Module. | |||||
:param dtype: a string indicating to collect scale and zero_point of which dtype | |||||
""" | |||||
def __init__(self, dtype="qint8"): | |||||
super().__init__() | |||||
if dtype not in _metadata_dict.keys(): | |||||
raise ValueError( | |||||
"unknown dtype: {}, only support {}".format( | |||||
dtype, _metadata_dict.keys() | |||||
) | |||||
) | |||||
self.dtype = dtype | |||||
self.qmin = _metadata_dict[dtype].qmin | |||||
self.qmax = _metadata_dict[dtype].qmax | |||||
self.zero_point, self.scale = None, None | |||||
self.enabled = True | |||||
def get_dtype(self): | |||||
scale, zero_point = self.get_qparams() | |||||
numpy_scale = None if scale is None else scale.numpy()[0] | |||||
numpy_zero_point = None if zero_point is None else zero_point.numpy()[0] | |||||
return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point) | |||||
def enable(self): | |||||
self.enabled = True | |||||
def disable(self): | |||||
self.enabled = False | |||||
@abstractmethod | |||||
def forward(self, x): | |||||
pass | |||||
@abstractmethod | |||||
def get_qparams(self, **kwargs): | |||||
pass | |||||
class IdentityObserver(Observer): | |||||
r""" | |||||
An test Observer that always return scale:1 and zero_point:0. | |||||
""" | |||||
def __init__(self, *args, **kwargs): | |||||
super().__init__(*args, **kwargs) | |||||
self.zero_point = ones((1), dtype="float32") | |||||
self.scale = zeros((1), dtype="float32") | |||||
def forward(self, x): | |||||
return x | |||||
def get_qparams(self): | |||||
return self.scale, self.zero_point | |||||
class MinMaxObserver(Observer): | |||||
def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs): | |||||
super().__init__(*args, **kwargs) | |||||
self.symmetric = symmetric | |||||
if self.symmetric: | |||||
# assert qmin + qmax == -1, 'when reduce_range, qmin + qmax shoule equal -1' | |||||
self.zero_point = tensor((self.qmin + self.qmax + 1) // 2) | |||||
self.min_val = Buffer(0.0, dtype=np.float32) | |||||
self.max_val = Buffer(0.0, dtype=np.float32) | |||||
self.scale_limit = eps | |||||
# flag is used by cond_take, first time will be first flag, and after will be set as not_flag | |||||
self.first_flag = Buffer(np.array([1, 0], dtype=np.int32)) | |||||
self.not_flag = Buffer(np.array([0, 1], dtype=np.int32)) | |||||
def set_min_max(self, tmp_min, tmp_max): | |||||
# FIXME: cond_take will destory shape, use reshape to reset shape | |||||
tmp_min = tmp_min.reshape(1) | |||||
tmp_max = tmp_max.reshape(1) | |||||
if self.training: | |||||
F.zero_grad( | |||||
F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0) | |||||
) | |||||
F.zero_grad( | |||||
F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0) | |||||
) | |||||
F.zero_grad( | |||||
F.add_update( | |||||
self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0 | |||||
) | |||||
) | |||||
# FIXME: add_update is applied after the whole trace procedure in `symbolic=True` | |||||
# mode. So use tmp_min/tmp_max to calc and save scale/zero_point for further | |||||
# calculation in FakeQuant. | |||||
self.set_scale_zero_point(tmp_min, tmp_max) | |||||
def set_scale_zero_point(self, tmp_min, tmp_max): | |||||
if self.symmetric: | |||||
symmetric_max_vals = F.maximum(-tmp_min, tmp_max) | |||||
# use maximun to avoid scale too small at the begin | |||||
self.scale = F.maximum( | |||||
symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit | |||||
) | |||||
# zero_point = self.zero_point | |||||
else: | |||||
# use maximun to avoid scale too small at the begin | |||||
self.scale = F.maximum( | |||||
(tmp_max - tmp_min) / (self.qmax - self.qmin), self.scale_limit | |||||
) | |||||
# caculate zero_point | |||||
self.zero_point = self.qmin - Round()((tmp_min / self.scale)) | |||||
def get_qparams(self): | |||||
# scale and zero_point is runtime tensor rather than Buffer, | |||||
# so need to re-calc if min_val and max_val are loaded. | |||||
if self.scale is None: | |||||
self.set_scale_zero_point(self.min_val, self.max_val) | |||||
return self.scale, self.zero_point | |||||
def forward(self, x_orig): | |||||
if self.enabled: | |||||
# stop gradient | |||||
x = F.zero_grad(x_orig) | |||||
# find max and min | |||||
tmp_min, _ = F.cond_take( | |||||
self.first_flag, F.concat([x.min(), F.minimum(self.min_val, x.min())]) | |||||
) | |||||
tmp_max, _ = F.cond_take( | |||||
self.first_flag, F.concat([x.max(), F.maximum(self.max_val, x.max())]) | |||||
) | |||||
self.set_min_max(tmp_min, tmp_max) | |||||
return x_orig | |||||
class ExponentialMovingAverageObserver(MinMaxObserver): | |||||
def __init__(self, momentum=0.9, *args, **kwargs): | |||||
super().__init__(*args, **kwargs) | |||||
self.momentum = Buffer(momentum) | |||||
def set_momentum(self, momentum): | |||||
self.momentum.set_value(momentum) | |||||
def forward(self, x_orig): | |||||
if self.enabled: | |||||
# stop gradient | |||||
x = F.zero_grad(x_orig) | |||||
# Exponential Moving Average | |||||
tmp_min, _ = F.cond_take( | |||||
self.first_flag, | |||||
F.concat( | |||||
[ | |||||
x.min(), | |||||
self.momentum * self.min_val + (1 - self.momentum) * x.min(), | |||||
] | |||||
), | |||||
) | |||||
tmp_max, _ = F.cond_take( | |||||
self.first_flag, | |||||
F.concat( | |||||
[ | |||||
x.max(), | |||||
self.momentum * self.max_val + (1 - self.momentum) * x.max(), | |||||
] | |||||
), | |||||
) | |||||
self.set_min_max(tmp_min, tmp_max) | |||||
return x_orig |
@@ -0,0 +1,82 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from functools import partial | |||||
from ..module import Module | |||||
from .fake_quant import FakeQuantize | |||||
from .observer import ExponentialMovingAverageObserver, MinMaxObserver | |||||
class QConfig: | |||||
""" | |||||
A config class indicating how to do quantize toward :class:`~.QATModule`'s | |||||
``activation``, ``weight`` and ``bias``. | |||||
And ``fake_quant`` parameter to indicate | |||||
See :meth:`~.QATModule.set_qconfig` for detail usage. | |||||
:param inp_observer: interface to instantiate an :class:`~.Observer` indicating | |||||
how to collect scales and zero_point of input. | |||||
:param weight_observer: similar to ``inp_observer`` but toward weight. | |||||
:param act_observer: similar to ``inp_observer`` but toward activation. | |||||
:param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating | |||||
how to do fake_quant calculation. can be invoked multi times to get different | |||||
instance for each target tensor, for better control on enable and disable. | |||||
:param bias_fake_quant: similar to ``fake_quant``, but usually need to set ``dtype`` | |||||
in advance, for bias's dtype is unable to be inferred from observer. | |||||
Examples: | |||||
.. code-block:: | |||||
# Default EMA QConfig for QAT. | |||||
ema_fakequant_qconfig = QConfig( | |||||
inp_observer=ExponentialMovingAverageObserver, | |||||
weight_observer=ExponentialMovingAverageObserver, | |||||
act_observer=ExponentialMovingAverageObserver, | |||||
fake_quant=FakeQuantize, | |||||
) | |||||
""" | |||||
def __init__( | |||||
self, act_observer, weight_observer, inp_observer, fake_quant, bias_fake_quant, | |||||
): | |||||
if ( | |||||
isinstance(act_observer, Module) | |||||
or isinstance(weight_observer, Module) | |||||
or isinstance(inp_observer, Module) | |||||
): | |||||
raise ValueError( | |||||
"QConfig must not receive observer instance, please pass observer" | |||||
" class generator using `partial(Observer, ...)` instead. Use" | |||||
" partial(MyObserver, x=1) to override arguments to constructor if needed" | |||||
) | |||||
self.act_observer = act_observer | |||||
self.weight_observer = weight_observer | |||||
self.inp_observer = inp_observer | |||||
self.fake_quant = fake_quant | |||||
self.bias_fake_quant = bias_fake_quant | |||||
# Default QAT QConfigs | |||||
min_max_fakequant_qconfig = QConfig( | |||||
inp_observer=MinMaxObserver, | |||||
weight_observer=MinMaxObserver, | |||||
act_observer=MinMaxObserver, | |||||
fake_quant=FakeQuantize, | |||||
bias_fake_quant=partial(FakeQuantize, dtype="qint32"), | |||||
) | |||||
ema_fakequant_qconfig = QConfig( | |||||
inp_observer=ExponentialMovingAverageObserver, | |||||
weight_observer=MinMaxObserver, | |||||
act_observer=ExponentialMovingAverageObserver, | |||||
fake_quant=FakeQuantize, | |||||
bias_fake_quant=partial(FakeQuantize, dtype="qint32"), | |||||
) |
@@ -0,0 +1,113 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from copy import deepcopy | |||||
from ..module import Module, QATModule, Sequential, quantized | |||||
from .qconfig import QConfig, ema_fakequant_qconfig | |||||
def quantize(module: Module, inplace=True): | |||||
r""" | |||||
Recursively convert `module` to `quantized` mode through :meth:`~.Module.apply`. | |||||
:param module: root module to do convert recursively. | |||||
""" | |||||
if not inplace: | |||||
module = deepcopy(module) | |||||
def is_qat_module(obj): | |||||
return isinstance(obj, QATModule) | |||||
# no need to pass prefix and get pure key of parent Module. | |||||
for key, submodule, parent in module._flatten( | |||||
with_key=True, with_parent=True, predicate=is_qat_module | |||||
): | |||||
if isinstance(parent, Sequential): | |||||
# cannnot use setattr to be compatible with Sequential's ``__setitem__`` | |||||
parent[int(key.split(".")[-1])] = submodule.to_quantized() | |||||
else: | |||||
setattr(parent, key.split(".")[-1], submodule.to_quantized()) | |||||
def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): | |||||
r""" | |||||
Recursively convert `module` to `qat` mode through :meth:`~.Module.apply` | |||||
and set qconfig relatively. | |||||
:param module: root module to do convert recursively. | |||||
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. | |||||
default is :any:`~.qconfig.ema_fakequant_qconfig`. | |||||
""" | |||||
def fn(mod: Module): | |||||
if isinstance(mod, QATModule): | |||||
mod.set_qat_mode(QATModule.QATMode.QAT) | |||||
mod.set_qconfig(qconfig) | |||||
module.apply(fn) | |||||
def disable_fake_quant(module: Module): | |||||
r""" | |||||
Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply` | |||||
:param module: root module to do disable fake quantization recursively. | |||||
""" | |||||
def fn(mod): | |||||
if isinstance(mod, QATModule): | |||||
mod.act_fake_quant.disable() | |||||
mod.weight_fake_quant.disable() | |||||
mod.inp_fake_quant.disable() | |||||
module.apply(fn) | |||||
def disable_observer(module: Module): | |||||
r""" | |||||
Recursively disable `module` observer in QATModule through :meth:`~.Module.apply` | |||||
:param module: root module to do disable observer recursively. | |||||
""" | |||||
def fn(mod): | |||||
if isinstance(mod, QATModule): | |||||
mod.act_observer.disable() | |||||
module.apply(fn) | |||||
def enable_fake_quant(module: Module): | |||||
r""" | |||||
Recursively enable `module` fake quantization in QATModule through :meth:`~.Module.apply` | |||||
:param module: root module to do enable fake quantization recursively. | |||||
""" | |||||
def fn(mod): | |||||
if isinstance(mod, QATModule): | |||||
mod.act_fake_quant.enable() | |||||
mod.weight_fake_quant.enable() | |||||
mod.inp_fake_quant.enable() | |||||
module.apply(fn) | |||||
def enable_observer(module: Module): | |||||
r""" | |||||
Recursively enable `module` observer in QATModule through :meth:`~.Module.apply` | |||||
:param module: root module to do enable observer recursively. | |||||
""" | |||||
def fn(mod): | |||||
if isinstance(mod, QATModule): | |||||
mod.act_observer.enable() | |||||
module.apply(fn) |
@@ -0,0 +1,23 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from functools import partial, update_wrapper, wraps | |||||
def register_method_to_class(cls): | |||||
def decorator(func): | |||||
@wraps(func) | |||||
def wrapper(self, *args, **kwargs): | |||||
return func(self, *args, **kwargs) | |||||
if isinstance(func, partial): | |||||
update_wrapper(func, func.func) | |||||
setattr(cls, func.__name__, wrapper) | |||||
return func | |||||
return decorator |
@@ -7,10 +7,12 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import numpy as np | import numpy as np | ||||
import pytest | |||||
from helpers import opr_test | from helpers import opr_test | ||||
import megengine._internal as mgb | |||||
import megengine.functional as F | import megengine.functional as F | ||||
from megengine import Buffer, jit, tensor | |||||
from megengine import Buffer, Parameter, is_cuda_available, jit, tensor | |||||
from megengine.test import assertTensorClose | from megengine.test import assertTensorClose | ||||
@@ -332,3 +334,108 @@ def test_binary_cross_entropy(): | |||||
{"input": [data2, label2], "output": expect2,}, | {"input": [data2, label2], "output": expect2,}, | ||||
] | ] | ||||
opr_test(cases, F.binary_cross_entropy, compare_fn=compare_fn) | opr_test(cases, F.binary_cross_entropy, compare_fn=compare_fn) | ||||
@pytest.mark.skip | |||||
def test_conv_bias(): | |||||
inp_scale = 0.01 | |||||
w_scale = 0.02 | |||||
outp_scale = 0.1 | |||||
inp_dtype = mgb.dtype.qint8(inp_scale) | |||||
w_dtype = mgb.dtype.qint8(w_scale) | |||||
b_dtype = mgb.dtype.qint32(inp_scale * w_scale) | |||||
out_dtype = mgb.dtype.qint8(outp_scale) | |||||
def run( | |||||
N, | |||||
IC, | |||||
OC, | |||||
IH, | |||||
IW, | |||||
KH, | |||||
KW, | |||||
PH, | |||||
PW, | |||||
SH, | |||||
SW, | |||||
has_bias=True, | |||||
nonlinear_mode="IDENTITY", | |||||
): | |||||
inp_v = np.random.normal(size=(N, IC, IH, IW)) | |||||
w_v = np.random.normal(size=(OC, IC, KW, KW)) | |||||
b_v = np.random.normal(size=(1, OC, 1, 1)) | |||||
inp_scale = mgb.dtype.get_scale(inp_dtype) | |||||
w_scale = mgb.dtype.get_scale(w_dtype) | |||||
b_scale = mgb.dtype.get_scale(b_dtype) | |||||
inpv = mgb.dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype) | |||||
wv = mgb.dtype.convert_to_qint8(w_v * w_scale, w_dtype) | |||||
bv = mgb.dtype.convert_to_qint32(b_v * b_scale, b_dtype) | |||||
inp_int8 = tensor(inpv, dtype=inp_dtype) | |||||
w_int8 = Parameter(wv, dtype=w_dtype) | |||||
b_int32 = Parameter(bv, dtype=b_dtype) | |||||
inp_fp32 = inp_int8.astype("float32") | |||||
w_fp32 = w_int8.astype("float32") | |||||
b_fp32 = b_int32.astype("float32") | |||||
jit.trace.enabled = True | |||||
b_symbolic = True | |||||
def convert_to_nchw4(var): | |||||
return var.reshape( | |||||
var.shapeof(0), var.shapeof(1) // 4, 4, var.shapeof(2), var.shapeof(3) | |||||
).dimshuffle(0, 1, 3, 4, 2) | |||||
@jit.trace(symbolic=b_symbolic) | |||||
def run_conv2d(inp, w, b): | |||||
O = F.conv2d( | |||||
inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW), | |||||
) | |||||
if nonlinear_mode == "RELU": | |||||
return F.relu(O) | |||||
else: | |||||
return O | |||||
@jit.trace(symbolic=b_symbolic) | |||||
def run_conv_bias(inp, w, b, format="NCHW"): | |||||
b = b if has_bias else np.zeros_like(b) | |||||
if format == "NCHW4": | |||||
inp = convert_to_nchw4(inp) | |||||
w = convert_to_nchw4(w) | |||||
b = F.flatten(b) | |||||
return F.conv_bias_activation( | |||||
inp, | |||||
w, | |||||
b, | |||||
stride=(SH, SW), | |||||
padding=(PH, PW), | |||||
dtype=out_dtype, | |||||
nonlinear_mode=nonlinear_mode, | |||||
) | |||||
format = "NCHW4" if is_cuda_available() else "NCHW" | |||||
expected = run_conv2d(inp_fp32, w_fp32, b_fp32) | |||||
expected = expected.astype(out_dtype).astype("float32") | |||||
result = run_conv_bias(inp_int8, w_int8, b_int32, format=format).astype( | |||||
"float32" | |||||
) | |||||
if format == "NCHW4": | |||||
result = result.dimshuffle(0, 1, 4, 2, 3) | |||||
expected = F.flatten(expected) | |||||
result = F.flatten(result) | |||||
assertTensorClose(result.numpy(), expected.numpy()) | |||||
if not is_cuda_available(): | |||||
run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False) | |||||
run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False) | |||||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False) | |||||
run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1) | |||||
run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1) | |||||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2) | |||||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU") | |||||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU") |