Browse Source

fix(mge/module): fix some deconv fuse bn problem

GitOrigin-RevId: e88a633280
master
Megvii Engine Team 2 years ago
parent
commit
2694ff81c1
4 changed files with 199 additions and 111 deletions
  1. +3
    -4
      imperative/python/megengine/module/qat/conv_transpose_bn.py
  2. +39
    -60
      imperative/python/megengine/utils/bn_fusion.py
  3. +65
    -47
      imperative/python/test/unit/module/test_qat.py
  4. +92
    -0
      imperative/python/test/unit/quantization/test_quantize.py

+ 3
- 4
imperative/python/megengine/module/qat/conv_transpose_bn.py View File

@@ -24,8 +24,7 @@ class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule
# get fold bn conv_transpose2d param # get fold bn conv_transpose2d param
gamma = self.bn.weight gamma = self.bn.weight
if gamma is None: if gamma is None:
gamma = ones((self.bn.num_features), dtype="float32")
gamma = gamma.reshape(1, -1, 1, 1)
gamma = ones((1, self.bn.num_features, 1, 1), dtype="float32")
beta = self.bn.bias beta = self.bn.bias
if beta is None: if beta is None:
beta = zeros((1, self.bn.num_features, 1, 1), dtype="float32") beta = zeros((1, self.bn.num_features, 1, 1), dtype="float32")
@@ -44,10 +43,10 @@ class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) bn_istd = 1.0 / sqrt(bn_var + self.bn.eps)
scale_factor = gamma * bn_istd scale_factor = gamma * bn_istd
if self.conv_transpose2d.groups == 1: if self.conv_transpose2d.groups == 1:
w_fold = self.conv_transpose2d.weight * scale_factor.reshape(-1, 1, 1, 1)
w_fold = self.conv_transpose2d.weight * scale_factor.reshape(1, -1, 1, 1)
else: else:
w_fold = self.conv_transpose2d.weight * scale_factor.reshape( w_fold = self.conv_transpose2d.weight * scale_factor.reshape(
self.conv_transpose2d.groups, -1, 1, 1, 1
self.conv_transpose2d.groups, 1, -1, 1, 1
) )


w_fold = self.apply_quant_weight(w_fold) w_fold = self.apply_quant_weight(w_fold)


+ 39
- 60
imperative/python/megengine/utils/bn_fusion.py View File

@@ -32,15 +32,21 @@ _MAP_TO_FUSED_MODULE = {
def fold_weight_bias( def fold_weight_bias(
weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False
): ):
shape = (1, -1, 1, 1)
shape = (-1, 1, 1, 1)
if transpose: if transpose:
shape = (-1, 1, 1, 1)
shape = (1, -1, 1, 1)


kernel_shape = weight.shape kernel_shape = weight.shape
if len(kernel_shape) == 5: if len(kernel_shape) == 5:
groups, num_features = kernel_shape[0], kernel_shape[1]
if transpose:
groups, num_features = kernel_shape[0], kernel_shape[2]
else:
groups, num_features = kernel_shape[0], kernel_shape[1]
else: else:
groups, num_features = 1, kernel_shape[0]
if transpose:
groups, num_features = 1, kernel_shape[1]
else:
groups, num_features = 1, kernel_shape[0]


out_channels = groups * num_features out_channels = groups * num_features
if gamma is None: if gamma is None:
@@ -93,12 +99,37 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU):
compute_mode=conv.compute_mode, compute_mode=conv.compute_mode,
name=conv.name, name=conv.name,
) )
new_conv = module if bn is None or not conv.training else module.conv
if isinstance(conv, ConvTranspose2d):
module.output_padding = conv.output_padding
new_conv = (
module if bn is None or not conv.training else module.conv_transpose2d
)
else:
new_conv = module if bn is None or not conv.training else module.conv

weight, bias = conv.weight, conv.bias weight, bias = conv.weight, conv.bias
if not conv.training and bn is not None: if not conv.training and bn is not None:
weight, bias = fold_weight_bias(
weight, bias, bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.eps,
)
if isinstance(conv, ConvTranspose2d):
weight, bias = fold_weight_bias(
weight,
bias,
bn.weight,
bn.bias,
bn.running_mean,
bn.running_var,
bn.eps,
transpose=True,
)
else:
weight, bias = fold_weight_bias(
weight,
bias,
bn.weight,
bn.bias,
bn.running_mean,
bn.running_var,
bn.eps,
)
new_conv.weight = Parameter(weight) new_conv.weight = Parameter(weight)
if bias is not None: if bias is not None:
new_conv.bias = Parameter(bias) new_conv.bias = Parameter(bias)
@@ -106,55 +137,3 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU):
module.bn = deepcopy(bn) module.bn = deepcopy(bn)
new_conv.training = conv.training new_conv.training = conv.training
return module return module


def fuse_conv_transpose2d_bn_relu_module(
conv_transpose2d: ConvTranspose2d, bn: BatchNorm2d, relu: ReLU
):
module_key = tuple([type(m) for m in [conv_transpose2d, bn, relu] if m])
if bn:
assert (
conv_transpose2d.training == bn.training
), "ConvTranspose2d and BN both must be in the same mode (train or eval)."
assert (
bn.num_features == conv_transpose2d.out_channels
), "Output channel of ConvTranspose2d must match num_features of BatchNorm2d"
module_key = module_key + (conv_transpose2d.training,)
module = _MAP_TO_FUSED_MODULE[module_key](
in_channels=conv_transpose2d.in_channels,
out_channels=conv_transpose2d.out_channels,
kernel_size=conv_transpose2d.kernel_size,
stride=conv_transpose2d.stride,
padding=conv_transpose2d.padding,
output_padding=conv_transpose2d.output_padding,
dilation=conv_transpose2d.dilation,
groups=conv_transpose2d.groups,
bias=conv_transpose2d.bias is not None,
conv_mode=conv_transpose2d.conv_mode,
compute_mode=conv_transpose2d.compute_mode,
name=conv_transpose2d.name,
)
new_conv_transpose2d = (
module
if bn is None or not conv_transpose2d.training
else module.conv_transpose2d
)
weight, bias = conv_transpose2d.weight, conv_transpose2d.bias
if not conv_transpose2d.training and bn is not None:
weight, bias = fold_weight_bias(
weight,
bias,
bn.weight,
bn.bias,
bn.running_mean,
bn.running_var,
bn.eps,
transpose=False,
)
new_conv_transpose2d.weight = Parameter(weight)
if bias is not None:
new_conv_transpose2d.bias = Parameter(bias)
if bn is not None and conv_transpose2d.training:
module.bn = deepcopy(bn)
new_conv_transpose2d.training = conv_transpose2d.training
return module

+ 65
- 47
imperative/python/test/unit/module/test_qat.py View File

@@ -34,35 +34,49 @@ def test_qat_convbn2d():
in_channels = 32 in_channels = 32
out_channels = 64 out_channels = 64
kernel_size = 3 kernel_size = 3

class TestNet(Module):
def __init__(self, groups, bias):
super().__init__()
self.quant = QuantStub()
self.dequant = DequantStub()
self.conv_bn = ConvBn2d(
in_channels, out_channels, kernel_size, groups=groups, bias=bias,
)

def forward(self, inp):
out = self.quant(inp)
out = self.conv_bn(out)
out = self.dequant(out)
return out

inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
for groups, bias in product([1, 4], [True, False]): for groups, bias in product([1, 4], [True, False]):
module = ConvBn2d(
in_channels, out_channels, kernel_size, groups=groups, bias=bias
)
M.init.normal_(module.bn.weight)
M.init.normal_(module.bn.bias)
module.train()
qat_module = quantize_qat(module, inplace=False)
disable_fake_quant(qat_module)
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
normal_outputs = module(inputs)
qat_outputs = qat_module(inputs)
net = TestNet(groups, bias)
net.train()
qat_net = quantize_qat(net, inplace=False)
disable_fake_quant(qat_net)
normal_outputs = net(inputs)
qat_outputs = qat_net(inputs)
np.testing.assert_allclose( np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-4,
) )
np.testing.assert_allclose( np.testing.assert_allclose(
module.bn.running_mean.numpy(),
qat_module.bn.running_mean.numpy(),
net.conv_bn.bn.running_mean.numpy(),
qat_net.conv_bn.bn.running_mean.numpy(),
atol=5e-8, atol=5e-8,
) )
np.testing.assert_allclose( np.testing.assert_allclose(
module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7,
net.conv_bn.bn.running_var.numpy(),
qat_net.conv_bn.bn.running_var.numpy(),
atol=5e-7,
) )
module.eval()
normal_outputs = module(inputs)
qat_module.eval()
qat_outputs = qat_module(inputs)
net.eval()
normal_outputs = net(inputs)
qat_net.eval()
qat_outputs = qat_net(inputs)
np.testing.assert_allclose( np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-4,
) )




@@ -70,40 +84,44 @@ def test_qat_convtransposebn2d():
in_channels = 32 in_channels = 32
out_channels = 64 out_channels = 64
kernel_size = 3 kernel_size = 3

class TestNet(Module):
def __init__(self, groups, bias):
super().__init__()
self.quant = QuantStub()
self.dequant = DequantStub()
self.conv_transpose_bn = ConvTransposeBn2d(
in_channels, out_channels, kernel_size, groups=groups, bias=bias,
)

def forward(self, inp):
out = self.quant(inp)
out = self.conv_transpose_bn(out)
out = self.dequant(out)
return out

for groups, bias in product([1, 4], [True, False]): for groups, bias in product([1, 4], [True, False]):
module = ConvTransposeBn2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
output_padding=0,
groups=groups,
bias=bias,
)
M.init.normal_(module.bn.weight)
M.init.normal_(module.bn.bias)
module.train()
qat_module = quantize_qat(module, inplace=False)
disable_fake_quant(qat_module)
net = TestNet(groups, bias)
net.train()
qat_net = quantize_qat(net, inplace=False)
disable_fake_quant(qat_net)
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
normal_outputs = module(inputs)
qat_outputs = qat_module(inputs)
np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
)
normal_outputs = net(inputs)
qat_outputs = qat_net(inputs)
np.testing.assert_allclose( np.testing.assert_allclose(
module.bn.running_mean.numpy(),
qat_module.bn.running_mean.numpy(),
atol=5e-8,
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-5,
) )
np.testing.assert_allclose( np.testing.assert_allclose(
module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7,
net.conv_transpose_bn.bn.running_var.numpy(),
qat_net.conv_transpose_bn.bn.running_var.numpy(),
atol=5e-7,
) )
module.eval()
normal_outputs = module(inputs)
qat_module.eval()
qat_outputs = qat_module(inputs)
net.eval()
normal_outputs = net(inputs)
qat_net.eval()
qat_outputs = qat_net(inputs)
np.testing.assert_allclose( np.testing.assert_allclose(
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-5,
) )






+ 92
- 0
imperative/python/test/unit/quantization/test_quantize.py View File

@@ -3,6 +3,15 @@ import pytest


from megengine import Parameter, Tensor from megengine import Parameter, Tensor
from megengine import module as Float from megengine import module as Float
from megengine.functional import ones, zeros
from megengine.module import (
BatchNorm2d,
Conv2d,
ConvBn2d,
ConvTranspose2d,
ConvTransposeBn2d,
ReLU,
)
from megengine.module import qat as QAT from megengine.module import qat as QAT
from megengine.module import quantized as Q from megengine.module import quantized as Q
from megengine.quantization import ( from megengine.quantization import (
@@ -24,6 +33,7 @@ from megengine.quantization.quantize import (
quantize_qat, quantize_qat,
reset_qconfig, reset_qconfig,
) )
from megengine.utils.bn_fusion import fuse_conv_bn_relu_module




class FloatNet(Float.Module): class FloatNet(Float.Module):
@@ -291,3 +301,85 @@ def test_convert_with_custom_mapping():
net = Net() net = Net()
qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample})
assert isinstance(qat_net.example, QATExample) assert isinstance(qat_net.example, QATExample)


def test_ConvBn2d_fold_weight_bias():
in_channels = 32
out_channels = 64
kernel_size = 3

conv = Conv2d(in_channels, out_channels, kernel_size)
bn = BatchNorm2d(out_channels)
relu = ReLU()

fused_conv = fuse_conv_bn_relu_module(conv, bn, relu)
bn.eval()
fused_conv.eval()
inputs = Tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
expected_result = relu(bn(conv(inputs)))
actual_result = fused_conv(inputs)
np.testing.assert_allclose(
expected_result.numpy(), actual_result.numpy(), atol=1e-4
)

conv.eval()
bn.eval()
relu.eval()
fused_conv = fuse_conv_bn_relu_module(conv, bn, relu)
fused_conv.eval()
expected_result = relu(conv(inputs))
actual_result = fused_conv(inputs)
np.testing.assert_allclose(
expected_result.numpy(), actual_result.numpy(), atol=1e-4
)

conv.train()
bn.train()
fused_conv = fuse_conv_bn_relu_module(conv, bn, None)
fused_conv.train()
expected_result = bn(conv(inputs))
actual_result = fused_conv(inputs)
np.testing.assert_allclose(
expected_result.numpy(), actual_result.numpy(), atol=1e-4
)


def test_ConvTransposeBn2d_fold_weight_bias():
in_channels = 32
out_channels = 64
kernel_size = 3

conv = ConvTranspose2d(in_channels, out_channels, kernel_size)
bn = BatchNorm2d(out_channels)
relu = ReLU()

fused_conv = fuse_conv_bn_relu_module(conv, bn, relu)
bn.eval()
fused_conv.eval()
inputs = Tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
expected_result = relu(bn(conv(inputs)))
actual_result = fused_conv(inputs)
np.testing.assert_allclose(
expected_result.numpy(), actual_result.numpy(), atol=1e-4
)

conv.eval()
bn.eval()
relu.eval()
fused_conv = fuse_conv_bn_relu_module(conv, bn, relu)
fused_conv.eval()
expected_result = relu(conv(inputs))
actual_result = fused_conv(inputs)
np.testing.assert_allclose(
expected_result.numpy(), actual_result.numpy(), atol=1e-4
)

conv.train()
bn.train()
fused_conv = fuse_conv_bn_relu_module(conv, bn, None)
fused_conv.train()
expected_result = bn(conv(inputs))
actual_result = fused_conv(inputs)
np.testing.assert_allclose(
expected_result.numpy(), actual_result.numpy(), atol=1e-4
)

Loading…
Cancel
Save