@@ -12,11 +12,13 @@ from .conv import ( | |||
ConvRelu2d, | |||
ConvTranspose2d, | |||
ConvTranspose3d, | |||
ConvTransposeRelu2d, | |||
DeformableConv2d, | |||
LocalConv2d, | |||
RegionRestrictedConv, | |||
) | |||
from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d | |||
from .deformable_psroi_pooling import DeformablePSROIPooling | |||
from .dropout import Dropout | |||
from .elemwise import Elemwise | |||
@@ -773,6 +773,15 @@ class ConvRelu2d(Conv2d): | |||
return relu(self.calc_conv(inp, self.weight, self.bias)) | |||
class ConvTransposeRelu2d(ConvTranspose2d): | |||
r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d` and :func:`~.relu`. | |||
Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeRelu2d` using :func:`~.quantize.quantize_qat`. | |||
""" | |||
def forward(self, inp): | |||
return relu(self.calc_conv_transpose2d(inp, self.weight, self.bias)) | |||
class DeformableConv2d(_ConvNd): | |||
r"""Deformable Convolution. | |||
@@ -0,0 +1,62 @@ | |||
from typing import Tuple, Union | |||
from ..functional import relu | |||
from .batchnorm import BatchNorm2d | |||
from .conv import ConvTranspose2d | |||
from .module import Module | |||
class _ConvTransposeBnActivation2d(Module): | |||
def __init__( | |||
self, | |||
in_channels: int, | |||
out_channels: int, | |||
kernel_size: Union[int, Tuple[int, int]], | |||
stride: Union[int, Tuple[int, int]] = 1, | |||
padding: Union[int, Tuple[int, int]] = 0, | |||
output_padding: Union[int, Tuple[int, int]] = 0, | |||
dilation: Union[int, Tuple[int, int]] = 1, | |||
groups: int = 1, | |||
bias: bool = True, | |||
conv_mode: str = "cross_correlation", | |||
compute_mode: str = "default", | |||
eps=1e-5, | |||
momentum=0.9, | |||
affine=True, | |||
track_running_stats=True, | |||
**kwargs | |||
): | |||
super().__init__(**kwargs) | |||
self.conv_transpose2d = ConvTranspose2d( | |||
in_channels, | |||
out_channels, | |||
kernel_size, | |||
stride, | |||
padding, | |||
output_padding, | |||
dilation, | |||
groups, | |||
bias, | |||
conv_mode, | |||
compute_mode, | |||
**kwargs, | |||
) | |||
self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) | |||
class ConvTransposeBn2d(_ConvTransposeBnActivation2d): | |||
r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d` and :class:`~.module.BatchNorm2d`. | |||
Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeBn2d` using:func:`~.quantize.quantize_qat`. | |||
""" | |||
def forward(self, inp): | |||
return self.bn(self.conv_transpose2d(inp)) | |||
class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d): | |||
r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu`. | |||
Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeBnRelu2d` using :func:`~.quantize.quantize_qat`. | |||
""" | |||
def forward(self, inp): | |||
return relu(self.bn(self.conv_transpose2d(inp))) |
@@ -1,7 +1,8 @@ | |||
from .batch_matmul_activation import BatchMatMulActivation | |||
from .concat import Concat | |||
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d | |||
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d | |||
from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d | |||
from .elemwise import Elemwise | |||
from .linear import Linear | |||
from .module import QATModule | |||
@@ -59,8 +59,8 @@ class ConvTranspose2d(Float.ConvTranspose2d, QATModule): | |||
def calc_conv_transpose2d_qat(self, inp): | |||
w_qat = self.apply_quant_weight(self.weight) | |||
b_qat = self.apply_quant_bias(self.bias, inp, w_qat) | |||
conv = self.calc_conv_transpose2d(inp, w_qat, b_qat) | |||
return conv | |||
conv_transpose2d = self.calc_conv_transpose2d(inp, w_qat, b_qat) | |||
return conv_transpose2d | |||
@classmethod | |||
def from_float_module(cls, float_module: Float.ConvTranspose2d): | |||
@@ -88,3 +88,12 @@ class ConvTranspose2d(Float.ConvTranspose2d, QATModule): | |||
def forward(self, inp): | |||
return self.apply_quant_activation(self.calc_conv_transpose2d_qat(inp)) | |||
class ConvTransposeRelu2d(ConvTranspose2d): | |||
r"""A :class:`~.QATModule` include :class:`~.module.ConvTranspose2d` and :func:`~.relu` with QAT support. | |||
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`. | |||
""" | |||
def forward(self, inp): | |||
return self.apply_quant_activation(F.relu(self.calc_conv_transpose2d_qat(inp))) |
@@ -0,0 +1,163 @@ | |||
from ...functional import ones, relu, sqrt, sum, zeros | |||
from .. import conv_transpose_bn as Float | |||
from .module import QATModule | |||
class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule): | |||
def get_batch_mean_var(self, inp): | |||
def _sum_channel(inp, axis=0, keepdims=True): | |||
if isinstance(axis, int): | |||
out = sum(inp, axis=axis, keepdims=keepdims) | |||
elif isinstance(axis, tuple): | |||
for idx, elem in enumerate(axis): | |||
out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) | |||
return out | |||
sum1 = _sum_channel(inp, (0, 2, 3)) | |||
sum2 = _sum_channel(inp ** 2, (0, 2, 3)) | |||
reduce_size = inp.size / inp.shape[1] | |||
batch_mean = sum1 / reduce_size | |||
batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size | |||
return batch_mean, batch_var | |||
def fold_weight_bias(self, bn_mean, bn_var): | |||
# get fold bn conv_transpose2d param | |||
gamma = self.bn.weight | |||
if gamma is None: | |||
gamma = ones((self.bn.num_features), dtype="float32") | |||
gamma = gamma.reshape(1, -1, 1, 1) | |||
beta = self.bn.bias | |||
if beta is None: | |||
beta = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||
if bn_mean is None: | |||
bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||
if bn_var is None: | |||
bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") | |||
conv_transpose2d_bias = self.conv_transpose2d.bias | |||
if conv_transpose2d_bias is None: | |||
conv_transpose2d_bias = zeros( | |||
self.conv_transpose2d._infer_bias_shape(), dtype="float32" | |||
) | |||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
scale_factor = gamma * bn_istd | |||
if self.conv_transpose2d.groups == 1: | |||
w_fold = self.conv_transpose2d.weight * scale_factor.reshape(-1, 1, 1, 1) | |||
else: | |||
w_fold = self.conv_transpose2d.weight * scale_factor.reshape( | |||
self.conv_transpose2d.groups, -1, 1, 1, 1 | |||
) | |||
w_fold = self.apply_quant_weight(w_fold) | |||
b_fold = beta + gamma * (conv_transpose2d_bias - bn_mean) * bn_istd | |||
return w_fold, b_fold | |||
def update_running_mean_and_running_var( | |||
self, bn_mean, bn_var, num_elements_per_channel | |||
): | |||
# update running mean and running var. no grad, use unbiased bn var | |||
bn_mean = bn_mean.detach() | |||
bn_var = ( | |||
bn_var.detach() * num_elements_per_channel / (num_elements_per_channel - 1) | |||
) | |||
exponential_average_factor = 1 - self.bn.momentum | |||
self.bn.running_mean *= self.bn.momentum | |||
self.bn.running_mean += exponential_average_factor * bn_mean | |||
self.bn.running_var *= self.bn.momentum | |||
self.bn.running_var += exponential_average_factor * bn_var | |||
def calc_conv_transpose2d_bn_qat(self, inp, approx=True): | |||
if self.training and not approx: | |||
conv_transpose2d = self.conv_transpose2d(inp) | |||
bn_mean, bn_var = self.get_batch_mean_var(conv_transpose2d) | |||
num_elements_per_channel = conv_transpose2d.size / conv_transpose2d.shape[1] | |||
self.update_running_mean_and_running_var( | |||
bn_mean, bn_var, num_elements_per_channel | |||
) | |||
else: | |||
bn_mean, bn_var = self.bn.running_mean, self.bn.running_var | |||
# get gamma and beta in BatchNorm | |||
gamma = self.bn.weight | |||
if gamma is None: | |||
gamma = ones((self.bn.num_features), dtype="float32") | |||
gamma = gamma.reshape(1, -1, 1, 1) | |||
beta = self.bn.bias | |||
if beta is None: | |||
beta = zeros((self.bn.num_features), dtype="float32") | |||
beta = beta.reshape(1, -1, 1, 1) | |||
# conv_transpose2d_bias | |||
conv_transpose2d_bias = self.conv_transpose2d.bias | |||
if conv_transpose2d_bias is None: | |||
conv_transpose2d_bias = zeros( | |||
self.conv_transpose2d._infer_bias_shape(), dtype="float32" | |||
) | |||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
scale_factor = gamma * bn_istd | |||
if self.conv_transpose2d.groups == 1: | |||
w_fold = self.conv_transpose2d.weight * scale_factor.reshape(1, -1, 1, 1) | |||
else: | |||
w_fold = self.conv_transpose2d.weight * scale_factor.reshape( | |||
self.conv_transpose2d.groups, 1, -1, 1, 1 | |||
) | |||
b_fold = None | |||
if not (self.training and approx): | |||
b_fold = beta + gamma * (conv_transpose2d_bias - bn_mean) * bn_istd | |||
w_qat = self.apply_quant_weight(w_fold) | |||
b_qat = self.apply_quant_bias(b_fold, inp, w_qat) | |||
conv_transpose2d = self.conv_transpose2d.calc_conv_transpose2d( | |||
inp, w_qat, b_qat | |||
) | |||
if not (self.training and approx): | |||
return conv_transpose2d | |||
# rescale conv_transpose2d to get original conv_transpose2d output | |||
orig_conv_transpose2d = conv_transpose2d / scale_factor.reshape(1, -1, 1, 1) | |||
if self.conv_transpose2d.bias is not None: | |||
orig_conv_transpose2d = orig_conv_transpose2d + self.conv_transpose2d.bias | |||
# calculate batch norm | |||
conv_transpose2d = self.bn(orig_conv_transpose2d) | |||
return conv_transpose2d | |||
@classmethod | |||
def from_float_module(cls, float_module: Float._ConvTransposeBnActivation2d): | |||
qat_module = cls( | |||
float_module.conv_transpose2d.in_channels, | |||
float_module.conv_transpose2d.out_channels, | |||
float_module.conv_transpose2d.kernel_size, | |||
float_module.conv_transpose2d.stride, | |||
float_module.conv_transpose2d.padding, | |||
float_module.conv_transpose2d.output_padding, | |||
float_module.conv_transpose2d.dilation, | |||
float_module.conv_transpose2d.groups, | |||
float_module.conv_transpose2d.bias is not None, | |||
float_module.conv_transpose2d.conv_mode, | |||
float_module.conv_transpose2d.compute_mode, | |||
name=float_module.name, | |||
) | |||
qat_module.conv_transpose2d.weight = float_module.conv_transpose2d.weight | |||
qat_module.conv_transpose2d.bias = float_module.conv_transpose2d.bias | |||
qat_module.bn = float_module.bn | |||
return qat_module | |||
class ConvTransposeBn2d(_ConvTransposeBnActivation2d): | |||
r"""A fused :class:`~.QATModule` including :class:`~.module.ConvTranspose2d` and :class:`~.module.BatchNorm2d` with QAT support. | |||
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`. | |||
""" | |||
def forward(self, inp): | |||
return self.apply_quant_activation(self.calc_conv_transpose2d_bn_qat(inp)) | |||
class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d): | |||
r"""A fused :class:`~.QATModule` including :class:`~.module.ConvTranspose2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu` with QAT support. | |||
Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`. | |||
""" | |||
def forward(self, inp): | |||
return self.apply_quant_activation(relu(self.calc_conv_transpose2d_bn_qat(inp))) |
@@ -1,7 +1,8 @@ | |||
from .batch_matmul_activation import BatchMatMulActivation | |||
from .concat import Concat | |||
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d | |||
from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d | |||
from .conv_bn import ConvBn2d, ConvBnRelu2d | |||
from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d | |||
from .elemwise import Elemwise | |||
from .linear import Linear | |||
from .module import QuantizedModule | |||
@@ -178,7 +178,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): | |||
:class:`~.QATModule` instance. | |||
""" | |||
output_dtype = qat_module.get_activation_dtype() | |||
qconv = cls( | |||
qconv_transpose2d = cls( | |||
qat_module.in_channels, | |||
qat_module.out_channels, | |||
qat_module.kernel_size, | |||
@@ -194,15 +194,19 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): | |||
name=qat_module.name, | |||
) | |||
weight = qat_module.weight.astype(qat_module.get_weight_dtype()) | |||
qconv.weight = Parameter(weight.numpy(), name=qat_module.weight.name) | |||
qconv.bias = ( | |||
qconv_transpose2d.weight = Parameter( | |||
weight.numpy(), name=qat_module.weight.name | |||
) | |||
qconv_transpose2d.bias = ( | |||
Parameter(qat_module.bias.numpy(), name=qat_module.bias.name) | |||
if qat_module.bias is not None | |||
else None | |||
) | |||
return qconv | |||
return qconv_transpose2d | |||
def calc_conv_transpose2d_quantized(self, inp, nonlinear_mode): | |||
assert nonlinear_mode == "identity", "nonlinear_mode shoule be 'identity'" | |||
def calc_conv_transpose2d_quantized(self, inp): | |||
if self.bias is not None: | |||
inp_scale = dtype.get_scale(inp.dtype) | |||
w_scale = dtype.get_scale(self.weight.dtype) | |||
@@ -225,4 +229,11 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): | |||
) | |||
def forward(self, inp): | |||
return self.calc_conv_transpose2d_quantized(inp) | |||
return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="identity") | |||
class ConvTransposeRelu2d(ConvTranspose2d): | |||
r"""Quantized version of :class:`~.qat.ConvTransposeRelu2d`.""" | |||
def forward(self, inp): | |||
return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="relu") |
@@ -0,0 +1,53 @@ | |||
from ...tensor import Parameter | |||
from ..qat import conv_transpose_bn as QAT | |||
from .conv import ConvTranspose2d | |||
class _ConvTransposeBnActivation2d(ConvTranspose2d): | |||
r"""Applies a 2D deconvolution over a quantized input tensor, used for inference only. | |||
""" | |||
@classmethod | |||
def from_qat_module(cls, qat_module: QAT._ConvTransposeBnActivation2d): | |||
r""" | |||
Return a :class:`~.QuantizedModule` instance converted from a | |||
:class:`~.QATModule` instance. | |||
""" | |||
output_dtype = qat_module.get_activation_dtype() | |||
qconv_transpose2d = cls( | |||
qat_module.conv_transpose2d.in_channels, | |||
qat_module.conv_transpose2d.out_channels, | |||
qat_module.conv_transpose2d.kernel_size, | |||
qat_module.conv_transpose2d.stride, | |||
qat_module.conv_transpose2d.padding, | |||
qat_module.conv_transpose2d.output_padding, | |||
qat_module.conv_transpose2d.dilation, | |||
qat_module.conv_transpose2d.groups, | |||
dtype=output_dtype, | |||
name=qat_module.name, | |||
) | |||
w_fold, b_fold = qat_module.fold_weight_bias( | |||
qat_module.bn.running_mean, qat_module.bn.running_var | |||
) | |||
weight = w_fold.astype(qat_module.get_weight_dtype()) | |||
qconv_transpose2d.weight = Parameter( | |||
weight.numpy(), name=qat_module.conv_transpose2d.weight.name | |||
) | |||
qconv_transpose2d.bias = Parameter(b_fold.numpy()) | |||
if qat_module.conv_transpose2d.bias is not None: | |||
qconv_transpose2d.bias.name = qat_module.conv_transpose2d.bias.name | |||
return qconv_transpose2d | |||
class ConvTransposeBn2d(_ConvTransposeBnActivation2d): | |||
r"""Quantized version of :class:`~.qat.ConvTransposeBn2d`.""" | |||
def forward(self, inp): | |||
return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="identity") | |||
class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d): | |||
r"""Quantized version of :class:`~.qat.ConvTransposeBnRelu2d`.""" | |||
def forward(self, inp): | |||
return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="relu") |
@@ -1,48 +1,70 @@ | |||
from copy import deepcopy | |||
from ..functional import ones, sqrt, zeros | |||
from ..module import BatchNorm2d, Conv2d, ConvBn2d, ConvBnRelu2d, ConvRelu2d, ReLU | |||
from ..module import ( | |||
BatchNorm2d, | |||
Conv2d, | |||
ConvBn2d, | |||
ConvBnRelu2d, | |||
ConvRelu2d, | |||
ConvTranspose2d, | |||
ConvTransposeBn2d, | |||
ConvTransposeBnRelu2d, | |||
ConvTransposeRelu2d, | |||
ReLU, | |||
) | |||
from ..tensor import Parameter | |||
_MAP_TO_FUSED_MODULE = { | |||
(Conv2d, BatchNorm2d, ReLU, False): ConvRelu2d, | |||
(Conv2d, BatchNorm2d, ReLU, True): ConvBnRelu2d, | |||
(ConvTranspose2d, BatchNorm2d, ReLU, False): ConvTransposeRelu2d, | |||
(ConvTranspose2d, BatchNorm2d, ReLU, True): ConvTransposeBnRelu2d, | |||
(Conv2d, BatchNorm2d, False): Conv2d, | |||
(Conv2d, BatchNorm2d, True): ConvBn2d, | |||
(Conv2d, ReLU): ConvRelu2d, | |||
(ConvTranspose2d, BatchNorm2d, False): ConvTranspose2d, | |||
(ConvTranspose2d, BatchNorm2d, True): ConvTransposeBn2d, | |||
(ConvTranspose2d, ReLU): ConvTransposeRelu2d, | |||
} | |||
def fold_weight_bias(weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5): | |||
# get fold bn conv param | |||
def fold_weight_bias( | |||
weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False | |||
): | |||
shape = (1, -1, 1, 1) | |||
if transpose: | |||
shape = (-1, 1, 1, 1) | |||
kernel_shape = weight.shape | |||
if len(kernel_shape) == 5: | |||
groups, num_features = kernel_shape[0], kernel_shape[1] | |||
else: | |||
groups, num_features = 1, kernel_shape[0] | |||
out_channels = groups * num_features | |||
if gamma is None: | |||
gamma = ones((num_features), dtype="float32") | |||
gamma = ones((out_channels,), dtype="float32") | |||
gamma = gamma.reshape(1, -1, 1, 1) | |||
if beta is None: | |||
beta = zeros((num_features), dtype="float32") | |||
beta = zeros((out_channels,), dtype="float32") | |||
beta = beta.reshape(1, -1, 1, 1) | |||
if bn_mean is None: | |||
bn_mean = zeros((1, num_features, 1, 1), dtype="float32") | |||
bn_mean = zeros((1, out_channels, 1, 1), dtype="float32") | |||
if bn_var is None: | |||
bn_var = ones((1, num_features, 1, 1), dtype="float32") | |||
bn_var = ones((1, out_channels, 1, 1), dtype="float32") | |||
if bias is None: | |||
bias = zeros((1, num_features, 1, 1), dtype="float32") | |||
bias = zeros((1, out_channels, 1, 1), dtype="float32") | |||
bn_istd = 1.0 / sqrt(bn_var + eps) | |||
scale_factor = gamma * bn_istd | |||
if groups == 1: | |||
w_fold = weight * scale_factor.reshape(-1, 1, 1, 1) | |||
w_fold = weight * scale_factor.reshape(*shape) | |||
else: | |||
w_fold = weight * scale_factor.reshape(groups, -1, 1, 1, 1) | |||
w_fold = weight * scale_factor.reshape(groups, *shape) | |||
b_fold = beta + gamma * (bias - bn_mean) * bn_istd | |||
return w_fold, b_fold | |||
@@ -84,3 +106,55 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): | |||
module.bn = deepcopy(bn) | |||
new_conv.training = conv.training | |||
return module | |||
def fuse_conv_transpose2d_bn_relu_module( | |||
conv_transpose2d: ConvTranspose2d, bn: BatchNorm2d, relu: ReLU | |||
): | |||
module_key = tuple([type(m) for m in [conv_transpose2d, bn, relu] if m]) | |||
if bn: | |||
assert ( | |||
conv_transpose2d.training == bn.training | |||
), "ConvTranspose2d and BN both must be in the same mode (train or eval)." | |||
assert ( | |||
bn.num_features == conv_transpose2d.out_channels | |||
), "Output channel of ConvTranspose2d must match num_features of BatchNorm2d" | |||
module_key = module_key + (conv_transpose2d.training,) | |||
module = _MAP_TO_FUSED_MODULE[module_key]( | |||
in_channels=conv_transpose2d.in_channels, | |||
out_channels=conv_transpose2d.out_channels, | |||
kernel_size=conv_transpose2d.kernel_size, | |||
stride=conv_transpose2d.stride, | |||
padding=conv_transpose2d.padding, | |||
output_padding=conv_transpose2d.output_padding, | |||
dilation=conv_transpose2d.dilation, | |||
groups=conv_transpose2d.groups, | |||
bias=conv_transpose2d.bias is not None, | |||
conv_mode=conv_transpose2d.conv_mode, | |||
compute_mode=conv_transpose2d.compute_mode, | |||
name=conv_transpose2d.name, | |||
) | |||
new_conv_transpose2d = ( | |||
module | |||
if bn is None or not conv_transpose2d.training | |||
else module.conv_transpose2d | |||
) | |||
weight, bias = conv_transpose2d.weight, conv_transpose2d.bias | |||
if not conv_transpose2d.training and bn is not None: | |||
weight, bias = fold_weight_bias( | |||
weight, | |||
bias, | |||
bn.weight, | |||
bn.bias, | |||
bn.running_mean, | |||
bn.running_var, | |||
bn.eps, | |||
transpose=False, | |||
) | |||
new_conv_transpose2d.weight = Parameter(weight) | |||
if bias is not None: | |||
new_conv_transpose2d.bias = Parameter(bias) | |||
if bn is not None and conv_transpose2d.training: | |||
module.bn = deepcopy(bn) | |||
new_conv_transpose2d.training = conv_transpose2d.training | |||
return module |
@@ -5,7 +5,9 @@ import numpy as np | |||
import pytest | |||
import megengine.utils.comp_graph_tools as cgtools | |||
from megengine import jit, tensor | |||
from megengine import jit | |||
from megengine import module as M | |||
from megengine import tensor | |||
from megengine.device import get_device_count | |||
from megengine.functional import expand_dims | |||
from megengine.module import ( | |||
@@ -14,6 +16,8 @@ from megengine.module import ( | |||
ConvBn2d, | |||
ConvRelu2d, | |||
ConvTranspose2d, | |||
ConvTransposeBn2d, | |||
ConvTransposeRelu2d, | |||
DequantStub, | |||
Module, | |||
QuantStub, | |||
@@ -34,6 +38,49 @@ def test_qat_convbn2d(): | |||
module = ConvBn2d( | |||
in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||
) | |||
M.init.normal_(module.bn.weight) | |||
M.init.normal_(module.bn.bias) | |||
module.train() | |||
qat_module = quantize_qat(module, inplace=False) | |||
disable_fake_quant(qat_module) | |||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
normal_outputs = module(inputs) | |||
qat_outputs = qat_module(inputs) | |||
np.testing.assert_allclose( | |||
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||
) | |||
np.testing.assert_allclose( | |||
module.bn.running_mean.numpy(), | |||
qat_module.bn.running_mean.numpy(), | |||
atol=5e-8, | |||
) | |||
np.testing.assert_allclose( | |||
module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7, | |||
) | |||
module.eval() | |||
normal_outputs = module(inputs) | |||
qat_module.eval() | |||
qat_outputs = qat_module(inputs) | |||
np.testing.assert_allclose( | |||
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||
) | |||
def test_qat_convtransposebn2d(): | |||
in_channels = 32 | |||
out_channels = 64 | |||
kernel_size = 3 | |||
for groups, bias in product([1, 4], [True, False]): | |||
module = ConvTransposeBn2d( | |||
in_channels=in_channels, | |||
out_channels=out_channels, | |||
kernel_size=kernel_size, | |||
output_padding=0, | |||
groups=groups, | |||
bias=bias, | |||
) | |||
M.init.normal_(module.bn.weight) | |||
M.init.normal_(module.bn.bias) | |||
module.train() | |||
qat_module = quantize_qat(module, inplace=False) | |||
disable_fake_quant(qat_module) | |||
@@ -235,10 +282,14 @@ def test_qat_conv_transpose2d(): | |||
self.conv = ConvTranspose2d( | |||
in_channels, out_channels, kernel_size, bias=bias | |||
) | |||
self.conv_transpose2d_relu = ConvTransposeRelu2d( | |||
out_channels, in_channels, kernel_size, bias=bias | |||
) | |||
def forward(self, inp): | |||
out = self.quant(inp) | |||
out = self.conv(out) | |||
out = self.conv_transpose2d_relu(out) | |||
out = self.dequant(out) | |||
return out | |||
@@ -250,10 +301,14 @@ def test_qat_conv_transpose2d(): | |||
disable_fake_quant(qat_net) | |||
normal_outputs = net(inputs) | |||
qat_outputs = qat_net(inputs) | |||
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||
np.testing.assert_allclose( | |||
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6 | |||
) | |||
net.eval() | |||
normal_outputs = net(inputs) | |||
qat_net.eval() | |||
qat_outputs = qat_net(inputs) | |||
np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) | |||
np.testing.assert_allclose( | |||
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6 | |||
) |