GitOrigin-RevId: 9cd668d97b
tags/v0.5.0
@@ -9,8 +9,8 @@ | |||
from .activation import LeakyReLU, PReLU, ReLU, Sigmoid, Softmax | |||
from .batchnorm import BatchNorm1d, BatchNorm2d, SyncBatchNorm | |||
from .concat import Concat | |||
from .conv import Conv2d, ConvTranspose2d, LocalConv2d | |||
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, LocalConv2d | |||
from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
from .dropout import Dropout | |||
from .elemwise import Elemwise | |||
from .embedding import Embedding | |||
@@ -1,4 +1,3 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
@@ -13,8 +12,8 @@ import numpy as np | |||
import megengine._internal as mgb | |||
from .. import functional as F | |||
from ..core import Parameter | |||
from ..functional import conv2d, conv_transpose2d, local_conv2d | |||
from ..utils.types import _pair, _pair_nonzero | |||
from . import init | |||
from .module import Module | |||
@@ -183,7 +182,7 @@ class Conv2d(_ConvNd): | |||
return (1, self.out_channels, 1, 1) | |||
def calc_conv(self, inp, weight, bias): | |||
return conv2d( | |||
return F.conv2d( | |||
inp, | |||
weight, | |||
bias, | |||
@@ -295,7 +294,7 @@ class ConvTranspose2d(_ConvNd): | |||
return (1, self.out_channels, 1, 1) | |||
def forward(self, inp): | |||
return conv_transpose2d( | |||
return F.conv_transpose2d( | |||
inp, | |||
self.weight, | |||
self.bias, | |||
@@ -324,7 +323,7 @@ class LocalConv2d(Conv2d): | |||
spatial dimensions. Only zero-padding is supported. Default: 0 | |||
:param groups: number of groups to divide input and output channels into, | |||
so as to perform a "grouped convolution". When ``groups`` is not 1, | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``. | |||
``in_channels`` and ``out_channels`` must be divisible by ``groups``. | |||
The shape of weight is ``(groups, output_height, output_width, | |||
in_channels // groups, *kernel_size, out_channels // groups)``. | |||
""" | |||
@@ -377,6 +376,17 @@ class LocalConv2d(Conv2d): | |||
) | |||
def forward(self, inp): | |||
return local_conv2d( | |||
return F.local_conv2d( | |||
inp, self.weight, self.stride, self.padding, self.dilation, self.conv_mode | |||
) | |||
class ConvRelu2d(Conv2d): | |||
r""" | |||
A fused :class:`~.Module` including Conv2d and relu. Could be replaced | |||
with :class:`~.QATModule` version :class:`~.qat.conv.ConvRelu2d` using | |||
:func:`~.quantize.quantize_qat`. | |||
""" | |||
def forward(self, inp): | |||
return F.relu(self.calc_conv(inp, self.weight, self.bias)) |
@@ -50,7 +50,7 @@ class _ConvBnActivation2d(Module): | |||
class ConvBn2d(_ConvBnActivation2d): | |||
r""" | |||
A fused :class:`~.Module` including Conv2d, BatchNorm2d. Could be replaced | |||
with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBn2d` using | |||
with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBn2d` using | |||
:func:`~.quantize.quantize_qat`. | |||
""" | |||
@@ -61,7 +61,7 @@ class ConvBn2d(_ConvBnActivation2d): | |||
class ConvBnRelu2d(_ConvBnActivation2d): | |||
r""" | |||
A fused :class:`~.Module` including Conv2d, BatchNorm2d and relu. Could be replaced | |||
with :class:`~.QATModule` version :class:`~.qat.conv_bn_relu.ConvBnRelu2d` using | |||
with :class:`~.QATModule` version :class:`~.qat.conv_bn.ConvBnRelu2d` using | |||
:func:`~.quantize.quantize_qat`. | |||
""" | |||
@@ -6,7 +6,8 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .concat import Concat | |||
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||
from .conv import Conv2d, ConvRelu2d | |||
from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
from .elemwise import Elemwise | |||
from .linear import Linear | |||
from .module import QATModule | |||
@@ -0,0 +1,57 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ... import functional as F | |||
from .. import conv as Float | |||
from .module import QATModule | |||
class Conv2d(Float.Conv2d, QATModule): | |||
r""" | |||
A :class:`~.QATModule` Conv2d with QAT support. | |||
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
""" | |||
def calc_conv_qat(self, inp): | |||
w_qat = self.apply_quant_weight(self.weight) | |||
conv = self.calc_conv(inp, w_qat, self.bias) | |||
return conv | |||
@classmethod | |||
def from_float_module(cls, float_module: Float.Conv2d): | |||
r""" | |||
Return a :class:`~.QATModule` instance converted from | |||
a float :class:`~.Module` instance. | |||
""" | |||
qat_module = cls( | |||
float_module.in_channels, | |||
float_module.out_channels, | |||
float_module.kernel_size, | |||
float_module.stride, | |||
float_module.padding, | |||
float_module.dilation, | |||
float_module.groups, | |||
float_module.bias is not None, | |||
float_module.conv_mode.name, | |||
float_module.compute_mode.name, | |||
) | |||
qat_module.weight = float_module.weight | |||
qat_module.bias = float_module.bias | |||
return qat_module | |||
def forward(self, inp): | |||
return self.apply_quant_activation(self.calc_conv_qat(inp)) | |||
class ConvRelu2d(Conv2d): | |||
r""" | |||
A :class:`~.QATModule` include Conv2d and Relu with QAT support. | |||
Could be applied with :class:`~.Observer` and :class:`~.FakeQuantize`. | |||
""" | |||
def forward(self, inp): | |||
return self.apply_quant_activation(F.relu(self.calc_conv_qat(inp))) |
@@ -7,7 +7,7 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ...core import ones, zeros | |||
from ...functional import add_update, relu, sqrt, sum, zero_grad | |||
from .. import conv_bn_relu as Float | |||
from .. import conv_bn as Float | |||
from .module import QATModule | |||
@@ -163,7 +163,7 @@ class _ConvBnActivation2d(Float._ConvBnActivation2d, QATModule): | |||
float_module.conv.padding, | |||
float_module.conv.dilation, | |||
float_module.conv.groups, | |||
bool(float_module.conv.bias), | |||
float_module.conv.bias is not None, | |||
float_module.conv.conv_mode.name, | |||
float_module.conv.compute_mode.name, | |||
) |
@@ -6,7 +6,8 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from .concat import Concat | |||
from .conv_bn_relu import ConvBn2d, ConvBnRelu2d | |||
from .conv import Conv2d, ConvRelu2d | |||
from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
from .elemwise import Elemwise | |||
from .linear import Linear | |||
from .module import QuantizedModule | |||
@@ -7,16 +7,19 @@ | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Tuple, Union | |||
import numpy as np | |||
import megengine._internal as mgb | |||
from ... import module as Float | |||
from ...core import Parameter | |||
from ...functional import conv_bias_activation | |||
from ..qat import conv_bn_relu as QAT | |||
from ..qat import conv as QAT | |||
from .module import QuantizedModule | |||
class _ConvBnActivation2d(Float.Conv2d, QuantizedModule): | |||
class Conv2d(Float.Conv2d, QuantizedModule): | |||
r"""quantized version of :class:`~.qat.conv.Conv2d`.""" | |||
r"""Applies a 2D convolution over an quantized input tensor, inference only. | |||
The parameter is same with :class: `~.Conv2d` | |||
@@ -68,40 +71,38 @@ class _ConvBnActivation2d(Float.Conv2d, QuantizedModule): | |||
) | |||
@classmethod | |||
def from_qat_module(cls, qat_module: QAT._ConvBnActivation2d): | |||
def from_qat_module(cls, qat_module: QAT.Conv2d): | |||
r""" | |||
return a :class:`~.QuantizedModule` instance converted from a | |||
:class:`~.QATModule` instance. | |||
""" | |||
output_dtype = qat_module.get_activation_dtype() | |||
qconv = cls( | |||
qat_module.conv.in_channels, | |||
qat_module.conv.out_channels, | |||
qat_module.conv.kernel_size, | |||
qat_module.conv.stride, | |||
qat_module.conv.padding, | |||
qat_module.conv.dilation, | |||
qat_module.conv.groups, | |||
qat_module.in_channels, | |||
qat_module.out_channels, | |||
qat_module.kernel_size, | |||
qat_module.stride, | |||
qat_module.padding, | |||
qat_module.dilation, | |||
qat_module.groups, | |||
dtype=output_dtype, | |||
) | |||
w_fold, b_fold = qat_module.fold_weight_bias( | |||
qat_module.bn.running_mean, qat_module.bn.running_var | |||
) | |||
weight = w_fold.astype(qat_module.get_weight_dtype()) | |||
weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | |||
qconv.weight = Parameter(weight.numpy()) | |||
qconv.bias = Parameter(b_fold.numpy()) | |||
if qat_module.bias is not None: | |||
qconv.bias = Parameter(qat_module.bias.numpy()) | |||
else: | |||
qconv.bias = Parameter( | |||
np.zeros(qat_module._infer_bias_shape(), dtype=np.float32) | |||
) | |||
return qconv | |||
class ConvBn2d(_ConvBnActivation2d): | |||
r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBn2d`.""" | |||
def forward(self, inp): | |||
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") | |||
class ConvBnRelu2d(_ConvBnActivation2d): | |||
r"""quantized version of :class:`~.qat.conv_bn_relu.ConvBnRelu2d`.""" | |||
class ConvRelu2d(Conv2d): | |||
r"""quantized version of :class:`~.qat.conv.ConvRelu2d`.""" | |||
def forward(self, inp): | |||
return self.calc_conv_quantized(inp, nonlinear_mode="RELU") |
@@ -0,0 +1,56 @@ | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ...core import Parameter | |||
from ..qat import conv_bn as QAT | |||
from .conv import Conv2d | |||
class _ConvBnActivation2d(Conv2d): | |||
r"""Applies a 2D convolution over an quantized input tensor, inference only. | |||
The parameter is same with :class: `~.Conv2d` | |||
""" | |||
@classmethod | |||
def from_qat_module(cls, qat_module: QAT._ConvBnActivation2d): | |||
r""" | |||
return a :class:`~.QuantizedModule` instance converted from a | |||
:class:`~.QATModule` instance. | |||
""" | |||
output_dtype = qat_module.get_activation_dtype() | |||
qconv = cls( | |||
qat_module.conv.in_channels, | |||
qat_module.conv.out_channels, | |||
qat_module.conv.kernel_size, | |||
qat_module.conv.stride, | |||
qat_module.conv.padding, | |||
qat_module.conv.dilation, | |||
qat_module.conv.groups, | |||
dtype=output_dtype, | |||
) | |||
w_fold, b_fold = qat_module.fold_weight_bias( | |||
qat_module.bn.running_mean, qat_module.bn.running_var | |||
) | |||
weight = w_fold.astype(qat_module.get_weight_dtype()) | |||
qconv.weight = Parameter(weight.numpy()) | |||
qconv.bias = Parameter(b_fold.numpy()) | |||
return qconv | |||
class ConvBn2d(_ConvBnActivation2d): | |||
r"""quantized version of :class:`~.qat.conv_bn.ConvBn2d`.""" | |||
def forward(self, inp): | |||
return self.calc_conv_quantized(inp, nonlinear_mode="IDENTITY") | |||
class ConvBnRelu2d(_ConvBnActivation2d): | |||
r"""quantized version of :class:`~.qat.conv_bn.ConvBnRelu2d`.""" | |||
def forward(self, inp): | |||
return self.calc_conv_quantized(inp, nonlinear_mode="RELU") |
@@ -104,6 +104,10 @@ def quantize_qat( | |||
for key, submodule, parent in module._flatten( | |||
with_key=True, with_parent=True, predicate=is_quantable | |||
): | |||
# only convert top quantable module. | |||
if is_quantable(parent): | |||
continue | |||
new_mod = _float2qat_dict[type(submodule)].from_float_module(submodule) | |||
if isinstance(parent, Float.Sequential): | |||
# cannnot use setattr to be compatible with Sequential's ``__setitem__`` | |||
@@ -1,37 +0,0 @@ | |||
import copy | |||
from itertools import product | |||
import numpy as np | |||
from megengine import tensor | |||
from megengine.module import ConvBn2d | |||
from megengine.quantization.quantize import disable_fake_quant, quantize_qat | |||
from megengine.test import assertTensorClose | |||
def test_convbn2d(): | |||
in_channels = 32 | |||
out_channels = 64 | |||
kernel_size = 3 | |||
for groups, bias in product([1, 4], [True, False]): | |||
module = ConvBn2d( | |||
in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||
) | |||
module.train() | |||
qat_module = quantize_qat(module, inplace=False) | |||
disable_fake_quant(qat_module) | |||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
normal_outputs = module(inputs) | |||
qat_outputs = qat_module(inputs) | |||
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) | |||
assertTensorClose( | |||
module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8 | |||
) | |||
assertTensorClose( | |||
module.bn.running_var, qat_module.bn.running_var, max_err=5e-7 | |||
) | |||
module.eval() | |||
normal_outputs = module(inputs) | |||
qat_module.eval() | |||
qat_outputs = qat_module(inputs) | |||
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) |
@@ -0,0 +1,85 @@ | |||
from itertools import product | |||
import numpy as np | |||
from megengine import tensor | |||
from megengine.module import ( | |||
Conv2d, | |||
ConvBn2d, | |||
ConvRelu2d, | |||
DequantStub, | |||
Module, | |||
QuantStub, | |||
) | |||
from megengine.quantization.quantize import disable_fake_quant, quantize_qat | |||
from megengine.test import assertTensorClose | |||
def test_qat_convbn2d(): | |||
in_channels = 32 | |||
out_channels = 64 | |||
kernel_size = 3 | |||
for groups, bias in product([1, 4], [True, False]): | |||
module = ConvBn2d( | |||
in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||
) | |||
module.train() | |||
qat_module = quantize_qat(module, inplace=False) | |||
disable_fake_quant(qat_module) | |||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
normal_outputs = module(inputs) | |||
qat_outputs = qat_module(inputs) | |||
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) | |||
assertTensorClose( | |||
module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8 | |||
) | |||
assertTensorClose( | |||
module.bn.running_var, qat_module.bn.running_var, max_err=5e-7 | |||
) | |||
module.eval() | |||
normal_outputs = module(inputs) | |||
qat_module.eval() | |||
qat_outputs = qat_module(inputs) | |||
assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) | |||
def test_qat_conv(): | |||
in_channels = 32 | |||
out_channels = 64 | |||
kernel_size = 3 | |||
class TestNet(Module): | |||
def __init__(self, groups, bias): | |||
super().__init__() | |||
self.quant = QuantStub() | |||
self.dequant = DequantStub() | |||
self.conv = Conv2d( | |||
in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||
) | |||
self.conv_relu = ConvRelu2d( | |||
out_channels, in_channels, kernel_size, groups=groups, bias=bias | |||
) | |||
def forward(self, inp): | |||
out = self.quant(inp) | |||
out = self.conv(out) | |||
out = self.conv_relu(out) | |||
out = self.dequant(out) | |||
return out | |||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
for groups, bias in product([1, 4], [True, False]): | |||
net = TestNet(groups, bias) | |||
net.train() | |||
qat_net = quantize_qat(net, inplace=False) | |||
disable_fake_quant(qat_net) | |||
normal_outputs = net(inputs) | |||
qat_outputs = qat_net(inputs) | |||
assertTensorClose(normal_outputs, qat_outputs) | |||
net.eval() | |||
normal_outputs = net(inputs) | |||
qat_net.eval() | |||
qat_outputs = qat_net(inputs) | |||
assertTensorClose(normal_outputs, qat_outputs) |