GitOrigin-RevId: 4fe1233ec3
tags/v0.4.0
@@ -74,12 +74,14 @@ from .nn import ( | |||
softmax, | |||
warp_perspective, | |||
) | |||
from .quantized import conv_bias_activation | |||
from .sort import argsort, sort, top_k | |||
from .tensor import ( | |||
add_axis, | |||
arange, | |||
broadcast_to, | |||
concat, | |||
cond_take, | |||
dimshuffle, | |||
gather, | |||
linspace, | |||
@@ -0,0 +1,84 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
# pylint: disable=too-many-lines | |||
from typing import Tuple, Union | |||
from .. import _internal as mgb | |||
from ..core import Tensor, wrap_io_tensor | |||
from ..utils.types import _pair, _pair_nonzero | |||
from .debug_param import get_conv_execution_strategy | |||
@wrap_io_tensor | |||
def conv_bias_activation( | |||
inp: Tensor, | |||
weight: Tensor, | |||
bias: Tensor, | |||
dtype=None, | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
groups: int = 1, | |||
nonlinear_mode="IDENTITY", | |||
conv_mode="CROSS_CORRELATION", | |||
compute_mode="DEFAULT", | |||
) -> Tensor: | |||
""" convolution bias with activation operation, only for inference. | |||
:param inp: The feature map of the convolution operation | |||
:param weight: The convolution kernel | |||
:param bias: The bias added to the result of convolution | |||
:param stride: Stride of the 2D convolution operation. Default: 1 | |||
:param padding: Size of the paddings added to the input on both sides of its | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param dilation: Dilation of the 2D convolution operation. Default: 1 | |||
:param groups: number of groups to divide input and output channels into, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``, | |||
and the shape of weight should be ``(groups, out_channel // groups, | |||
in_channels // groups, height, width)``. | |||
:type conv_mode: string or :class:`mgb.opr_param_defs.Convolution.Mode` | |||
:param conv_mode: Supports 'CROSS_CORRELATION' or 'CONVOLUTION'. Default: | |||
'CROSS_CORRELATION'. | |||
:param dtype: Support for np.dtype, Default: | |||
np.int8. | |||
:param scale: scale if use quantization, Default: | |||
0.0. | |||
:param zero_point: scale if use quantization quint8, Default: | |||
0.0. | |||
:type compute_mode: string or | |||
:class:`mgb.opr_param_defs.Convolution.ComputeMode` | |||
:param compute_mode: When set to 'DEFAULT', no special requirements will be | |||
placed on the precision of intermediate results. When set to 'FLOAT32', | |||
Float32 would be used for accumulator and intermediate result, but only | |||
effective when input and output are of Float16 dtype. | |||
""" | |||
ph, pw = _pair(padding) | |||
sh, sw = _pair_nonzero(stride) | |||
dh, dw = _pair_nonzero(dilation) | |||
sparse_type = "DENSE" if groups == 1 else "GROUP" | |||
res = mgb.opr.conv_bias_activation( | |||
inp, | |||
weight, | |||
bias, | |||
compute_mode=compute_mode, | |||
dtype=dtype, | |||
strategy=get_conv_execution_strategy(), | |||
nonlineMode=nonlinear_mode, | |||
sparse=sparse_type, | |||
format="NCHW", | |||
pad_h=ph, | |||
pad_w=pw, | |||
stride_h=sh, | |||
stride_w=sw, | |||
dilate_h=dh, | |||
dilate_w=dw, | |||
mode=conv_mode, | |||
) | |||
return res |
@@ -359,6 +359,41 @@ def where(mask: Tensor, x: Tensor, y: Tensor) -> Tensor: | |||
return out | |||
@wrap_io_tensor | |||
def cond_take(mask: Tensor, x: Tensor, val=1) -> Tensor: | |||
r""" | |||
Take elements from data if specific condition is satisfied on mask. This operator has two outputs: the first is the elements taken, and the second is the indices corresponding to those elements; they are both 1-dimensional. High-dimension input would first be flattened. | |||
:param mask: condition param; must be the same shape with data | |||
:param x: input tensor from which to take elements | |||
:param val: value to be compared to by mode | |||
Examples: | |||
.. testcode:: | |||
from megengine import tensor | |||
import megengine.functional as F | |||
mask = tensor(np.array([[1, 0], [0, 1]], dtype=np.int32)) | |||
x = tensor(np.array([[1, np.inf], [np.nan, 4]], | |||
dtype=np.float32)) | |||
v, index = F.cond_take(mask, x, 1) | |||
print(v, index) | |||
Outputs: | |||
.. testoutput:: | |||
Tensor([1. 4.]) Tensor([0 3], dtype=int32) | |||
""" | |||
v, index = mgb.opr.cond_take( | |||
x, mask, mode=mgb.opr_param_defs.CondTake.Mode.EQ, val=val | |||
) | |||
return v, index | |||
def shapeof(x: Tensor, axis=None): | |||
r""" | |||
The shape of input tensor. | |||
@@ -8,12 +8,16 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | |||
from .batchnorm import BatchNorm1d, BatchNorm2d | |||
from .concat import Concat | |||
from .conv import Conv2d, ConvTranspose2d | |||
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||
from .dropout import Dropout | |||
from .elemwise import Elemwise | |||
from .embedding import Embedding | |||
from .identity import Identity | |||
from .linear import Linear | |||
from .module import Module | |||
from .module import Module, QATModule | |||
from .parampack import ParamPack | |||
from .pooling import AvgPool2d, MaxPool2d | |||
from .quant_dequant import DequantStub, QuantStub | |||
from .sequential import Sequential |
@@ -0,0 +1,27 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Iterable | |||
from .. import functional as F | |||
from ..core.tensor import Tensor | |||
from .module import QATModule | |||
class Concat(QATModule): | |||
r""" | |||
A :class:`~.QATModule` to do functional concat, should replace concat with this module, | |||
supporting ``qat`` mode and ``quantized`` mode. | |||
""" | |||
def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||
return F.concat(inps, axis) | |||
def forward_qat(self, inps: Iterable[Tensor], axis: int = 0): | |||
return self.apply_fakequant_with_observer( | |||
self.forward(inps, axis), self.act_fake_quant, self.act_observer | |||
) |
@@ -182,11 +182,11 @@ class Conv2d(_ConvNd): | |||
# Assume format is NCHW | |||
return (1, self.out_channels, 1, 1) | |||
def forward(self, inp): | |||
def calc_conv(self, inp, weight, bias): | |||
return conv2d( | |||
inp, | |||
self.weight, | |||
self.bias, | |||
weight, | |||
bias, | |||
self.stride, | |||
self.padding, | |||
self.dilation, | |||
@@ -195,6 +195,9 @@ class Conv2d(_ConvNd): | |||
self.compute_mode, | |||
) | |||
def forward(self, inp): | |||
return self.calc_conv(inp, self.weight, self.bias) | |||
class ConvTranspose2d(_ConvNd): | |||
r"""Applies a 2D transposed convolution over an input tensor. | |||
@@ -0,0 +1,168 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Tuple, Union | |||
from ..core import ones, zeros | |||
from ..functional import flatten, relu, sqrt, sum | |||
from .batchnorm import BatchNorm2d | |||
from .conv import Conv2d | |||
from .module import QATModule | |||
class _ConvBn2d(QATModule): | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
out_channels: int, | |||
kernel_size: Union[int, Tuple[int, int]], | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
groups: int = 1, | |||
bias: bool = True, | |||
conv_mode: str = "CROSS_CORRELATION", | |||
compute_mode: str = "DEFAULT", | |||
eps=1e-5, | |||
momentum=0.9, | |||
affine=True, | |||
track_running_stats=True, | |||
freeze_bn=False, | |||
): | |||
super().__init__() | |||
self.conv = Conv2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
stride, | |||
padding, | |||
dilation, | |||
groups, | |||
bias, | |||
conv_mode, | |||
compute_mode, | |||
) | |||
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | |||
self.freeze_bn = freeze_bn | |||
def update_bn_stats(self): | |||
self.freeze_bn = False | |||
return self | |||
def freeze_bn_stats(self): | |||
self.freeze_bn = True | |||
return self | |||
def get_bn_gamma_beta(self): | |||
if self.bn.weight is None: | |||
gamma = ones((self.bn.num_features), dtype="float32") | |||
else: | |||
gamma = self.bn.weight | |||
if self.bn.bias is None: | |||
beta = zeros((self.bn.num_features), dtype="float32") | |||
else: | |||
beta = self.bn.bias | |||
return gamma, beta | |||
def get_batch_mean_var(self, inp): | |||
def _sum_channel(inp, axis=0, keepdims=True): | |||
if isinstance(axis, int): | |||
out = sum(inp, axis=axis, keepdims=keepdims) | |||
elif isinstance(axis, tuple): | |||
for idx, elem in enumerate(axis): | |||
out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) | |||
return out | |||
sum1 = _sum_channel(inp, (0, 2, 3)) | |||
sum2 = _sum_channel(inp ** 2, (0, 2, 3)) | |||
reduce_size = inp.shapeof().prod() / inp.shapeof(1) | |||
batch_mean = sum1 / reduce_size | |||
batch_var = (sum2 - sum1 ** 2 / reduce_size) / (reduce_size - 1) | |||
return batch_mean, batch_var | |||
def fold_weight_bias(self, bn_mean, bn_var): | |||
# get fold bn conv param | |||
# bn_istd = 1 / bn_std | |||
# w_fold = gamma / bn_std * W | |||
# b_fold = gamma * (b - bn_mean) / bn_std + beta | |||
gamma, beta = self.get_bn_gamma_beta() | |||
b = self.conv.bias | |||
if b is None: | |||
b = zeros(self.conv._infer_bias_shape(), dtype="float32") | |||
if bn_mean is None: | |||
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||
if bn_var is None: | |||
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") | |||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
if self.conv.groups == 1: | |||
w_fold = ( | |||
self.conv.weight | |||
* gamma.reshape(-1, 1, 1, 1) | |||
* bn_istd.reshape(-1, 1, 1, 1) | |||
) | |||
else: | |||
w_fold = ( | |||
self.conv.weight | |||
* gamma.reshape(self.conv.groups, -1, 1, 1, 1) | |||
* bn_istd.reshape(self.conv.groups, -1, 1, 1, 1) | |||
) | |||
b_fold = flatten(beta) + ( | |||
flatten(gamma) * (flatten(b) - flatten(bn_mean)) * flatten(bn_istd) | |||
) | |||
b_fold = b_fold.reshape(self.conv._infer_bias_shape()) | |||
return w_fold, b_fold | |||
def calc_conv_bn_qat(self, inp): | |||
# TODO: use pytorch method as | |||
conv = self.conv(inp) | |||
self.bn(conv) | |||
if self.training: | |||
bn_mean, bn_var = self.get_batch_mean_var(conv) | |||
else: | |||
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var | |||
w_fold, b_fold = self.fold_weight_bias(bn_mean, bn_var) | |||
w_qat = self.apply_fakequant_with_observer( | |||
w_fold, self.weight_fake_quant, self.weight_observer | |||
) | |||
return self.conv.calc_conv(inp, w_qat, b_fold) | |||
class ConvBn2d(_ConvBn2d): | |||
r""" | |||
A fused :class:`~.QATModule` including Conv2d and BatchNorm2d, supporting ``qat`` mode | |||
and ``normal`` mode. | |||
""" | |||
def forward_qat(self, inp): | |||
return self.apply_fakequant_with_observer( | |||
self.calc_conv_bn_qat(inp), self.act_fake_quant, self.act_observer | |||
) | |||
def forward(self, inp): | |||
return self.bn(self.conv(inp)) | |||
class ConvBnRelu2d(_ConvBn2d): | |||
r""" | |||
A fused :class:`~.QATModule` including Conv2d, BatchNorm2d and relu, supporting ``qat`` | |||
mode and ``normal`` mode. | |||
""" | |||
def forward_qat(self, inp): | |||
return self.apply_fakequant_with_observer( | |||
relu(self.calc_conv_bn_qat(inp)), self.act_fake_quant, self.act_observer | |||
) | |||
def forward(self, inp): | |||
return relu(self.bn(self.conv(inp))) |
@@ -0,0 +1,95 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .. import _internal as mgb | |||
from ..core import Tensor, wrap_io_tensor | |||
from ..core.graph import _use_default_if_none | |||
from .module import QATModule | |||
@wrap_io_tensor | |||
def _elemwise_func(mode, *inputs, **kwargs) -> Tensor: | |||
if all(isinstance(i, (int, float)) for i in inputs): | |||
device, comp_graph = _use_default_if_none(None, None) | |||
ret = mgb.opr.elemwise( | |||
*inputs, mode=mode, comp_node=device, comp_graph=comp_graph, **kwargs | |||
) | |||
return ret.inferred_value[0] | |||
return mgb.opr.elemwise(*inputs, mode=mode, **kwargs) | |||
class Elemwise(QATModule): | |||
r""" | |||
A :class:`~.QATModule` to do elemwise operator, should functional operator with this module, | |||
supporting ``qat`` mode and ``normal`` mode. | |||
:param method: the elemwise method, support the following string. | |||
It will do the normal elemwise operator for float. | |||
* "ADD": a + b | |||
* "FUSE_ADD_RELU": max(x+y, 0) | |||
* "MUL": x * y | |||
* "MIN": min(x, y) | |||
* "MAX": max(x, y) | |||
* "SUB": x - y | |||
* "TRUE_DIV": x / y | |||
* "FUSE_ADD_SIGMOID": sigmoid(x + y) | |||
* "FUSE_ADD_TANH": tanh(x + y) | |||
* "RELU": x > 0 ? x : 0 | |||
* "ABS": x > 0 ? x : -x | |||
* "SIGMOID": sigmoid(x) | |||
* "EXP": exp(x) | |||
* "TANH": tanh(x) | |||
* "FUSE_MUL_ADD3": x * y + z | |||
* "FAST_TANH": fast_tanh(x) | |||
* "NEGATE": -x | |||
* "ACOS": acos(x) | |||
* "ASIN": asin(x) | |||
* "CEIL": ceil(x) | |||
* "COS": cos(x) | |||
* "EXPM1": expm1(x) | |||
* "FLOOR": floor(x) | |||
* "LOG": log(x) | |||
* "LOG1P": log1p(x) | |||
* "SIN": sin(x) | |||
* "ROUND": round(x) | |||
* "ERF": erf(x) | |||
* "ERFINV": erfinv(x) | |||
* "ERFC": erfc(x) | |||
* "ERFCINV": erfcinv(x) | |||
* "ABS_GRAD": abs_grad | |||
* "FLOOR_DIV": floor_div | |||
* "MOD": mod | |||
* "SIGMOID_GRAD": sigmoid_grad | |||
* "SWITCH_GT0": switch_gt0 | |||
* "TANH_GRAD": tanh_grad | |||
* "LT": lt | |||
* "LEQ": leq | |||
* "EQ": eq | |||
* "POW": pow | |||
* "LOG_SUM_EXP": log_sum_exp | |||
* "FAST_TANH_GRAD": fast_tanh_grad | |||
* "ATAN2": atan2 | |||
* "COND_LEQ_MOV": cond_leq_mov | |||
* "H_SWISH": h_swish | |||
* "FUSE_ADD_H_SWISH": h_swish(x+y) | |||
* "H_SWISH_GRAD": h_swish_grad | |||
""" | |||
_elemwise_mode_type = mgb.opr_param_defs.Elemwise.Mode | |||
def __init__(self, method): | |||
super().__init__() | |||
self.method = self._elemwise_mode_type.convert(method) | |||
def forward(self, *inps): | |||
return _elemwise_func(self.method, *inps) | |||
def forward_qat(self, *inps): | |||
return self.apply_fakequant_with_observer( | |||
self.forward(*inps), self.act_fake_quant, self.act_observer, | |||
) |
@@ -1,4 +1,3 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
@@ -8,6 +7,7 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from abc import ABCMeta, abstractmethod | |||
from collections import OrderedDict | |||
from enum import Enum | |||
from typing import Any, Callable, Iterable, Optional, Set, Tuple, Union | |||
import numpy as np | |||
@@ -442,3 +442,95 @@ class Module(metaclass=ABCMeta): | |||
loaded.append(k) | |||
return set(loaded), set(skipped) | |||
class QATModule(Module): | |||
r""" | |||
Base class of quantization related Module. Add extra forward methods | |||
:meth:`~.QATModule.forward_qat` and :meth:`~.QATModule.forward_quantized` for | |||
``qat``(quantization aware training) mode and ``quantized`` mode respectively. | |||
Use :meth:`~.QATModule.quant` to switch between ``QAT`` and ``NORMAL`` mode, | |||
and use :meth:`~.QATModule.to_quantized` to switch to ``quantized`` mode, | |||
which is irreversible. | |||
If you want to recursively switch mode for all QATModule in network, use | |||
functions in :mod:`~.quantization.quantize`. | |||
""" | |||
class QATMode(Enum): | |||
DISABLED = 1 | |||
QAT = 2 | |||
CALIBRATION = 3 | |||
def __init__(self): | |||
from ..quantization import ( | |||
QConfig, | |||
FakeQuantize, | |||
Observer, | |||
) # pylint: disable=all | |||
super().__init__() | |||
self.quantizing = self.QATMode.DISABLED | |||
self.scale = None | |||
self.inp_observer = None # type: Observer | |||
self.weight_observer = None # type: Observer | |||
self.act_observer = None # type: Observer | |||
self.weight_fake_quant = None # type: FakeQuantize | |||
self.bias_fake_quant = None # type: FakeQuantize | |||
self.act_fake_quant = None # type: FakeQuantize | |||
def set_qconfig(self, qconfig: "QConfig"): | |||
self.inp_observer = qconfig.inp_observer() | |||
self.weight_observer = qconfig.weight_observer() | |||
self.act_observer = qconfig.act_observer() | |||
self.weight_fake_quant = qconfig.fake_quant(self.weight_observer.dtype) | |||
self.bias_fake_quant = qconfig.bias_fake_quant() | |||
self.act_fake_quant = qconfig.fake_quant(self.act_observer.dtype) | |||
def apply_observer(self, target: Tensor, obs: "Observer"): | |||
return obs(target) | |||
def apply_fakequant_with_observer( | |||
self, target: Tensor, fq: "FakeQuantize", obs: "Observer" | |||
): | |||
oup = self.apply_observer(target, obs) | |||
return fq(oup, obs.scale, obs.zero_point) | |||
def set_qat_mode(self, mode: QATMode): | |||
r""" | |||
Change ``self.quantizing`` mode, available values: ``self.QATMode.DISABLED``, | |||
``QAT``,``CALIBRATION``. | |||
""" | |||
if not isinstance(mode, self.QATMode): | |||
raise TypeError("mode must be QATMode Enum type") | |||
self.quantizing = mode | |||
def to_quantized(self): | |||
r""" | |||
Return a new :class:`~.Module` with quantized parameters of ``self`` | |||
according to scale and zero_point in ``self.xxx_observer``. | |||
""" | |||
raise NotImplementedError( | |||
"Use megengine.quantization.quantize to register the method." | |||
) | |||
@abstractmethod | |||
def forward_qat(self, *args, **kwargs): | |||
r""" | |||
Forward method for ``qat`` mode. | |||
""" | |||
def __call__(self, *args, **kwargs): | |||
if self.quantizing == self.QATMode.QAT: | |||
return self.forward_qat(*args, **kwargs) | |||
elif self.quantizing == self.QATMode.CALIBRATION: | |||
# TODO implement the CALIBRATION | |||
assert False | |||
return None | |||
else: | |||
return self.forward(*args, **kwargs) |
@@ -0,0 +1,34 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .module import QATModule | |||
class QuantStub(QATModule): | |||
r""" | |||
A helper QATModule doing quantize operation on input. | |||
""" | |||
def forward(self, inp): | |||
return inp | |||
def forward_qat(self, inp): | |||
return self.apply_fakequant_with_observer( | |||
inp, self.act_fake_quant, self.act_observer | |||
) | |||
class DequantStub(QATModule): | |||
r""" | |||
A helper QATModule doing de-quantize operation on input. | |||
""" | |||
def forward(self, inp): | |||
return inp | |||
def forward_qat(self, inp): | |||
return inp |
@@ -0,0 +1,11 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .concat import Concat | |||
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||
from .elemwise import Elemwise | |||
from .quant_dequant import DequantStub, QuantStub |
@@ -0,0 +1,45 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Iterable | |||
from ... import _internal as mgb | |||
from ... import functional as F | |||
from ... import module as Float | |||
from ...core.tensor import Tensor | |||
from ...quantization.utils import register_method_to_class | |||
from ..module import Module | |||
class Concat(Module): | |||
r""" | |||
A :class:`~.Module` to do quantized concat, inference only. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self.scale = 1.0 | |||
self.zero_point = 0.0 | |||
self.output_dtype = mgb.dtype.qint8(self.scale) | |||
def forward(self, inps: Iterable[Tensor], axis: int = 0): | |||
if self.training: | |||
raise ValueError("quantized module only support inference.") | |||
new_inps = (x.astype(self.output_dtype) for x in inps) | |||
return F.concat(new_inps, axis) | |||
@register_method_to_class(Float.Concat) | |||
def to_quantized(float_module): | |||
r""" | |||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
implemented here to avoid circular import. | |||
""" | |||
qmod = Concat() | |||
qmod.output_dtype = float_module.act_observer.get_dtype() | |||
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() | |||
return qmod |
@@ -0,0 +1,114 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from functools import partial | |||
from typing import Tuple, Union | |||
import megengine._internal as mgb | |||
from ... import module as Float | |||
from ...core import Parameter | |||
from ...functional import conv_bias_activation | |||
from ...module import Conv2d | |||
from ...quantization.utils import register_method_to_class | |||
class _ConvBnActivation2d(Conv2d): | |||
r"""Applies a 2D convolution over an quantized input tensor, inference only. | |||
The parameter is same with :class: `~.Conv2d` | |||
""" | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
out_channels: int, | |||
kernel_size: Union[int, Tuple[int, int]], | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
groups: int = 1, | |||
conv_mode: str = "CROSS_CORRELATION", | |||
compute_mode: str = "DEFAULT", | |||
): | |||
super().__init__( | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
stride, | |||
padding, | |||
dilation, | |||
groups, | |||
True, | |||
conv_mode, | |||
compute_mode, | |||
) | |||
self.scale = 1.0 | |||
self.zero_point = 0.0 | |||
self.output_dtype = mgb.dtype.qint8(self.scale) | |||
self.weight = self.weight.astype(self.output_dtype) | |||
self.bias = self.bias.astype(mgb.dtype.qint32(self.scale)) | |||
def calc_conv_quantized(self, inp, nonlinear_mode="IDENTITY"): | |||
inp_scale = mgb.dtype.get_scale(inp.dtype) | |||
w_scale = mgb.dtype.get_scale(self.weight.dtype) | |||
bias_scale = inp_scale * w_scale | |||
return conv_bias_activation( | |||
inp, | |||
self.weight, | |||
self.bias.astype(mgb.dtype.qint32(bias_scale)), | |||
self.output_dtype, | |||
self.stride, | |||
self.padding, | |||
self.dilation, | |||
self.groups, | |||
conv_mode=self.conv_mode, | |||
compute_mode=self.compute_mode, | |||
nonlinear_mode=nonlinear_mode, | |||
) | |||
class ConvBn2d(_ConvBnActivation2d): | |||
def forward(self, inp): | |||
if self.training: | |||
raise ValueError("quantized module only support inference.") | |||
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") | |||
class ConvBnRelu2d(_ConvBnActivation2d): | |||
def forward(self, inp): | |||
if self.training: | |||
raise ValueError("quantized module only support inference.") | |||
return self.calc_conv_quantized(inp, nonlinear_mode="RELU") | |||
def to_quantized(quantized_class, float_module): | |||
qconv = quantized_class( | |||
float_module.conv.in_channels, | |||
float_module.conv.out_channels, | |||
float_module.conv.kernel_size, | |||
float_module.conv.stride, | |||
float_module.conv.padding, | |||
float_module.conv.dilation, | |||
float_module.conv.groups, | |||
) | |||
w_fold, b_fold = float_module.fold_weight_bias( | |||
float_module.bn.running_mean, float_module.bn.running_var | |||
) | |||
weight = w_fold.astype(float_module.weight_observer.get_dtype()) | |||
qconv.output_dtype = float_module.act_observer.get_dtype() | |||
qconv.weight = Parameter(weight.numpy()) | |||
qconv.bias = Parameter(b_fold.numpy()) | |||
qconv.scale, qconv.zero_point = float_module.act_observer.get_qparams() | |||
return qconv | |||
# replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
# implemented here to avoid circular import. | |||
register_method_to_class(Float.ConvBn2d)(partial(to_quantized, ConvBn2d)) | |||
register_method_to_class(Float.ConvBnRelu2d)(partial(to_quantized, ConvBnRelu2d)) |
@@ -0,0 +1,59 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ... import _internal as mgb | |||
from ... import module as Float | |||
from ...core import Tensor, wrap_io_tensor | |||
from ...core.graph import _use_default_if_none | |||
from ...quantization.utils import register_method_to_class | |||
from ..module import Module | |||
@wrap_io_tensor | |||
def _elemwise_multi_type(mode, *inputs, **kwargs) -> Tensor: | |||
if all(isinstance(i, (int, float)) for i in inputs): | |||
device, comp_graph = _use_default_if_none(None, None) | |||
ret = mgb.opr.elemwise_multi_type( | |||
*inputs, mode=mode, comp_node=device, comp_graph=comp_graph, **kwargs, | |||
) | |||
return ret.inferred_value[0] | |||
return mgb.opr.elemwise_multi_type(*inputs, mode=mode, **kwargs) | |||
class Elemwise(Module): | |||
r""" | |||
quantized module for elemwise operator, inference only. | |||
:param method: the elemwise method, supported string refer to :class:`~.module.elemwise.Elemwise`. | |||
it will do quantized operator with specified output quantized dtype. | |||
""" | |||
_elemwise_multi_type_mode = mgb.opr_param_defs.ElemwiseMultiType.Mode | |||
def __init__(self, method): | |||
super().__init__() | |||
self.method = self._elemwise_multi_type_mode.convert("Q" + method) | |||
self.scale = 1.0 | |||
self.zero_point = 0.0 | |||
self.output_dtype = mgb.dtype.qint8(self.scale) | |||
def forward(self, *inps): | |||
if self.training: | |||
raise ValueError("quantized module only support inference.") | |||
return _elemwise_multi_type(self.method, *inps, dtype=self.output_dtype) | |||
@register_method_to_class(Float.Elemwise) | |||
def to_quantized(float_module): | |||
r""" | |||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
implemented here to avoid circular import. | |||
""" | |||
qmod = Elemwise(float_module.method.name) | |||
qmod.output_dtype = float_module.act_observer.get_dtype() | |||
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() | |||
return qmod |
@@ -0,0 +1,61 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ... import _internal as mgb | |||
from ... import module as Float | |||
from ...quantization.utils import register_method_to_class | |||
from ..module import Module | |||
class QuantStub(Module): | |||
r""" | |||
A helper quantize operation on input and inference only. | |||
""" | |||
def __init__(self): | |||
super().__init__() | |||
self.scale = 1.0 | |||
self.zero_point = 0.0 | |||
self.output_dtype = mgb.dtype.qint8(self.scale) | |||
def forward(self, inp): | |||
if self.training: | |||
raise ValueError("quantized module only support inference.") | |||
return inp.astype(self.output_dtype) | |||
class DequantStub(Module): | |||
r""" | |||
A helper de-quantize operation and inference only. | |||
""" | |||
def forward(self, inp): | |||
if self.training: | |||
raise ValueError("quantized module only support inference.") | |||
return inp.astype("float32") | |||
@register_method_to_class(Float.QuantStub) | |||
def to_quantized(float_module): | |||
r""" | |||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
implemented here to avoid circular import. | |||
""" | |||
qmod = QuantStub() | |||
qmod.output_dtype = float_module.act_observer.get_dtype() | |||
qmod.scale, qmod.zero_point = float_module.act_observer.get_qparams() | |||
return qmod | |||
@register_method_to_class(Float.DequantStub) | |||
def to_quantized(float_module): | |||
r""" | |||
Replace :class:`~.module.QATModule`'s ``to_quantized`` method. | |||
implemented here to avoid circular import. | |||
""" | |||
qmod = DequantStub() | |||
return qmod |
@@ -68,6 +68,7 @@ class Sequential(Module): | |||
def __setitem__(self, idx, module): | |||
key = self.layer_keys[idx] | |||
self.layer_values[idx] = module | |||
return setattr(self, key, module) | |||
def __delitem__(self, idx): | |||
@@ -0,0 +1,11 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .fake_quant import FakeQuantize | |||
from .observer import Observer | |||
from .qconfig import QConfig, ema_fakequant_qconfig, min_max_fakequant_qconfig | |||
from .quantize import quantize, quantize_qat |
@@ -0,0 +1,48 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .. import functional as F | |||
from .._internal.dtype import _metadata_dict | |||
from ..module import Module | |||
from .observer import Round | |||
class FakeQuantize(Module): | |||
r""" | |||
A module to do quant and dequant according to observer's scale and zero_point. | |||
""" | |||
def __init__(self, dtype: str, enable: bool = True): | |||
super().__init__() | |||
if not dtype in _metadata_dict.keys(): | |||
raise ValueError( | |||
"unknown dtype: {}, only support {}".format( | |||
dtype, _metadata_dict.keys() | |||
) | |||
) | |||
self.dtype = dtype | |||
self.qmin = _metadata_dict[dtype].qmin | |||
self.qmax = _metadata_dict[dtype].qmax | |||
self.enabled = enable | |||
def enable(self): | |||
self.enabled = True | |||
def disable(self): | |||
self.enabled = False | |||
def forward(self, inp, scale, zero_point): | |||
if self.enabled: | |||
# Quant | |||
oup = Round()(inp / scale) + zero_point | |||
# clip | |||
oup = F.minimum(F.maximum(oup, self.qmin), self.qmax) | |||
# DeQuant | |||
oup = (oup - zero_point) * scale | |||
return oup | |||
return inp |
@@ -0,0 +1,193 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from abc import abstractmethod | |||
import numpy as np | |||
from .. import functional as F | |||
from .._internal.dtype import _metadata_dict, get_quantized_dtype | |||
from ..core import Buffer, Function, ones, tensor, zeros | |||
from ..module import Module | |||
class Round(Function): | |||
def forward(self, x): | |||
return x.round() | |||
def backward(self, output_grads): | |||
return output_grads | |||
class Observer(Module): | |||
r""" | |||
A base class for Observer Module. | |||
:param dtype: a string indicating to collect scale and zero_point of which dtype | |||
""" | |||
def __init__(self, dtype="qint8"): | |||
super().__init__() | |||
if dtype not in _metadata_dict.keys(): | |||
raise ValueError( | |||
"unknown dtype: {}, only support {}".format( | |||
dtype, _metadata_dict.keys() | |||
) | |||
) | |||
self.dtype = dtype | |||
self.qmin = _metadata_dict[dtype].qmin | |||
self.qmax = _metadata_dict[dtype].qmax | |||
self.zero_point, self.scale = None, None | |||
self.enabled = True | |||
def get_dtype(self): | |||
scale, zero_point = self.get_qparams() | |||
numpy_scale = None if scale is None else scale.numpy()[0] | |||
numpy_zero_point = None if zero_point is None else zero_point.numpy()[0] | |||
return get_quantized_dtype(self.dtype, numpy_scale, numpy_zero_point) | |||
def enable(self): | |||
self.enabled = True | |||
def disable(self): | |||
self.enabled = False | |||
@abstractmethod | |||
def forward(self, x): | |||
pass | |||
@abstractmethod | |||
def get_qparams(self, **kwargs): | |||
pass | |||
class IdentityObserver(Observer): | |||
r""" | |||
An test Observer that always return scale:1 and zero_point:0. | |||
""" | |||
def __init__(self, *args, **kwargs): | |||
super().__init__(*args, **kwargs) | |||
self.zero_point = ones((1), dtype="float32") | |||
self.scale = zeros((1), dtype="float32") | |||
def forward(self, x): | |||
return x | |||
def get_qparams(self): | |||
return self.scale, self.zero_point | |||
class MinMaxObserver(Observer): | |||
def __init__(self, symmetric=True, eps=0.00001, *args, **kwargs): | |||
super().__init__(*args, **kwargs) | |||
self.symmetric = symmetric | |||
if self.symmetric: | |||
# assert qmin + qmax == -1, 'when reduce_range, qmin + qmax shoule equal -1' | |||
self.zero_point = tensor((self.qmin + self.qmax + 1) // 2) | |||
self.min_val = Buffer(0.0, dtype=np.float32) | |||
self.max_val = Buffer(0.0, dtype=np.float32) | |||
self.scale_limit = eps | |||
# flag is used by cond_take, first time will be first flag, and after will be set as not_flag | |||
self.first_flag = Buffer(np.array([1, 0], dtype=np.int32)) | |||
self.not_flag = Buffer(np.array([0, 1], dtype=np.int32)) | |||
def set_min_max(self, tmp_min, tmp_max): | |||
# FIXME: cond_take will destory shape, use reshape to reset shape | |||
tmp_min = tmp_min.reshape(1) | |||
tmp_max = tmp_max.reshape(1) | |||
if self.training: | |||
F.zero_grad( | |||
F.add_update(self.min_val, tmp_min, alpha=0.0, beta=1.0, bias=0.0) | |||
) | |||
F.zero_grad( | |||
F.add_update(self.max_val, tmp_max, alpha=0.0, beta=1.0, bias=0.0) | |||
) | |||
F.zero_grad( | |||
F.add_update( | |||
self.first_flag, self.not_flag, alpha=0.0, beta=1.0, bias=0.0 | |||
) | |||
) | |||
# FIXME: add_update is applied after the whole trace procedure in `symbolic=True` | |||
# mode. So use tmp_min/tmp_max to calc and save scale/zero_point for further | |||
# calculation in FakeQuant. | |||
self.set_scale_zero_point(tmp_min, tmp_max) | |||
def set_scale_zero_point(self, tmp_min, tmp_max): | |||
if self.symmetric: | |||
symmetric_max_vals = F.maximum(-tmp_min, tmp_max) | |||
# use maximun to avoid scale too small at the begin | |||
self.scale = F.maximum( | |||
symmetric_max_vals / ((self.qmax - self.qmin) / 2), self.scale_limit | |||
) | |||
# zero_point = self.zero_point | |||
else: | |||
# use maximun to avoid scale too small at the begin | |||
self.scale = F.maximum( | |||
(tmp_max - tmp_min) / (self.qmax - self.qmin), self.scale_limit | |||
) | |||
# caculate zero_point | |||
self.zero_point = self.qmin - Round()((tmp_min / self.scale)) | |||
def get_qparams(self): | |||
# scale and zero_point is runtime tensor rather than Buffer, | |||
# so need to re-calc if min_val and max_val are loaded. | |||
if self.scale is None: | |||
self.set_scale_zero_point(self.min_val, self.max_val) | |||
return self.scale, self.zero_point | |||
def forward(self, x_orig): | |||
if self.enabled: | |||
# stop gradient | |||
x = F.zero_grad(x_orig) | |||
# find max and min | |||
tmp_min, _ = F.cond_take( | |||
self.first_flag, F.concat([x.min(), F.minimum(self.min_val, x.min())]) | |||
) | |||
tmp_max, _ = F.cond_take( | |||
self.first_flag, F.concat([x.max(), F.maximum(self.max_val, x.max())]) | |||
) | |||
self.set_min_max(tmp_min, tmp_max) | |||
return x_orig | |||
class ExponentialMovingAverageObserver(MinMaxObserver): | |||
def __init__(self, momentum=0.9, *args, **kwargs): | |||
super().__init__(*args, **kwargs) | |||
self.momentum = Buffer(momentum) | |||
def set_momentum(self, momentum): | |||
self.momentum.set_value(momentum) | |||
def forward(self, x_orig): | |||
if self.enabled: | |||
# stop gradient | |||
x = F.zero_grad(x_orig) | |||
# Exponential Moving Average | |||
tmp_min, _ = F.cond_take( | |||
self.first_flag, | |||
F.concat( | |||
[ | |||
x.min(), | |||
self.momentum * self.min_val + (1 - self.momentum) * x.min(), | |||
] | |||
), | |||
) | |||
tmp_max, _ = F.cond_take( | |||
self.first_flag, | |||
F.concat( | |||
[ | |||
x.max(), | |||
self.momentum * self.max_val + (1 - self.momentum) * x.max(), | |||
] | |||
), | |||
) | |||
self.set_min_max(tmp_min, tmp_max) | |||
return x_orig |
@@ -0,0 +1,82 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from functools import partial | |||
from ..module import Module | |||
from .fake_quant import FakeQuantize | |||
from .observer import ExponentialMovingAverageObserver, MinMaxObserver | |||
class QConfig: | |||
""" | |||
A config class indicating how to do quantize toward :class:`~.QATModule`'s | |||
``activation``, ``weight`` and ``bias``. | |||
And ``fake_quant`` parameter to indicate | |||
See :meth:`~.QATModule.set_qconfig` for detail usage. | |||
:param inp_observer: interface to instantiate an :class:`~.Observer` indicating | |||
how to collect scales and zero_point of input. | |||
:param weight_observer: similar to ``inp_observer`` but toward weight. | |||
:param act_observer: similar to ``inp_observer`` but toward activation. | |||
:param fake_quant: interface to instantiate a :class:`~.FakeQuantize` indicating | |||
how to do fake_quant calculation. can be invoked multi times to get different | |||
instance for each target tensor, for better control on enable and disable. | |||
:param bias_fake_quant: similar to ``fake_quant``, but usually need to set ``dtype`` | |||
in advance, for bias's dtype is unable to be inferred from observer. | |||
Examples: | |||
.. code-block:: | |||
# Default EMA QConfig for QAT. | |||
ema_fakequant_qconfig = QConfig( | |||
inp_observer=ExponentialMovingAverageObserver, | |||
weight_observer=ExponentialMovingAverageObserver, | |||
act_observer=ExponentialMovingAverageObserver, | |||
fake_quant=FakeQuantize, | |||
) | |||
""" | |||
def __init__( | |||
self, act_observer, weight_observer, inp_observer, fake_quant, bias_fake_quant, | |||
): | |||
if ( | |||
isinstance(act_observer, Module) | |||
or isinstance(weight_observer, Module) | |||
or isinstance(inp_observer, Module) | |||
): | |||
raise ValueError( | |||
"QConfig must not receive observer instance, please pass observer" | |||
" class generator using `partial(Observer, ...)` instead. Use" | |||
" partial(MyObserver, x=1) to override arguments to constructor if needed" | |||
) | |||
self.act_observer = act_observer | |||
self.weight_observer = weight_observer | |||
self.inp_observer = inp_observer | |||
self.fake_quant = fake_quant | |||
self.bias_fake_quant = bias_fake_quant | |||
# Default QAT QConfigs | |||
min_max_fakequant_qconfig = QConfig( | |||
inp_observer=MinMaxObserver, | |||
weight_observer=MinMaxObserver, | |||
act_observer=MinMaxObserver, | |||
fake_quant=FakeQuantize, | |||
bias_fake_quant=partial(FakeQuantize, dtype="qint32"), | |||
) | |||
ema_fakequant_qconfig = QConfig( | |||
inp_observer=ExponentialMovingAverageObserver, | |||
weight_observer=MinMaxObserver, | |||
act_observer=ExponentialMovingAverageObserver, | |||
fake_quant=FakeQuantize, | |||
bias_fake_quant=partial(FakeQuantize, dtype="qint32"), | |||
) |
@@ -0,0 +1,113 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from copy import deepcopy | |||
from ..module import Module, QATModule, Sequential, quantized | |||
from .qconfig import QConfig, ema_fakequant_qconfig | |||
def quantize(module: Module, inplace=True): | |||
r""" | |||
Recursively convert `module` to `quantized` mode through :meth:`~.Module.apply`. | |||
:param module: root module to do convert recursively. | |||
""" | |||
if not inplace: | |||
module = deepcopy(module) | |||
def is_qat_module(obj): | |||
return isinstance(obj, QATModule) | |||
# no need to pass prefix and get pure key of parent Module. | |||
for key, submodule, parent in module._flatten( | |||
with_key=True, with_parent=True, predicate=is_qat_module | |||
): | |||
if isinstance(parent, Sequential): | |||
# cannnot use setattr to be compatible with Sequential's ``__setitem__`` | |||
parent[int(key.split(".")[-1])] = submodule.to_quantized() | |||
else: | |||
setattr(parent, key.split(".")[-1], submodule.to_quantized()) | |||
def quantize_qat(module: Module, qconfig: QConfig = ema_fakequant_qconfig): | |||
r""" | |||
Recursively convert `module` to `qat` mode through :meth:`~.Module.apply` | |||
and set qconfig relatively. | |||
:param module: root module to do convert recursively. | |||
:param qconfig: a instance of :class:`~.QConfig` to be set as submodules' qconfig. | |||
default is :any:`~.qconfig.ema_fakequant_qconfig`. | |||
""" | |||
def fn(mod: Module): | |||
if isinstance(mod, QATModule): | |||
mod.set_qat_mode(QATModule.QATMode.QAT) | |||
mod.set_qconfig(qconfig) | |||
module.apply(fn) | |||
def disable_fake_quant(module: Module): | |||
r""" | |||
Recursively disable `module` fake quantization in QATModule through :meth:`~.Module.apply` | |||
:param module: root module to do disable fake quantization recursively. | |||
""" | |||
def fn(mod): | |||
if isinstance(mod, QATModule): | |||
mod.act_fake_quant.disable() | |||
mod.weight_fake_quant.disable() | |||
mod.inp_fake_quant.disable() | |||
module.apply(fn) | |||
def disable_observer(module: Module): | |||
r""" | |||
Recursively disable `module` observer in QATModule through :meth:`~.Module.apply` | |||
:param module: root module to do disable observer recursively. | |||
""" | |||
def fn(mod): | |||
if isinstance(mod, QATModule): | |||
mod.act_observer.disable() | |||
module.apply(fn) | |||
def enable_fake_quant(module: Module): | |||
r""" | |||
Recursively enable `module` fake quantization in QATModule through :meth:`~.Module.apply` | |||
:param module: root module to do enable fake quantization recursively. | |||
""" | |||
def fn(mod): | |||
if isinstance(mod, QATModule): | |||
mod.act_fake_quant.enable() | |||
mod.weight_fake_quant.enable() | |||
mod.inp_fake_quant.enable() | |||
module.apply(fn) | |||
def enable_observer(module: Module): | |||
r""" | |||
Recursively enable `module` observer in QATModule through :meth:`~.Module.apply` | |||
:param module: root module to do enable observer recursively. | |||
""" | |||
def fn(mod): | |||
if isinstance(mod, QATModule): | |||
mod.act_observer.enable() | |||
module.apply(fn) |
@@ -0,0 +1,23 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from functools import partial, update_wrapper, wraps | |||
def register_method_to_class(cls): | |||
def decorator(func): | |||
@wraps(func) | |||
def wrapper(self, *args, **kwargs): | |||
return func(self, *args, **kwargs) | |||
if isinstance(func, partial): | |||
update_wrapper(func, func.func) | |||
setattr(cls, func.__name__, wrapper) | |||
return func | |||
return decorator |
@@ -7,10 +7,12 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import numpy as np | |||
import pytest | |||
from helpers import opr_test | |||
import megengine._internal as mgb | |||
import megengine.functional as F | |||
from megengine import Buffer, jit, tensor | |||
from megengine import Buffer, Parameter, is_cuda_available, jit, tensor | |||
from megengine.test import assertTensorClose | |||
@@ -332,3 +334,108 @@ def test_binary_cross_entropy(): | |||
{"input": [data2, label2], "output": expect2,}, | |||
] | |||
opr_test(cases, F.binary_cross_entropy, compare_fn=compare_fn) | |||
@pytest.mark.skip | |||
def test_conv_bias(): | |||
inp_scale = 0.01 | |||
w_scale = 0.02 | |||
outp_scale = 0.1 | |||
inp_dtype = mgb.dtype.qint8(inp_scale) | |||
w_dtype = mgb.dtype.qint8(w_scale) | |||
b_dtype = mgb.dtype.qint32(inp_scale * w_scale) | |||
out_dtype = mgb.dtype.qint8(outp_scale) | |||
def run( | |||
N, | |||
IC, | |||
OC, | |||
IH, | |||
IW, | |||
KH, | |||
KW, | |||
PH, | |||
PW, | |||
SH, | |||
SW, | |||
has_bias=True, | |||
nonlinear_mode="IDENTITY", | |||
): | |||
inp_v = np.random.normal(size=(N, IC, IH, IW)) | |||
w_v = np.random.normal(size=(OC, IC, KW, KW)) | |||
b_v = np.random.normal(size=(1, OC, 1, 1)) | |||
inp_scale = mgb.dtype.get_scale(inp_dtype) | |||
w_scale = mgb.dtype.get_scale(w_dtype) | |||
b_scale = mgb.dtype.get_scale(b_dtype) | |||
inpv = mgb.dtype.convert_to_qint8(inp_v * inp_scale, inp_dtype) | |||
wv = mgb.dtype.convert_to_qint8(w_v * w_scale, w_dtype) | |||
bv = mgb.dtype.convert_to_qint32(b_v * b_scale, b_dtype) | |||
inp_int8 = tensor(inpv, dtype=inp_dtype) | |||
w_int8 = Parameter(wv, dtype=w_dtype) | |||
b_int32 = Parameter(bv, dtype=b_dtype) | |||
inp_fp32 = inp_int8.astype("float32") | |||
w_fp32 = w_int8.astype("float32") | |||
b_fp32 = b_int32.astype("float32") | |||
jit.trace.enabled = True | |||
b_symbolic = True | |||
def convert_to_nchw4(var): | |||
return var.reshape( | |||
var.shapeof(0), var.shapeof(1) // 4, 4, var.shapeof(2), var.shapeof(3) | |||
).dimshuffle(0, 1, 3, 4, 2) | |||
@jit.trace(symbolic=b_symbolic) | |||
def run_conv2d(inp, w, b): | |||
O = F.conv2d( | |||
inp, w, b if has_bias else None, stride=(SH, SW), padding=(PH, PW), | |||
) | |||
if nonlinear_mode == "RELU": | |||
return F.relu(O) | |||
else: | |||
return O | |||
@jit.trace(symbolic=b_symbolic) | |||
def run_conv_bias(inp, w, b, format="NCHW"): | |||
b = b if has_bias else np.zeros_like(b) | |||
if format == "NCHW4": | |||
inp = convert_to_nchw4(inp) | |||
w = convert_to_nchw4(w) | |||
b = F.flatten(b) | |||
return F.conv_bias_activation( | |||
inp, | |||
w, | |||
b, | |||
stride=(SH, SW), | |||
padding=(PH, PW), | |||
dtype=out_dtype, | |||
nonlinear_mode=nonlinear_mode, | |||
) | |||
format = "NCHW4" if is_cuda_available() else "NCHW" | |||
expected = run_conv2d(inp_fp32, w_fp32, b_fp32) | |||
expected = expected.astype(out_dtype).astype("float32") | |||
result = run_conv_bias(inp_int8, w_int8, b_int32, format=format).astype( | |||
"float32" | |||
) | |||
if format == "NCHW4": | |||
result = result.dimshuffle(0, 1, 4, 2, 3) | |||
expected = F.flatten(expected) | |||
result = F.flatten(result) | |||
assertTensorClose(result.numpy(), expected.numpy()) | |||
if not is_cuda_available(): | |||
run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1, False) | |||
run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1, False) | |||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False) | |||
run(1, 4, 4, 24, 33, 1, 1, 2, 3, 1, 1) | |||
run(10, 12, 24, 46, 46, 1, 1, 2, 1, 3, 1) | |||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2) | |||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, False, "RELU") | |||
run(10, 36, 8, 46, 26, 2, 2, 2, 1, 1, 2, True, "RELU") |