From 2694ff81c1fa7ba857ea9adb1d848cd1f34eb02b Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 25 Oct 2022 19:11:11 +0800 Subject: [PATCH] fix(mge/module): fix some deconv fuse bn problem GitOrigin-RevId: e88a63328065234c8060317d7c7e558c95bfbd4b --- .../megengine/module/qat/conv_transpose_bn.py | 7 +- imperative/python/megengine/utils/bn_fusion.py | 99 +++++++----------- imperative/python/test/unit/module/test_qat.py | 112 ++++++++++++--------- .../python/test/unit/quantization/test_quantize.py | 92 +++++++++++++++++ 4 files changed, 199 insertions(+), 111 deletions(-) diff --git a/imperative/python/megengine/module/qat/conv_transpose_bn.py b/imperative/python/megengine/module/qat/conv_transpose_bn.py index 3ff6b704..b9ca4004 100644 --- a/imperative/python/megengine/module/qat/conv_transpose_bn.py +++ b/imperative/python/megengine/module/qat/conv_transpose_bn.py @@ -24,8 +24,7 @@ class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule # get fold bn conv_transpose2d param gamma = self.bn.weight if gamma is None: - gamma = ones((self.bn.num_features), dtype="float32") - gamma = gamma.reshape(1, -1, 1, 1) + gamma = ones((1, self.bn.num_features, 1, 1), dtype="float32") beta = self.bn.bias if beta is None: beta = zeros((1, self.bn.num_features, 1, 1), dtype="float32") @@ -44,10 +43,10 @@ class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) scale_factor = gamma * bn_istd if self.conv_transpose2d.groups == 1: - w_fold = self.conv_transpose2d.weight * scale_factor.reshape(-1, 1, 1, 1) + w_fold = self.conv_transpose2d.weight * scale_factor.reshape(1, -1, 1, 1) else: w_fold = self.conv_transpose2d.weight * scale_factor.reshape( - self.conv_transpose2d.groups, -1, 1, 1, 1 + self.conv_transpose2d.groups, 1, -1, 1, 1 ) w_fold = self.apply_quant_weight(w_fold) diff --git a/imperative/python/megengine/utils/bn_fusion.py b/imperative/python/megengine/utils/bn_fusion.py index b8b99457..01e435f4 100644 --- a/imperative/python/megengine/utils/bn_fusion.py +++ b/imperative/python/megengine/utils/bn_fusion.py @@ -32,15 +32,21 @@ _MAP_TO_FUSED_MODULE = { def fold_weight_bias( weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False ): - shape = (1, -1, 1, 1) + shape = (-1, 1, 1, 1) if transpose: - shape = (-1, 1, 1, 1) + shape = (1, -1, 1, 1) kernel_shape = weight.shape if len(kernel_shape) == 5: - groups, num_features = kernel_shape[0], kernel_shape[1] + if transpose: + groups, num_features = kernel_shape[0], kernel_shape[2] + else: + groups, num_features = kernel_shape[0], kernel_shape[1] else: - groups, num_features = 1, kernel_shape[0] + if transpose: + groups, num_features = 1, kernel_shape[1] + else: + groups, num_features = 1, kernel_shape[0] out_channels = groups * num_features if gamma is None: @@ -93,12 +99,37 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): compute_mode=conv.compute_mode, name=conv.name, ) - new_conv = module if bn is None or not conv.training else module.conv + if isinstance(conv, ConvTranspose2d): + module.output_padding = conv.output_padding + new_conv = ( + module if bn is None or not conv.training else module.conv_transpose2d + ) + else: + new_conv = module if bn is None or not conv.training else module.conv + weight, bias = conv.weight, conv.bias if not conv.training and bn is not None: - weight, bias = fold_weight_bias( - weight, bias, bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.eps, - ) + if isinstance(conv, ConvTranspose2d): + weight, bias = fold_weight_bias( + weight, + bias, + bn.weight, + bn.bias, + bn.running_mean, + bn.running_var, + bn.eps, + transpose=True, + ) + else: + weight, bias = fold_weight_bias( + weight, + bias, + bn.weight, + bn.bias, + bn.running_mean, + bn.running_var, + bn.eps, + ) new_conv.weight = Parameter(weight) if bias is not None: new_conv.bias = Parameter(bias) @@ -106,55 +137,3 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): module.bn = deepcopy(bn) new_conv.training = conv.training return module - - -def fuse_conv_transpose2d_bn_relu_module( - conv_transpose2d: ConvTranspose2d, bn: BatchNorm2d, relu: ReLU -): - module_key = tuple([type(m) for m in [conv_transpose2d, bn, relu] if m]) - if bn: - assert ( - conv_transpose2d.training == bn.training - ), "ConvTranspose2d and BN both must be in the same mode (train or eval)." - assert ( - bn.num_features == conv_transpose2d.out_channels - ), "Output channel of ConvTranspose2d must match num_features of BatchNorm2d" - module_key = module_key + (conv_transpose2d.training,) - module = _MAP_TO_FUSED_MODULE[module_key]( - in_channels=conv_transpose2d.in_channels, - out_channels=conv_transpose2d.out_channels, - kernel_size=conv_transpose2d.kernel_size, - stride=conv_transpose2d.stride, - padding=conv_transpose2d.padding, - output_padding=conv_transpose2d.output_padding, - dilation=conv_transpose2d.dilation, - groups=conv_transpose2d.groups, - bias=conv_transpose2d.bias is not None, - conv_mode=conv_transpose2d.conv_mode, - compute_mode=conv_transpose2d.compute_mode, - name=conv_transpose2d.name, - ) - new_conv_transpose2d = ( - module - if bn is None or not conv_transpose2d.training - else module.conv_transpose2d - ) - weight, bias = conv_transpose2d.weight, conv_transpose2d.bias - if not conv_transpose2d.training and bn is not None: - weight, bias = fold_weight_bias( - weight, - bias, - bn.weight, - bn.bias, - bn.running_mean, - bn.running_var, - bn.eps, - transpose=False, - ) - new_conv_transpose2d.weight = Parameter(weight) - if bias is not None: - new_conv_transpose2d.bias = Parameter(bias) - if bn is not None and conv_transpose2d.training: - module.bn = deepcopy(bn) - new_conv_transpose2d.training = conv_transpose2d.training - return module diff --git a/imperative/python/test/unit/module/test_qat.py b/imperative/python/test/unit/module/test_qat.py index df081982..7b99fe0f 100644 --- a/imperative/python/test/unit/module/test_qat.py +++ b/imperative/python/test/unit/module/test_qat.py @@ -34,35 +34,49 @@ def test_qat_convbn2d(): in_channels = 32 out_channels = 64 kernel_size = 3 + + class TestNet(Module): + def __init__(self, groups, bias): + super().__init__() + self.quant = QuantStub() + self.dequant = DequantStub() + self.conv_bn = ConvBn2d( + in_channels, out_channels, kernel_size, groups=groups, bias=bias, + ) + + def forward(self, inp): + out = self.quant(inp) + out = self.conv_bn(out) + out = self.dequant(out) + return out + + inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) for groups, bias in product([1, 4], [True, False]): - module = ConvBn2d( - in_channels, out_channels, kernel_size, groups=groups, bias=bias - ) - M.init.normal_(module.bn.weight) - M.init.normal_(module.bn.bias) - module.train() - qat_module = quantize_qat(module, inplace=False) - disable_fake_quant(qat_module) - inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) - normal_outputs = module(inputs) - qat_outputs = qat_module(inputs) + net = TestNet(groups, bias) + net.train() + qat_net = quantize_qat(net, inplace=False) + disable_fake_quant(qat_net) + normal_outputs = net(inputs) + qat_outputs = qat_net(inputs) np.testing.assert_allclose( - normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 + normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-4, ) np.testing.assert_allclose( - module.bn.running_mean.numpy(), - qat_module.bn.running_mean.numpy(), + net.conv_bn.bn.running_mean.numpy(), + qat_net.conv_bn.bn.running_mean.numpy(), atol=5e-8, ) np.testing.assert_allclose( - module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7, + net.conv_bn.bn.running_var.numpy(), + qat_net.conv_bn.bn.running_var.numpy(), + atol=5e-7, ) - module.eval() - normal_outputs = module(inputs) - qat_module.eval() - qat_outputs = qat_module(inputs) + net.eval() + normal_outputs = net(inputs) + qat_net.eval() + qat_outputs = qat_net(inputs) np.testing.assert_allclose( - normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 + normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-4, ) @@ -70,40 +84,44 @@ def test_qat_convtransposebn2d(): in_channels = 32 out_channels = 64 kernel_size = 3 + + class TestNet(Module): + def __init__(self, groups, bias): + super().__init__() + self.quant = QuantStub() + self.dequant = DequantStub() + self.conv_transpose_bn = ConvTransposeBn2d( + in_channels, out_channels, kernel_size, groups=groups, bias=bias, + ) + + def forward(self, inp): + out = self.quant(inp) + out = self.conv_transpose_bn(out) + out = self.dequant(out) + return out + for groups, bias in product([1, 4], [True, False]): - module = ConvTransposeBn2d( - in_channels=in_channels, - out_channels=out_channels, - kernel_size=kernel_size, - output_padding=0, - groups=groups, - bias=bias, - ) - M.init.normal_(module.bn.weight) - M.init.normal_(module.bn.bias) - module.train() - qat_module = quantize_qat(module, inplace=False) - disable_fake_quant(qat_module) + net = TestNet(groups, bias) + net.train() + qat_net = quantize_qat(net, inplace=False) + disable_fake_quant(qat_net) inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) - normal_outputs = module(inputs) - qat_outputs = qat_module(inputs) - np.testing.assert_allclose( - normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 - ) + normal_outputs = net(inputs) + qat_outputs = qat_net(inputs) np.testing.assert_allclose( - module.bn.running_mean.numpy(), - qat_module.bn.running_mean.numpy(), - atol=5e-8, + normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-5, ) np.testing.assert_allclose( - module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7, + net.conv_transpose_bn.bn.running_var.numpy(), + qat_net.conv_transpose_bn.bn.running_var.numpy(), + atol=5e-7, ) - module.eval() - normal_outputs = module(inputs) - qat_module.eval() - qat_outputs = qat_module(inputs) + net.eval() + normal_outputs = net(inputs) + qat_net.eval() + qat_outputs = qat_net(inputs) np.testing.assert_allclose( - normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 + normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-5, ) diff --git a/imperative/python/test/unit/quantization/test_quantize.py b/imperative/python/test/unit/quantization/test_quantize.py index 3ee9ae4c..82661bc2 100644 --- a/imperative/python/test/unit/quantization/test_quantize.py +++ b/imperative/python/test/unit/quantization/test_quantize.py @@ -3,6 +3,15 @@ import pytest from megengine import Parameter, Tensor from megengine import module as Float +from megengine.functional import ones, zeros +from megengine.module import ( + BatchNorm2d, + Conv2d, + ConvBn2d, + ConvTranspose2d, + ConvTransposeBn2d, + ReLU, +) from megengine.module import qat as QAT from megengine.module import quantized as Q from megengine.quantization import ( @@ -24,6 +33,7 @@ from megengine.quantization.quantize import ( quantize_qat, reset_qconfig, ) +from megengine.utils.bn_fusion import fuse_conv_bn_relu_module class FloatNet(Float.Module): @@ -291,3 +301,85 @@ def test_convert_with_custom_mapping(): net = Net() qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) assert isinstance(qat_net.example, QATExample) + + +def test_ConvBn2d_fold_weight_bias(): + in_channels = 32 + out_channels = 64 + kernel_size = 3 + + conv = Conv2d(in_channels, out_channels, kernel_size) + bn = BatchNorm2d(out_channels) + relu = ReLU() + + fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) + bn.eval() + fused_conv.eval() + inputs = Tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) + expected_result = relu(bn(conv(inputs))) + actual_result = fused_conv(inputs) + np.testing.assert_allclose( + expected_result.numpy(), actual_result.numpy(), atol=1e-4 + ) + + conv.eval() + bn.eval() + relu.eval() + fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) + fused_conv.eval() + expected_result = relu(conv(inputs)) + actual_result = fused_conv(inputs) + np.testing.assert_allclose( + expected_result.numpy(), actual_result.numpy(), atol=1e-4 + ) + + conv.train() + bn.train() + fused_conv = fuse_conv_bn_relu_module(conv, bn, None) + fused_conv.train() + expected_result = bn(conv(inputs)) + actual_result = fused_conv(inputs) + np.testing.assert_allclose( + expected_result.numpy(), actual_result.numpy(), atol=1e-4 + ) + + +def test_ConvTransposeBn2d_fold_weight_bias(): + in_channels = 32 + out_channels = 64 + kernel_size = 3 + + conv = ConvTranspose2d(in_channels, out_channels, kernel_size) + bn = BatchNorm2d(out_channels) + relu = ReLU() + + fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) + bn.eval() + fused_conv.eval() + inputs = Tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) + expected_result = relu(bn(conv(inputs))) + actual_result = fused_conv(inputs) + np.testing.assert_allclose( + expected_result.numpy(), actual_result.numpy(), atol=1e-4 + ) + + conv.eval() + bn.eval() + relu.eval() + fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) + fused_conv.eval() + expected_result = relu(conv(inputs)) + actual_result = fused_conv(inputs) + np.testing.assert_allclose( + expected_result.numpy(), actual_result.numpy(), atol=1e-4 + ) + + conv.train() + bn.train() + fused_conv = fuse_conv_bn_relu_module(conv, bn, None) + fused_conv.train() + expected_result = bn(conv(inputs)) + actual_result = fused_conv(inputs) + np.testing.assert_allclose( + expected_result.numpy(), actual_result.numpy(), atol=1e-4 + )