From da7f250c483431f5a3d4c32ce306e773095e57e7 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 18 Aug 2022 15:34:16 +0800 Subject: [PATCH] feat(mge/module): deconv fuse bn and relu GitOrigin-RevId: 5619b397a4686edec3f98f02c66cf3e70b197092 --- imperative/python/megengine/module/__init__.py | 2 + imperative/python/megengine/module/conv.py | 9 ++ .../python/megengine/module/conv_transpose_bn.py | 62 ++++++++ imperative/python/megengine/module/qat/__init__.py | 3 +- imperative/python/megengine/module/qat/conv.py | 13 +- .../megengine/module/qat/conv_transpose_bn.py | 163 +++++++++++++++++++++ .../python/megengine/module/quantized/__init__.py | 3 +- .../python/megengine/module/quantized/conv.py | 23 ++- .../module/quantized/conv_transpose_bn.py | 53 +++++++ imperative/python/megengine/utils/bn_fusion.py | 94 ++++++++++-- imperative/python/test/unit/module/test_qat.py | 61 +++++++- 11 files changed, 463 insertions(+), 23 deletions(-) create mode 100644 imperative/python/megengine/module/conv_transpose_bn.py create mode 100644 imperative/python/megengine/module/qat/conv_transpose_bn.py create mode 100644 imperative/python/megengine/module/quantized/conv_transpose_bn.py diff --git a/imperative/python/megengine/module/__init__.py b/imperative/python/megengine/module/__init__.py index 73de9058..fe375839 100644 --- a/imperative/python/megengine/module/__init__.py +++ b/imperative/python/megengine/module/__init__.py @@ -12,11 +12,13 @@ from .conv import ( ConvRelu2d, ConvTranspose2d, ConvTranspose3d, + ConvTransposeRelu2d, DeformableConv2d, LocalConv2d, RegionRestrictedConv, ) from .conv_bn import ConvBn2d, ConvBnRelu2d +from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d from .deformable_psroi_pooling import DeformablePSROIPooling from .dropout import Dropout from .elemwise import Elemwise diff --git a/imperative/python/megengine/module/conv.py b/imperative/python/megengine/module/conv.py index 97a6f202..5170023f 100644 --- a/imperative/python/megengine/module/conv.py +++ b/imperative/python/megengine/module/conv.py @@ -773,6 +773,15 @@ class ConvRelu2d(Conv2d): return relu(self.calc_conv(inp, self.weight, self.bias)) +class ConvTransposeRelu2d(ConvTranspose2d): + r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d` and :func:`~.relu`. + Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeRelu2d` using :func:`~.quantize.quantize_qat`. + """ + + def forward(self, inp): + return relu(self.calc_conv_transpose2d(inp, self.weight, self.bias)) + + class DeformableConv2d(_ConvNd): r"""Deformable Convolution. diff --git a/imperative/python/megengine/module/conv_transpose_bn.py b/imperative/python/megengine/module/conv_transpose_bn.py new file mode 100644 index 00000000..3c433620 --- /dev/null +++ b/imperative/python/megengine/module/conv_transpose_bn.py @@ -0,0 +1,62 @@ +from typing import Tuple, Union + +from ..functional import relu +from .batchnorm import BatchNorm2d +from .conv import ConvTranspose2d +from .module import Module + + +class _ConvTransposeBnActivation2d(Module): + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int, int]], + stride: Union[int, Tuple[int, int]] = 1, + padding: Union[int, Tuple[int, int]] = 0, + output_padding: Union[int, Tuple[int, int]] = 0, + dilation: Union[int, Tuple[int, int]] = 1, + groups: int = 1, + bias: bool = True, + conv_mode: str = "cross_correlation", + compute_mode: str = "default", + eps=1e-5, + momentum=0.9, + affine=True, + track_running_stats=True, + **kwargs + ): + super().__init__(**kwargs) + self.conv_transpose2d = ConvTranspose2d( + in_channels, + out_channels, + kernel_size, + stride, + padding, + output_padding, + dilation, + groups, + bias, + conv_mode, + compute_mode, + **kwargs, + ) + self.bn = BatchNorm2d(out_channels, eps, momentum, affine, track_running_stats) + + +class ConvTransposeBn2d(_ConvTransposeBnActivation2d): + r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d` and :class:`~.module.BatchNorm2d`. + Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeBn2d` using:func:`~.quantize.quantize_qat`. + """ + + def forward(self, inp): + return self.bn(self.conv_transpose2d(inp)) + + +class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d): + r"""A fused :class:`~.Module` including :class:`~.module.ConvTranspose2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu`. + Could be replaced with :class:`~.QATModule` version :class:`~.qat.ConvTransposeBnRelu2d` using :func:`~.quantize.quantize_qat`. + """ + + def forward(self, inp): + return relu(self.bn(self.conv_transpose2d(inp))) diff --git a/imperative/python/megengine/module/qat/__init__.py b/imperative/python/megengine/module/qat/__init__.py index 027b5861..2a95dabf 100644 --- a/imperative/python/megengine/module/qat/__init__.py +++ b/imperative/python/megengine/module/qat/__init__.py @@ -1,7 +1,8 @@ from .batch_matmul_activation import BatchMatMulActivation from .concat import Concat -from .conv import Conv2d, ConvRelu2d, ConvTranspose2d +from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d from .conv_bn import ConvBn2d, ConvBnRelu2d +from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d from .elemwise import Elemwise from .linear import Linear from .module import QATModule diff --git a/imperative/python/megengine/module/qat/conv.py b/imperative/python/megengine/module/qat/conv.py index e3281106..2ee69ea8 100644 --- a/imperative/python/megengine/module/qat/conv.py +++ b/imperative/python/megengine/module/qat/conv.py @@ -59,8 +59,8 @@ class ConvTranspose2d(Float.ConvTranspose2d, QATModule): def calc_conv_transpose2d_qat(self, inp): w_qat = self.apply_quant_weight(self.weight) b_qat = self.apply_quant_bias(self.bias, inp, w_qat) - conv = self.calc_conv_transpose2d(inp, w_qat, b_qat) - return conv + conv_transpose2d = self.calc_conv_transpose2d(inp, w_qat, b_qat) + return conv_transpose2d @classmethod def from_float_module(cls, float_module: Float.ConvTranspose2d): @@ -88,3 +88,12 @@ class ConvTranspose2d(Float.ConvTranspose2d, QATModule): def forward(self, inp): return self.apply_quant_activation(self.calc_conv_transpose2d_qat(inp)) + + +class ConvTransposeRelu2d(ConvTranspose2d): + r"""A :class:`~.QATModule` include :class:`~.module.ConvTranspose2d` and :func:`~.relu` with QAT support. + Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`. + """ + + def forward(self, inp): + return self.apply_quant_activation(F.relu(self.calc_conv_transpose2d_qat(inp))) diff --git a/imperative/python/megengine/module/qat/conv_transpose_bn.py b/imperative/python/megengine/module/qat/conv_transpose_bn.py new file mode 100644 index 00000000..3ff6b704 --- /dev/null +++ b/imperative/python/megengine/module/qat/conv_transpose_bn.py @@ -0,0 +1,163 @@ +from ...functional import ones, relu, sqrt, sum, zeros +from .. import conv_transpose_bn as Float +from .module import QATModule + + +class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule): + def get_batch_mean_var(self, inp): + def _sum_channel(inp, axis=0, keepdims=True): + if isinstance(axis, int): + out = sum(inp, axis=axis, keepdims=keepdims) + elif isinstance(axis, tuple): + for idx, elem in enumerate(axis): + out = sum(inp if idx == 0 else out, axis=elem, keepdims=keepdims) + return out + + sum1 = _sum_channel(inp, (0, 2, 3)) + sum2 = _sum_channel(inp ** 2, (0, 2, 3)) + reduce_size = inp.size / inp.shape[1] + batch_mean = sum1 / reduce_size + batch_var = (sum2 - sum1 ** 2 / reduce_size) / reduce_size + return batch_mean, batch_var + + def fold_weight_bias(self, bn_mean, bn_var): + # get fold bn conv_transpose2d param + gamma = self.bn.weight + if gamma is None: + gamma = ones((self.bn.num_features), dtype="float32") + gamma = gamma.reshape(1, -1, 1, 1) + beta = self.bn.bias + if beta is None: + beta = zeros((1, self.bn.num_features, 1, 1), dtype="float32") + + if bn_mean is None: + bn_mean = zeros((1, self.bn.num_features, 1, 1), dtype="float32") + if bn_var is None: + bn_var = ones((1, self.bn.num_features, 1, 1), dtype="float32") + + conv_transpose2d_bias = self.conv_transpose2d.bias + if conv_transpose2d_bias is None: + conv_transpose2d_bias = zeros( + self.conv_transpose2d._infer_bias_shape(), dtype="float32" + ) + + bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) + scale_factor = gamma * bn_istd + if self.conv_transpose2d.groups == 1: + w_fold = self.conv_transpose2d.weight * scale_factor.reshape(-1, 1, 1, 1) + else: + w_fold = self.conv_transpose2d.weight * scale_factor.reshape( + self.conv_transpose2d.groups, -1, 1, 1, 1 + ) + + w_fold = self.apply_quant_weight(w_fold) + b_fold = beta + gamma * (conv_transpose2d_bias - bn_mean) * bn_istd + return w_fold, b_fold + + def update_running_mean_and_running_var( + self, bn_mean, bn_var, num_elements_per_channel + ): + # update running mean and running var. no grad, use unbiased bn var + bn_mean = bn_mean.detach() + bn_var = ( + bn_var.detach() * num_elements_per_channel / (num_elements_per_channel - 1) + ) + exponential_average_factor = 1 - self.bn.momentum + self.bn.running_mean *= self.bn.momentum + self.bn.running_mean += exponential_average_factor * bn_mean + self.bn.running_var *= self.bn.momentum + self.bn.running_var += exponential_average_factor * bn_var + + def calc_conv_transpose2d_bn_qat(self, inp, approx=True): + if self.training and not approx: + conv_transpose2d = self.conv_transpose2d(inp) + bn_mean, bn_var = self.get_batch_mean_var(conv_transpose2d) + num_elements_per_channel = conv_transpose2d.size / conv_transpose2d.shape[1] + self.update_running_mean_and_running_var( + bn_mean, bn_var, num_elements_per_channel + ) + else: + bn_mean, bn_var = self.bn.running_mean, self.bn.running_var + + # get gamma and beta in BatchNorm + gamma = self.bn.weight + if gamma is None: + gamma = ones((self.bn.num_features), dtype="float32") + gamma = gamma.reshape(1, -1, 1, 1) + beta = self.bn.bias + if beta is None: + beta = zeros((self.bn.num_features), dtype="float32") + beta = beta.reshape(1, -1, 1, 1) + # conv_transpose2d_bias + conv_transpose2d_bias = self.conv_transpose2d.bias + if conv_transpose2d_bias is None: + conv_transpose2d_bias = zeros( + self.conv_transpose2d._infer_bias_shape(), dtype="float32" + ) + + bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) + scale_factor = gamma * bn_istd + if self.conv_transpose2d.groups == 1: + w_fold = self.conv_transpose2d.weight * scale_factor.reshape(1, -1, 1, 1) + else: + w_fold = self.conv_transpose2d.weight * scale_factor.reshape( + self.conv_transpose2d.groups, 1, -1, 1, 1 + ) + b_fold = None + if not (self.training and approx): + b_fold = beta + gamma * (conv_transpose2d_bias - bn_mean) * bn_istd + + w_qat = self.apply_quant_weight(w_fold) + b_qat = self.apply_quant_bias(b_fold, inp, w_qat) + conv_transpose2d = self.conv_transpose2d.calc_conv_transpose2d( + inp, w_qat, b_qat + ) + if not (self.training and approx): + return conv_transpose2d + + # rescale conv_transpose2d to get original conv_transpose2d output + orig_conv_transpose2d = conv_transpose2d / scale_factor.reshape(1, -1, 1, 1) + if self.conv_transpose2d.bias is not None: + orig_conv_transpose2d = orig_conv_transpose2d + self.conv_transpose2d.bias + # calculate batch norm + conv_transpose2d = self.bn(orig_conv_transpose2d) + return conv_transpose2d + + @classmethod + def from_float_module(cls, float_module: Float._ConvTransposeBnActivation2d): + qat_module = cls( + float_module.conv_transpose2d.in_channels, + float_module.conv_transpose2d.out_channels, + float_module.conv_transpose2d.kernel_size, + float_module.conv_transpose2d.stride, + float_module.conv_transpose2d.padding, + float_module.conv_transpose2d.output_padding, + float_module.conv_transpose2d.dilation, + float_module.conv_transpose2d.groups, + float_module.conv_transpose2d.bias is not None, + float_module.conv_transpose2d.conv_mode, + float_module.conv_transpose2d.compute_mode, + name=float_module.name, + ) + qat_module.conv_transpose2d.weight = float_module.conv_transpose2d.weight + qat_module.conv_transpose2d.bias = float_module.conv_transpose2d.bias + qat_module.bn = float_module.bn + return qat_module + + +class ConvTransposeBn2d(_ConvTransposeBnActivation2d): + r"""A fused :class:`~.QATModule` including :class:`~.module.ConvTranspose2d` and :class:`~.module.BatchNorm2d` with QAT support. + Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`. + """ + + def forward(self, inp): + return self.apply_quant_activation(self.calc_conv_transpose2d_bn_qat(inp)) + + +class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d): + r"""A fused :class:`~.QATModule` including :class:`~.module.ConvTranspose2d`, :class:`~.module.BatchNorm2d` and :func:`~.relu` with QAT support. + Could be applied with :class:`~.Observer` and :class:`~.quantization.fake_quant.FakeQuantize`. + """ + + def forward(self, inp): + return self.apply_quant_activation(relu(self.calc_conv_transpose2d_bn_qat(inp))) diff --git a/imperative/python/megengine/module/quantized/__init__.py b/imperative/python/megengine/module/quantized/__init__.py index 930deabe..11ff807e 100644 --- a/imperative/python/megengine/module/quantized/__init__.py +++ b/imperative/python/megengine/module/quantized/__init__.py @@ -1,7 +1,8 @@ from .batch_matmul_activation import BatchMatMulActivation from .concat import Concat -from .conv import Conv2d, ConvRelu2d, ConvTranspose2d +from .conv import Conv2d, ConvRelu2d, ConvTranspose2d, ConvTransposeRelu2d from .conv_bn import ConvBn2d, ConvBnRelu2d +from .conv_transpose_bn import ConvTransposeBn2d, ConvTransposeBnRelu2d from .elemwise import Elemwise from .linear import Linear from .module import QuantizedModule diff --git a/imperative/python/megengine/module/quantized/conv.py b/imperative/python/megengine/module/quantized/conv.py index 0fc8142e..3915c9dd 100644 --- a/imperative/python/megengine/module/quantized/conv.py +++ b/imperative/python/megengine/module/quantized/conv.py @@ -178,7 +178,7 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): :class:`~.QATModule` instance. """ output_dtype = qat_module.get_activation_dtype() - qconv = cls( + qconv_transpose2d = cls( qat_module.in_channels, qat_module.out_channels, qat_module.kernel_size, @@ -194,15 +194,19 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): name=qat_module.name, ) weight = qat_module.weight.astype(qat_module.get_weight_dtype()) - qconv.weight = Parameter(weight.numpy(), name=qat_module.weight.name) - qconv.bias = ( + qconv_transpose2d.weight = Parameter( + weight.numpy(), name=qat_module.weight.name + ) + qconv_transpose2d.bias = ( Parameter(qat_module.bias.numpy(), name=qat_module.bias.name) if qat_module.bias is not None else None ) - return qconv + return qconv_transpose2d + + def calc_conv_transpose2d_quantized(self, inp, nonlinear_mode): + assert nonlinear_mode == "identity", "nonlinear_mode shoule be 'identity'" - def calc_conv_transpose2d_quantized(self, inp): if self.bias is not None: inp_scale = dtype.get_scale(inp.dtype) w_scale = dtype.get_scale(self.weight.dtype) @@ -225,4 +229,11 @@ class ConvTranspose2d(Float.ConvTranspose2d, QuantizedModule): ) def forward(self, inp): - return self.calc_conv_transpose2d_quantized(inp) + return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="identity") + + +class ConvTransposeRelu2d(ConvTranspose2d): + r"""Quantized version of :class:`~.qat.ConvTransposeRelu2d`.""" + + def forward(self, inp): + return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="relu") diff --git a/imperative/python/megengine/module/quantized/conv_transpose_bn.py b/imperative/python/megengine/module/quantized/conv_transpose_bn.py new file mode 100644 index 00000000..0c73b2d2 --- /dev/null +++ b/imperative/python/megengine/module/quantized/conv_transpose_bn.py @@ -0,0 +1,53 @@ +from ...tensor import Parameter +from ..qat import conv_transpose_bn as QAT +from .conv import ConvTranspose2d + + +class _ConvTransposeBnActivation2d(ConvTranspose2d): + r"""Applies a 2D deconvolution over a quantized input tensor, used for inference only. + """ + + @classmethod + def from_qat_module(cls, qat_module: QAT._ConvTransposeBnActivation2d): + r""" + Return a :class:`~.QuantizedModule` instance converted from a + :class:`~.QATModule` instance. + """ + output_dtype = qat_module.get_activation_dtype() + qconv_transpose2d = cls( + qat_module.conv_transpose2d.in_channels, + qat_module.conv_transpose2d.out_channels, + qat_module.conv_transpose2d.kernel_size, + qat_module.conv_transpose2d.stride, + qat_module.conv_transpose2d.padding, + qat_module.conv_transpose2d.output_padding, + qat_module.conv_transpose2d.dilation, + qat_module.conv_transpose2d.groups, + dtype=output_dtype, + name=qat_module.name, + ) + w_fold, b_fold = qat_module.fold_weight_bias( + qat_module.bn.running_mean, qat_module.bn.running_var + ) + weight = w_fold.astype(qat_module.get_weight_dtype()) + qconv_transpose2d.weight = Parameter( + weight.numpy(), name=qat_module.conv_transpose2d.weight.name + ) + qconv_transpose2d.bias = Parameter(b_fold.numpy()) + if qat_module.conv_transpose2d.bias is not None: + qconv_transpose2d.bias.name = qat_module.conv_transpose2d.bias.name + return qconv_transpose2d + + +class ConvTransposeBn2d(_ConvTransposeBnActivation2d): + r"""Quantized version of :class:`~.qat.ConvTransposeBn2d`.""" + + def forward(self, inp): + return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="identity") + + +class ConvTransposeBnRelu2d(_ConvTransposeBnActivation2d): + r"""Quantized version of :class:`~.qat.ConvTransposeBnRelu2d`.""" + + def forward(self, inp): + return self.calc_conv_transpose2d_quantized(inp, nonlinear_mode="relu") diff --git a/imperative/python/megengine/utils/bn_fusion.py b/imperative/python/megengine/utils/bn_fusion.py index 41f08055..b8b99457 100644 --- a/imperative/python/megengine/utils/bn_fusion.py +++ b/imperative/python/megengine/utils/bn_fusion.py @@ -1,48 +1,70 @@ from copy import deepcopy from ..functional import ones, sqrt, zeros -from ..module import BatchNorm2d, Conv2d, ConvBn2d, ConvBnRelu2d, ConvRelu2d, ReLU +from ..module import ( + BatchNorm2d, + Conv2d, + ConvBn2d, + ConvBnRelu2d, + ConvRelu2d, + ConvTranspose2d, + ConvTransposeBn2d, + ConvTransposeBnRelu2d, + ConvTransposeRelu2d, + ReLU, +) from ..tensor import Parameter _MAP_TO_FUSED_MODULE = { (Conv2d, BatchNorm2d, ReLU, False): ConvRelu2d, (Conv2d, BatchNorm2d, ReLU, True): ConvBnRelu2d, + (ConvTranspose2d, BatchNorm2d, ReLU, False): ConvTransposeRelu2d, + (ConvTranspose2d, BatchNorm2d, ReLU, True): ConvTransposeBnRelu2d, (Conv2d, BatchNorm2d, False): Conv2d, (Conv2d, BatchNorm2d, True): ConvBn2d, (Conv2d, ReLU): ConvRelu2d, + (ConvTranspose2d, BatchNorm2d, False): ConvTranspose2d, + (ConvTranspose2d, BatchNorm2d, True): ConvTransposeBn2d, + (ConvTranspose2d, ReLU): ConvTransposeRelu2d, } -def fold_weight_bias(weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5): - # get fold bn conv param +def fold_weight_bias( + weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False +): + shape = (1, -1, 1, 1) + if transpose: + shape = (-1, 1, 1, 1) + kernel_shape = weight.shape if len(kernel_shape) == 5: groups, num_features = kernel_shape[0], kernel_shape[1] else: groups, num_features = 1, kernel_shape[0] + out_channels = groups * num_features if gamma is None: - gamma = ones((num_features), dtype="float32") + gamma = ones((out_channels,), dtype="float32") gamma = gamma.reshape(1, -1, 1, 1) if beta is None: - beta = zeros((num_features), dtype="float32") + beta = zeros((out_channels,), dtype="float32") beta = beta.reshape(1, -1, 1, 1) if bn_mean is None: - bn_mean = zeros((1, num_features, 1, 1), dtype="float32") + bn_mean = zeros((1, out_channels, 1, 1), dtype="float32") if bn_var is None: - bn_var = ones((1, num_features, 1, 1), dtype="float32") + bn_var = ones((1, out_channels, 1, 1), dtype="float32") if bias is None: - bias = zeros((1, num_features, 1, 1), dtype="float32") + bias = zeros((1, out_channels, 1, 1), dtype="float32") bn_istd = 1.0 / sqrt(bn_var + eps) scale_factor = gamma * bn_istd if groups == 1: - w_fold = weight * scale_factor.reshape(-1, 1, 1, 1) + w_fold = weight * scale_factor.reshape(*shape) else: - w_fold = weight * scale_factor.reshape(groups, -1, 1, 1, 1) + w_fold = weight * scale_factor.reshape(groups, *shape) b_fold = beta + gamma * (bias - bn_mean) * bn_istd return w_fold, b_fold @@ -84,3 +106,55 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): module.bn = deepcopy(bn) new_conv.training = conv.training return module + + +def fuse_conv_transpose2d_bn_relu_module( + conv_transpose2d: ConvTranspose2d, bn: BatchNorm2d, relu: ReLU +): + module_key = tuple([type(m) for m in [conv_transpose2d, bn, relu] if m]) + if bn: + assert ( + conv_transpose2d.training == bn.training + ), "ConvTranspose2d and BN both must be in the same mode (train or eval)." + assert ( + bn.num_features == conv_transpose2d.out_channels + ), "Output channel of ConvTranspose2d must match num_features of BatchNorm2d" + module_key = module_key + (conv_transpose2d.training,) + module = _MAP_TO_FUSED_MODULE[module_key]( + in_channels=conv_transpose2d.in_channels, + out_channels=conv_transpose2d.out_channels, + kernel_size=conv_transpose2d.kernel_size, + stride=conv_transpose2d.stride, + padding=conv_transpose2d.padding, + output_padding=conv_transpose2d.output_padding, + dilation=conv_transpose2d.dilation, + groups=conv_transpose2d.groups, + bias=conv_transpose2d.bias is not None, + conv_mode=conv_transpose2d.conv_mode, + compute_mode=conv_transpose2d.compute_mode, + name=conv_transpose2d.name, + ) + new_conv_transpose2d = ( + module + if bn is None or not conv_transpose2d.training + else module.conv_transpose2d + ) + weight, bias = conv_transpose2d.weight, conv_transpose2d.bias + if not conv_transpose2d.training and bn is not None: + weight, bias = fold_weight_bias( + weight, + bias, + bn.weight, + bn.bias, + bn.running_mean, + bn.running_var, + bn.eps, + transpose=False, + ) + new_conv_transpose2d.weight = Parameter(weight) + if bias is not None: + new_conv_transpose2d.bias = Parameter(bias) + if bn is not None and conv_transpose2d.training: + module.bn = deepcopy(bn) + new_conv_transpose2d.training = conv_transpose2d.training + return module diff --git a/imperative/python/test/unit/module/test_qat.py b/imperative/python/test/unit/module/test_qat.py index 1d81a17d..df081982 100644 --- a/imperative/python/test/unit/module/test_qat.py +++ b/imperative/python/test/unit/module/test_qat.py @@ -5,7 +5,9 @@ import numpy as np import pytest import megengine.utils.comp_graph_tools as cgtools -from megengine import jit, tensor +from megengine import jit +from megengine import module as M +from megengine import tensor from megengine.device import get_device_count from megengine.functional import expand_dims from megengine.module import ( @@ -14,6 +16,8 @@ from megengine.module import ( ConvBn2d, ConvRelu2d, ConvTranspose2d, + ConvTransposeBn2d, + ConvTransposeRelu2d, DequantStub, Module, QuantStub, @@ -34,6 +38,49 @@ def test_qat_convbn2d(): module = ConvBn2d( in_channels, out_channels, kernel_size, groups=groups, bias=bias ) + M.init.normal_(module.bn.weight) + M.init.normal_(module.bn.bias) + module.train() + qat_module = quantize_qat(module, inplace=False) + disable_fake_quant(qat_module) + inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) + normal_outputs = module(inputs) + qat_outputs = qat_module(inputs) + np.testing.assert_allclose( + normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 + ) + np.testing.assert_allclose( + module.bn.running_mean.numpy(), + qat_module.bn.running_mean.numpy(), + atol=5e-8, + ) + np.testing.assert_allclose( + module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7, + ) + module.eval() + normal_outputs = module(inputs) + qat_module.eval() + qat_outputs = qat_module(inputs) + np.testing.assert_allclose( + normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 + ) + + +def test_qat_convtransposebn2d(): + in_channels = 32 + out_channels = 64 + kernel_size = 3 + for groups, bias in product([1, 4], [True, False]): + module = ConvTransposeBn2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + output_padding=0, + groups=groups, + bias=bias, + ) + M.init.normal_(module.bn.weight) + M.init.normal_(module.bn.bias) module.train() qat_module = quantize_qat(module, inplace=False) disable_fake_quant(qat_module) @@ -235,10 +282,14 @@ def test_qat_conv_transpose2d(): self.conv = ConvTranspose2d( in_channels, out_channels, kernel_size, bias=bias ) + self.conv_transpose2d_relu = ConvTransposeRelu2d( + out_channels, in_channels, kernel_size, bias=bias + ) def forward(self, inp): out = self.quant(inp) out = self.conv(out) + out = self.conv_transpose2d_relu(out) out = self.dequant(out) return out @@ -250,10 +301,14 @@ def test_qat_conv_transpose2d(): disable_fake_quant(qat_net) normal_outputs = net(inputs) qat_outputs = qat_net(inputs) - np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) + np.testing.assert_allclose( + normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6 + ) net.eval() normal_outputs = net(inputs) qat_net.eval() qat_outputs = qat_net(inputs) - np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy()) + np.testing.assert_allclose( + normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-6 + )