@@ -24,8 +24,7 @@ class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule | |||
# get fold bn conv_transpose2d param | |||
gamma = self.bn.weight | |||
if gamma is None: | |||
gamma = ones((self.bn.num_features), dtype="float32") | |||
gamma = gamma.reshape(1, -1, 1, 1) | |||
gamma = ones((1, self.bn.num_features, 1, 1), dtype="float32") | |||
beta = self.bn.bias | |||
if beta is None: | |||
beta = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | |||
@@ -44,10 +43,10 @@ class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule | |||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | |||
scale_factor = gamma * bn_istd | |||
if self.conv_transpose2d.groups == 1: | |||
w_fold = self.conv_transpose2d.weight * scale_factor.reshape(-1, 1, 1, 1) | |||
w_fold = self.conv_transpose2d.weight * scale_factor.reshape(1, -1, 1, 1) | |||
else: | |||
w_fold = self.conv_transpose2d.weight * scale_factor.reshape( | |||
self.conv_transpose2d.groups, -1, 1, 1, 1 | |||
self.conv_transpose2d.groups, 1, -1, 1, 1 | |||
) | |||
w_fold = self.apply_quant_weight(w_fold) | |||
@@ -32,15 +32,21 @@ _MAP_TO_FUSED_MODULE = { | |||
def fold_weight_bias( | |||
weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False | |||
): | |||
shape = (1, -1, 1, 1) | |||
shape = (-1, 1, 1, 1) | |||
if transpose: | |||
shape = (-1, 1, 1, 1) | |||
shape = (1, -1, 1, 1) | |||
kernel_shape = weight.shape | |||
if len(kernel_shape) == 5: | |||
groups, num_features = kernel_shape[0], kernel_shape[1] | |||
if transpose: | |||
groups, num_features = kernel_shape[0], kernel_shape[2] | |||
else: | |||
groups, num_features = kernel_shape[0], kernel_shape[1] | |||
else: | |||
groups, num_features = 1, kernel_shape[0] | |||
if transpose: | |||
groups, num_features = 1, kernel_shape[1] | |||
else: | |||
groups, num_features = 1, kernel_shape[0] | |||
out_channels = groups * num_features | |||
if gamma is None: | |||
@@ -93,12 +99,37 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): | |||
compute_mode=conv.compute_mode, | |||
name=conv.name, | |||
) | |||
new_conv = module if bn is None or not conv.training else module.conv | |||
if isinstance(conv, ConvTranspose2d): | |||
module.output_padding = conv.output_padding | |||
new_conv = ( | |||
module if bn is None or not conv.training else module.conv_transpose2d | |||
) | |||
else: | |||
new_conv = module if bn is None or not conv.training else module.conv | |||
weight, bias = conv.weight, conv.bias | |||
if not conv.training and bn is not None: | |||
weight, bias = fold_weight_bias( | |||
weight, bias, bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.eps, | |||
) | |||
if isinstance(conv, ConvTranspose2d): | |||
weight, bias = fold_weight_bias( | |||
weight, | |||
bias, | |||
bn.weight, | |||
bn.bias, | |||
bn.running_mean, | |||
bn.running_var, | |||
bn.eps, | |||
transpose=True, | |||
) | |||
else: | |||
weight, bias = fold_weight_bias( | |||
weight, | |||
bias, | |||
bn.weight, | |||
bn.bias, | |||
bn.running_mean, | |||
bn.running_var, | |||
bn.eps, | |||
) | |||
new_conv.weight = Parameter(weight) | |||
if bias is not None: | |||
new_conv.bias = Parameter(bias) | |||
@@ -106,55 +137,3 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): | |||
module.bn = deepcopy(bn) | |||
new_conv.training = conv.training | |||
return module | |||
def fuse_conv_transpose2d_bn_relu_module( | |||
conv_transpose2d: ConvTranspose2d, bn: BatchNorm2d, relu: ReLU | |||
): | |||
module_key = tuple([type(m) for m in [conv_transpose2d, bn, relu] if m]) | |||
if bn: | |||
assert ( | |||
conv_transpose2d.training == bn.training | |||
), "ConvTranspose2d and BN both must be in the same mode (train or eval)." | |||
assert ( | |||
bn.num_features == conv_transpose2d.out_channels | |||
), "Output channel of ConvTranspose2d must match num_features of BatchNorm2d" | |||
module_key = module_key + (conv_transpose2d.training,) | |||
module = _MAP_TO_FUSED_MODULE[module_key]( | |||
in_channels=conv_transpose2d.in_channels, | |||
out_channels=conv_transpose2d.out_channels, | |||
kernel_size=conv_transpose2d.kernel_size, | |||
stride=conv_transpose2d.stride, | |||
padding=conv_transpose2d.padding, | |||
output_padding=conv_transpose2d.output_padding, | |||
dilation=conv_transpose2d.dilation, | |||
groups=conv_transpose2d.groups, | |||
bias=conv_transpose2d.bias is not None, | |||
conv_mode=conv_transpose2d.conv_mode, | |||
compute_mode=conv_transpose2d.compute_mode, | |||
name=conv_transpose2d.name, | |||
) | |||
new_conv_transpose2d = ( | |||
module | |||
if bn is None or not conv_transpose2d.training | |||
else module.conv_transpose2d | |||
) | |||
weight, bias = conv_transpose2d.weight, conv_transpose2d.bias | |||
if not conv_transpose2d.training and bn is not None: | |||
weight, bias = fold_weight_bias( | |||
weight, | |||
bias, | |||
bn.weight, | |||
bn.bias, | |||
bn.running_mean, | |||
bn.running_var, | |||
bn.eps, | |||
transpose=False, | |||
) | |||
new_conv_transpose2d.weight = Parameter(weight) | |||
if bias is not None: | |||
new_conv_transpose2d.bias = Parameter(bias) | |||
if bn is not None and conv_transpose2d.training: | |||
module.bn = deepcopy(bn) | |||
new_conv_transpose2d.training = conv_transpose2d.training | |||
return module |
@@ -34,35 +34,49 @@ def test_qat_convbn2d(): | |||
in_channels = 32 | |||
out_channels = 64 | |||
kernel_size = 3 | |||
class TestNet(Module): | |||
def __init__(self, groups, bias): | |||
super().__init__() | |||
self.quant = QuantStub() | |||
self.dequant = DequantStub() | |||
self.conv_bn = ConvBn2d( | |||
in_channels, out_channels, kernel_size, groups=groups, bias=bias, | |||
) | |||
def forward(self, inp): | |||
out = self.quant(inp) | |||
out = self.conv_bn(out) | |||
out = self.dequant(out) | |||
return out | |||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
for groups, bias in product([1, 4], [True, False]): | |||
module = ConvBn2d( | |||
in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||
) | |||
M.init.normal_(module.bn.weight) | |||
M.init.normal_(module.bn.bias) | |||
module.train() | |||
qat_module = quantize_qat(module, inplace=False) | |||
disable_fake_quant(qat_module) | |||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
normal_outputs = module(inputs) | |||
qat_outputs = qat_module(inputs) | |||
net = TestNet(groups, bias) | |||
net.train() | |||
qat_net = quantize_qat(net, inplace=False) | |||
disable_fake_quant(qat_net) | |||
normal_outputs = net(inputs) | |||
qat_outputs = qat_net(inputs) | |||
np.testing.assert_allclose( | |||
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-4, | |||
) | |||
np.testing.assert_allclose( | |||
module.bn.running_mean.numpy(), | |||
qat_module.bn.running_mean.numpy(), | |||
net.conv_bn.bn.running_mean.numpy(), | |||
qat_net.conv_bn.bn.running_mean.numpy(), | |||
atol=5e-8, | |||
) | |||
np.testing.assert_allclose( | |||
module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7, | |||
net.conv_bn.bn.running_var.numpy(), | |||
qat_net.conv_bn.bn.running_var.numpy(), | |||
atol=5e-7, | |||
) | |||
module.eval() | |||
normal_outputs = module(inputs) | |||
qat_module.eval() | |||
qat_outputs = qat_module(inputs) | |||
net.eval() | |||
normal_outputs = net(inputs) | |||
qat_net.eval() | |||
qat_outputs = qat_net(inputs) | |||
np.testing.assert_allclose( | |||
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-4, | |||
) | |||
@@ -70,40 +84,44 @@ def test_qat_convtransposebn2d(): | |||
in_channels = 32 | |||
out_channels = 64 | |||
kernel_size = 3 | |||
class TestNet(Module): | |||
def __init__(self, groups, bias): | |||
super().__init__() | |||
self.quant = QuantStub() | |||
self.dequant = DequantStub() | |||
self.conv_transpose_bn = ConvTransposeBn2d( | |||
in_channels, out_channels, kernel_size, groups=groups, bias=bias, | |||
) | |||
def forward(self, inp): | |||
out = self.quant(inp) | |||
out = self.conv_transpose_bn(out) | |||
out = self.dequant(out) | |||
return out | |||
for groups, bias in product([1, 4], [True, False]): | |||
module = ConvTransposeBn2d( | |||
in_channels=in_channels, | |||
out_channels=out_channels, | |||
kernel_size=kernel_size, | |||
output_padding=0, | |||
groups=groups, | |||
bias=bias, | |||
) | |||
M.init.normal_(module.bn.weight) | |||
M.init.normal_(module.bn.bias) | |||
module.train() | |||
qat_module = quantize_qat(module, inplace=False) | |||
disable_fake_quant(qat_module) | |||
net = TestNet(groups, bias) | |||
net.train() | |||
qat_net = quantize_qat(net, inplace=False) | |||
disable_fake_quant(qat_net) | |||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
normal_outputs = module(inputs) | |||
qat_outputs = qat_module(inputs) | |||
np.testing.assert_allclose( | |||
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||
) | |||
normal_outputs = net(inputs) | |||
qat_outputs = qat_net(inputs) | |||
np.testing.assert_allclose( | |||
module.bn.running_mean.numpy(), | |||
qat_module.bn.running_mean.numpy(), | |||
atol=5e-8, | |||
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-5, | |||
) | |||
np.testing.assert_allclose( | |||
module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7, | |||
net.conv_transpose_bn.bn.running_var.numpy(), | |||
qat_net.conv_transpose_bn.bn.running_var.numpy(), | |||
atol=5e-7, | |||
) | |||
module.eval() | |||
normal_outputs = module(inputs) | |||
qat_module.eval() | |||
qat_outputs = qat_module(inputs) | |||
net.eval() | |||
normal_outputs = net(inputs) | |||
qat_net.eval() | |||
qat_outputs = qat_net(inputs) | |||
np.testing.assert_allclose( | |||
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-5, | |||
) | |||
@@ -3,6 +3,15 @@ import pytest | |||
from megengine import Parameter, Tensor | |||
from megengine import module as Float | |||
from megengine.functional import ones, zeros | |||
from megengine.module import ( | |||
BatchNorm2d, | |||
Conv2d, | |||
ConvBn2d, | |||
ConvTranspose2d, | |||
ConvTransposeBn2d, | |||
ReLU, | |||
) | |||
from megengine.module import qat as QAT | |||
from megengine.module import quantized as Q | |||
from megengine.quantization import ( | |||
@@ -24,6 +33,7 @@ from megengine.quantization.quantize import ( | |||
quantize_qat, | |||
reset_qconfig, | |||
) | |||
from megengine.utils.bn_fusion import fuse_conv_bn_relu_module | |||
class FloatNet(Float.Module): | |||
@@ -291,3 +301,85 @@ def test_convert_with_custom_mapping(): | |||
net = Net() | |||
qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) | |||
assert isinstance(qat_net.example, QATExample) | |||
def test_ConvBn2d_fold_weight_bias(): | |||
in_channels = 32 | |||
out_channels = 64 | |||
kernel_size = 3 | |||
conv = Conv2d(in_channels, out_channels, kernel_size) | |||
bn = BatchNorm2d(out_channels) | |||
relu = ReLU() | |||
fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) | |||
bn.eval() | |||
fused_conv.eval() | |||
inputs = Tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
expected_result = relu(bn(conv(inputs))) | |||
actual_result = fused_conv(inputs) | |||
np.testing.assert_allclose( | |||
expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||
) | |||
conv.eval() | |||
bn.eval() | |||
relu.eval() | |||
fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) | |||
fused_conv.eval() | |||
expected_result = relu(conv(inputs)) | |||
actual_result = fused_conv(inputs) | |||
np.testing.assert_allclose( | |||
expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||
) | |||
conv.train() | |||
bn.train() | |||
fused_conv = fuse_conv_bn_relu_module(conv, bn, None) | |||
fused_conv.train() | |||
expected_result = bn(conv(inputs)) | |||
actual_result = fused_conv(inputs) | |||
np.testing.assert_allclose( | |||
expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||
) | |||
def test_ConvTransposeBn2d_fold_weight_bias(): | |||
in_channels = 32 | |||
out_channels = 64 | |||
kernel_size = 3 | |||
conv = ConvTranspose2d(in_channels, out_channels, kernel_size) | |||
bn = BatchNorm2d(out_channels) | |||
relu = ReLU() | |||
fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) | |||
bn.eval() | |||
fused_conv.eval() | |||
inputs = Tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||
expected_result = relu(bn(conv(inputs))) | |||
actual_result = fused_conv(inputs) | |||
np.testing.assert_allclose( | |||
expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||
) | |||
conv.eval() | |||
bn.eval() | |||
relu.eval() | |||
fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) | |||
fused_conv.eval() | |||
expected_result = relu(conv(inputs)) | |||
actual_result = fused_conv(inputs) | |||
np.testing.assert_allclose( | |||
expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||
) | |||
conv.train() | |||
bn.train() | |||
fused_conv = fuse_conv_bn_relu_module(conv, bn, None) | |||
fused_conv.train() | |||
expected_result = bn(conv(inputs)) | |||
actual_result = fused_conv(inputs) | |||
np.testing.assert_allclose( | |||
expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||
) |