GitOrigin-RevId: 9cd668d97b
tags/v0.5.0
@@ -9,8 +9,8 @@ | |||||
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | ||||
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | ||||
from .concat import Concat | from .concat import Concat | ||||
from .conv import Conv2d, ConvTranspose2d, LocalConv2d | |||||
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||||
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | |||||
from .conv_bn import ConvBn2d, ConvBnRelu2d | |||||
from .dropout import Dropout | from .dropout import Dropout | ||||
from .elemwise import Elemwise | from .elemwise import Elemwise | ||||
from .embedding import Embedding | from .embedding import Embedding | ||||
@@ -1,4 +1,3 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | # MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | ||||
# | # | ||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | # Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | ||||
@@ -13,8 +12,8 @@ import numpy as np | |||||
import megengine._internal as mgb | import megengine._internal as mgb | ||||
from .. import functional as F | |||||
from ..core import Parameter | from ..core import Parameter | ||||
from ..functional import conv2d, conv_transpose2d, local_conv2d | |||||
from ..utils.types import _pair, _pair_nonzero | from ..utils.types import _pair, _pair_nonzero | ||||
from . import init | from . import init | ||||
from .module import Module | from .module import Module | ||||
@@ -183,7 +182,7 @@ class Conv2d(_ConvNd): | |||||
return (1, self.out_channels, 1, 1) | return (1, self.out_channels, 1, 1) | ||||
def calc_conv(self, inp, weight, bias): | def calc_conv(self, inp, weight, bias): | ||||
return conv2d( | |||||
return F.conv2d( | |||||
inp, | inp, | ||||
weight, | weight, | ||||
bias, | bias, | ||||
@@ -295,7 +294,7 @@ class ConvTranspose2d(_ConvNd): | |||||
return (1, self.out_channels, 1, 1) | return (1, self.out_channels, 1, 1) | ||||
def forward(self, inp): | def forward(self, inp): | ||||
return conv_transpose2d( | |||||
return F.conv_transpose2d( | |||||
inp, | inp, | ||||
self.weight, | self.weight, | ||||
self.bias, | self.bias, | ||||
@@ -324,7 +323,7 @@ class LocalConv2d(Conv2d): | |||||
spatial dimensions. Only zero-padding is supported. Default: 0 | spatial dimensions. Only zero-padding is supported. Default: 0 | ||||
:param groups: number of groups to divide input and output channels into, | :param groups: number of groups to divide input and output channels into, | ||||
so as to perform a "grouped convolution". When ``groups`` is not 1, | so as to perform a "grouped convolution". When ``groups`` is not 1, | ||||
``in_channels`` and ``out_channels`` must be divisible by ``groups``. | |||||
``in_channels`` and ``out_channels`` must be divisible by ``groups``. | |||||
The shape of weight is ``(groups, output_height, output_width, | The shape of weight is ``(groups, output_height, output_width, | ||||
in_channels // groups, *kernel_size, out_channels // groups)``. | in_channels // groups, *kernel_size, out_channels // groups)``. | ||||
""" | """ | ||||
@@ -377,6 +376,17 @@ class LocalConv2d(Conv2d): | |||||
) | ) | ||||
def forward(self, inp): | def forward(self, inp): | ||||
return local_conv2d( | |||||
return F.local_conv2d( | |||||
inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode | inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode | ||||
) | ) | ||||
class ConvRelu2d(Conv2d): | |||||
r""" | |||||
A fused :class:`~.Module` including Conv2d and relu. Could be replaced | |||||
with :class:`~.QATModule` version :class:`~.qat.conv.ConvRelu2d` using | |||||
:func:`~.quantize.quantize_qat`. | |||||
""" | |||||
def forward(self, inp): | |||||
return F.relu(self.calc_conv(inp, self.weight, self.bias)) |
@@ -50,7 +50,7 @@ class _ConvBnActivation2d(Module): | |||||
class ConvBn2d(_ConvBnActivation2d): | class ConvBn2d(_ConvBnActivation2d): | ||||
r""" | r""" | ||||
A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced | A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced | ||||
with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBn2d` using | |||||
with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBn2d` using | |||||
:func:`~.quantize.quantize_qat`. | :func:`~.quantize.quantize_qat`. | ||||
""" | """ | ||||
@@ -61,7 +61,7 @@ class ConvBn2d(_ConvBnActivation2d): | |||||
class ConvBnRelu2d(_ConvBnActivation2d): | class ConvBnRelu2d(_ConvBnActivation2d): | ||||
r""" | r""" | ||||
A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced | A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced | ||||
with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBnRelu2d` using | |||||
with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBnRelu2d` using | |||||
:func:`~.quantize.quantize_qat`. | :func:`~.quantize.quantize_qat`. | ||||
""" | """ | ||||
@@ -6,7 +6,8 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .concat import Concat | from .concat import Concat | ||||
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||||
from .conv import Conv2d, ConvRelu2d | |||||
from .conv_bn import ConvBn2d, ConvBnRelu2d | |||||
from .elemwise import Elemwise | from .elemwise import Elemwise | ||||
from .linear import Linear | from .linear import Linear | ||||
from .module import QATModule | from .module import QATModule | ||||
@@ -0,0 +1,57 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from ... import functional as F | |||||
from .. import conv as Float | |||||
from .module import QATModule | |||||
class Conv2d(Float.Conv2d, QATModule): | |||||
r""" | |||||
A :class:`~.QATModule` Conv2d with QAT support. | |||||
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||||
""" | |||||
def calc_conv_qat(self, inp): | |||||
w_qat = self.apply_quant_weight(self.weight) | |||||
conv = self.calc_conv(inp, w_qat, self.bias) | |||||
return conv | |||||
@classmethod | |||||
def from_float_module(cls, float_module: Float.Conv2d): | |||||
r""" | |||||
Return a :class:`~.QATModule` instance converted from | |||||
a float :class:`~.Module` instance. | |||||
""" | |||||
qat_module = cls( | |||||
float_module.in_channels, | |||||
float_module.out_channels, | |||||
float_module.kernel_size, | |||||
float_module.stride, | |||||
float_module.padding, | |||||
float_module.dilation, | |||||
float_module.groups, | |||||
float_module.bias is not None, | |||||
float_module.conv_mode.name, | |||||
float_module.compute_mode.name, | |||||
) | |||||
qat_module.weight = float_module.weight | |||||
qat_module.bias = float_module.bias | |||||
return qat_module | |||||
def forward(self, inp): | |||||
return self.apply_quant_activation(self.calc_conv_qat(inp)) | |||||
class ConvRelu2d(Conv2d): | |||||
r""" | |||||
A :class:`~.QATModule` include Conv2d and Relu with QAT support. | |||||
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||||
""" | |||||
def forward(self, inp): | |||||
return self.apply_quant_activation(F.relu(self.calc_conv_qat(inp))) |
@@ -7,7 +7,7 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from ...core import ones, zeros | from ...core import ones, zeros | ||||
from ...functional import add_update, relu, sqrt, sum, zero_grad | from ...functional import add_update, relu, sqrt, sum, zero_grad | ||||
from .. import conv_bn_relu as Float | |||||
from .. import conv_bn as Float | |||||
from .module import QATModule | from .module import QATModule | ||||
@@ -163,7 +163,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): | |||||
float_module.conv.padding, | float_module.conv.padding, | ||||
float_module.conv.dilation, | float_module.conv.dilation, | ||||
float_module.conv.groups, | float_module.conv.groups, | ||||
bool(float_module.conv.bias), | |||||
float_module.conv.bias is not None, | |||||
float_module.conv.conv_mode.name, | float_module.conv.conv_mode.name, | ||||
float_module.conv.compute_mode.name, | float_module.conv.compute_mode.name, | ||||
) | ) |
@@ -6,7 +6,8 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from .concat import Concat | from .concat import Concat | ||||
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||||
from .conv import Conv2d, ConvRelu2d | |||||
from .conv_bn import ConvBn2d, ConvBnRelu2d | |||||
from .elemwise import Elemwise | from .elemwise import Elemwise | ||||
from .linear import Linear | from .linear import Linear | ||||
from .module import QuantizedModule | from .module import QuantizedModule | ||||
@@ -7,16 +7,19 @@ | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from typing import Tuple, Union | from typing import Tuple, Union | ||||
import numpy as np | |||||
import megengine._internal as mgb | import megengine._internal as mgb | ||||
from ... import module as Float | from ... import module as Float | ||||
from ...core import Parameter | from ...core import Parameter | ||||
from ...functional import conv_bias_activation | from ...functional import conv_bias_activation | ||||
from ..qat import conv_bn_relu as QAT | |||||
from ..qat import conv as QAT | |||||
from .module import QuantizedModule | from .module import QuantizedModule | ||||
class _ConvBnActivation2d(Float.Conv2d, QuantizedModule): | |||||
class Conv2d(Float.Conv2d, QuantizedModule): | |||||
r"""quantized version of :class:`~.qat.conv.Conv2d`.""" | |||||
r"""Applies a 2D convolution over an quantized input tensor, inference only. | r"""Applies a 2D convolution over an quantized input tensor, inference only. | ||||
The parameter is same with :class: `~.Conv2d` | The parameter is same with :class: `~.Conv2d` | ||||
@@ -68,40 +71,38 @@ class _ConvBnActivation2d(Float.Conv2d, QuantizedModule): | |||||
) | ) | ||||
@classmethod | @classmethod | ||||
def from_qat_module(cls, qat_module: QAT._ConvBnActivation2d): | |||||
def from_qat_module(cls, qat_module: QAT.Conv2d): | |||||
r""" | r""" | ||||
return a :class:`~.QuantizedModule` instance converted from a | return a :class:`~.QuantizedModule` instance converted from a | ||||
:class:`~.QATModule` instance. | :class:`~.QATModule` instance. | ||||
""" | """ | ||||
output_dtype = qat_module.get_activation_dtype() | output_dtype = qat_module.get_activation_dtype() | ||||
qconv = cls( | qconv = cls( | ||||
qat_module.conv.in_channels, | |||||
qat_module.conv.out_channels, | |||||
qat_module.conv.kernel_size, | |||||
qat_module.conv.stride, | |||||
qat_module.conv.padding, | |||||
qat_module.conv.dilation, | |||||
qat_module.conv.groups, | |||||
qat_module.in_channels, | |||||
qat_module.out_channels, | |||||
qat_module.kernel_size, | |||||
qat_module.stride, | |||||
qat_module.padding, | |||||
qat_module.dilation, | |||||
qat_module.groups, | |||||
dtype=output_dtype, | dtype=output_dtype, | ||||
) | ) | ||||
w_fold, b_fold = qat_module.fold_weight_bias( | |||||
qat_module.bn.running_mean, qat_module.bn.running_var | |||||
) | |||||
weight = w_fold.astype(qat_module.get_weight_dtype()) | |||||
weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | |||||
qconv.weight = Parameter(weight.numpy()) | qconv.weight = Parameter(weight.numpy()) | ||||
qconv.bias = Parameter(b_fold.numpy()) | |||||
if qat_module.bias is not None: | |||||
qconv.bias = Parameter(qat_module.bias.numpy()) | |||||
else: | |||||
qconv.bias = Parameter( | |||||
np.zeros(qat_module._infer_bias_shape(), dtype=np.float32) | |||||
) | |||||
return qconv | return qconv | ||||
class ConvBn2d(_ConvBnActivation2d): | |||||
r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBn2d`.""" | |||||
def forward(self, inp): | def forward(self, inp): | ||||
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") | return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") | ||||
class ConvBnRelu2d(_ConvBnActivation2d): | |||||
r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBnRelu2d`.""" | |||||
class ConvRelu2d(Conv2d): | |||||
r"""quantized version of :class:`~.qat.conv.ConvRelu2d`.""" | |||||
def forward(self, inp): | def forward(self, inp): | ||||
return self.calc_conv_quantized(inp, nonlinear_mode="RELU") | return self.calc_conv_quantized(inp, nonlinear_mode="RELU") |
@@ -0,0 +1,56 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from ...core import Parameter | |||||
from ..qat import conv_bn as QAT | |||||
from .conv import Conv2d | |||||
class _ConvBnActivation2d(Conv2d): | |||||
r"""Applies a 2D convolution over an quantized input tensor, inference only. | |||||
The parameter is same with :class: `~.Conv2d` | |||||
""" | |||||
@classmethod | |||||
def from_qat_module(cls, qat_module: QAT._ConvBnActivation2d): | |||||
r""" | |||||
return a :class:`~.QuantizedModule` instance converted from a | |||||
:class:`~.QATModule` instance. | |||||
""" | |||||
output_dtype = qat_module.get_activation_dtype() | |||||
qconv = cls( | |||||
qat_module.conv.in_channels, | |||||
qat_module.conv.out_channels, | |||||
qat_module.conv.kernel_size, | |||||
qat_module.conv.stride, | |||||
qat_module.conv.padding, | |||||
qat_module.conv.dilation, | |||||
qat_module.conv.groups, | |||||
dtype=output_dtype, | |||||
) | |||||
w_fold, b_fold = qat_module.fold_weight_bias( | |||||
qat_module.bn.running_mean, qat_module.bn.running_var | |||||
) | |||||
weight = w_fold.astype(qat_module.get_weight_dtype()) | |||||
qconv.weight = Parameter(weight.numpy()) | |||||
qconv.bias = Parameter(b_fold.numpy()) | |||||
return qconv | |||||
class ConvBn2d(_ConvBnActivation2d): | |||||
r"""quantized version of :class:`~.qat.conv_bn.ConvBn2d`.""" | |||||
def forward(self, inp): | |||||
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") | |||||
class ConvBnRelu2d(_ConvBnActivation2d): | |||||
r"""quantized version of :class:`~.qat.conv_bn.ConvBnRelu2d`.""" | |||||
def forward(self, inp): | |||||
return self.calc_conv_quantized(inp, nonlinear_mode="RELU") |
@@ -104,6 +104,10 @@ def quantize_qat( | |||||
for key, submodule, parent in module._flatten( | for key, submodule, parent in module._flatten( | ||||
with_key=True, with_parent=True, predicate=is_quantable | with_key=True, with_parent=True, predicate=is_quantable | ||||
): | ): | ||||
# only convert top quantable module. | |||||
if is_quantable(parent): | |||||
continue | |||||
new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule) | new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule) | ||||
if isinstance(parent, Float.Sequential): | if isinstance(parent, Float.Sequential): | ||||
# cannnot use setattr to be compatible with Sequential's ``__setitem__`` | # cannnot use setattr to be compatible with Sequential's ``__setitem__`` | ||||
@@ -1,37 +0,0 @@ | |||||
import copy | |||||
from itertools import product | |||||
import numpy as np | |||||
from megengine import tensor | |||||
from megengine.module import ConvBn2d | |||||
from megengine.quantization.quantize import disable_fake_quant, quantize_qat | |||||
from megengine.test import assertTensorClose | |||||
def test_convbn2d(): | |||||
in_channels = 32 | |||||
out_channels = 64 | |||||
kernel_size = 3 | |||||
for groups, bias in product([1, 4], [True, False]): | |||||
module = ConvBn2d( | |||||
in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||||
) | |||||
module.train() | |||||
qat_module = quantize_qat(module, inplace=False) | |||||
disable_fake_quant(qat_module) | |||||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||||
normal_outputs = module(inputs) | |||||
qat_outputs = qat_module(inputs) | |||||
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) | |||||
assertTensorClose( | |||||
module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8 | |||||
) | |||||
assertTensorClose( | |||||
module.bn.running_var, qat_module.bn.running_var, max_err=5e-7 | |||||
) | |||||
module.eval() | |||||
normal_outputs = module(inputs) | |||||
qat_module.eval() | |||||
qat_outputs = qat_module(inputs) | |||||
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) |
@@ -0,0 +1,85 @@ | |||||
from itertools import product | |||||
import numpy as np | |||||
from megengine import tensor | |||||
from megengine.module import ( | |||||
Conv2d, | |||||
ConvBn2d, | |||||
ConvRelu2d, | |||||
DequantStub, | |||||
Module, | |||||
QuantStub, | |||||
) | |||||
from megengine.quantization.quantize import disable_fake_quant, quantize_qat | |||||
from megengine.test import assertTensorClose | |||||
def test_qat_convbn2d(): | |||||
in_channels = 32 | |||||
out_channels = 64 | |||||
kernel_size = 3 | |||||
for groups, bias in product([1, 4], [True, False]): | |||||
module = ConvBn2d( | |||||
in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||||
) | |||||
module.train() | |||||
qat_module = quantize_qat(module, inplace=False) | |||||
disable_fake_quant(qat_module) | |||||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||||
normal_outputs = module(inputs) | |||||
qat_outputs = qat_module(inputs) | |||||
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) | |||||
assertTensorClose( | |||||
module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8 | |||||
) | |||||
assertTensorClose( | |||||
module.bn.running_var, qat_module.bn.running_var, max_err=5e-7 | |||||
) | |||||
module.eval() | |||||
normal_outputs = module(inputs) | |||||
qat_module.eval() | |||||
qat_outputs = qat_module(inputs) | |||||
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) | |||||
def test_qat_conv(): | |||||
in_channels = 32 | |||||
out_channels = 64 | |||||
kernel_size = 3 | |||||
class TestNet(Module): | |||||
def __init__(self, groups, bias): | |||||
super().__init__() | |||||
self.quant = QuantStub() | |||||
self.dequant = DequantStub() | |||||
self.conv = Conv2d( | |||||
in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||||
) | |||||
self.conv_relu = ConvRelu2d( | |||||
out_channels, in_channels, kernel_size, groups=groups, bias=bias | |||||
) | |||||
def forward(self, inp): | |||||
out = self.quant(inp) | |||||
out = self.conv(out) | |||||
out = self.conv_relu(out) | |||||
out = self.dequant(out) | |||||
return out | |||||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||||
for groups, bias in product([1, 4], [True, False]): | |||||
net = TestNet(groups, bias) | |||||
net.train() | |||||
qat_net = quantize_qat(net, inplace=False) | |||||
disable_fake_quant(qat_net) | |||||
normal_outputs = net(inputs) | |||||
qat_outputs = qat_net(inputs) | |||||
assertTensorClose(normal_outputs, qat_outputs) | |||||
net.eval() | |||||
normal_outputs = net(inputs) | |||||
qat_net.eval() | |||||
qat_outputs = qat_net(inputs) | |||||
assertTensorClose(normal_outputs, qat_outputs) |