@@ -24,8 +24,7 @@ class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule | |||||
# get fold bn conv_transpose2d param | # get fold bn conv_transpose2d param | ||||
gamma = self.bn.weight | gamma = self.bn.weight | ||||
if gamma is None: | if gamma is None: | ||||
gamma = ones((self.bn.num_features), dtype="float32") | |||||
gamma = gamma.reshape(1, -1, 1, 1) | |||||
gamma = ones((1, self.bn.num_features, 1, 1), dtype="float32") | |||||
beta = self.bn.bias | beta = self.bn.bias | ||||
if beta is None: | if beta is None: | ||||
beta = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | beta = zeros((1, self.bn.num_features, 1, 1), dtype="float32") | ||||
@@ -44,10 +43,10 @@ class _ConvTransposeBnActivation2d(Float._ConvTransposeBnActivation2d, QATModule | |||||
bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | bn_istd = 1.0 / sqrt(bn_var + self.bn.eps) | ||||
scale_factor = gamma * bn_istd | scale_factor = gamma * bn_istd | ||||
if self.conv_transpose2d.groups == 1: | if self.conv_transpose2d.groups == 1: | ||||
w_fold = self.conv_transpose2d.weight * scale_factor.reshape(-1, 1, 1, 1) | |||||
w_fold = self.conv_transpose2d.weight * scale_factor.reshape(1, -1, 1, 1) | |||||
else: | else: | ||||
w_fold = self.conv_transpose2d.weight * scale_factor.reshape( | w_fold = self.conv_transpose2d.weight * scale_factor.reshape( | ||||
self.conv_transpose2d.groups, -1, 1, 1, 1 | |||||
self.conv_transpose2d.groups, 1, -1, 1, 1 | |||||
) | ) | ||||
w_fold = self.apply_quant_weight(w_fold) | w_fold = self.apply_quant_weight(w_fold) | ||||
@@ -32,15 +32,21 @@ _MAP_TO_FUSED_MODULE = { | |||||
def fold_weight_bias( | def fold_weight_bias( | ||||
weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False | weight, bias, gamma, beta, bn_mean, bn_var, eps=1e-5, transpose=False | ||||
): | ): | ||||
shape = (1, -1, 1, 1) | |||||
shape = (-1, 1, 1, 1) | |||||
if transpose: | if transpose: | ||||
shape = (-1, 1, 1, 1) | |||||
shape = (1, -1, 1, 1) | |||||
kernel_shape = weight.shape | kernel_shape = weight.shape | ||||
if len(kernel_shape) == 5: | if len(kernel_shape) == 5: | ||||
groups, num_features = kernel_shape[0], kernel_shape[1] | |||||
if transpose: | |||||
groups, num_features = kernel_shape[0], kernel_shape[2] | |||||
else: | |||||
groups, num_features = kernel_shape[0], kernel_shape[1] | |||||
else: | else: | ||||
groups, num_features = 1, kernel_shape[0] | |||||
if transpose: | |||||
groups, num_features = 1, kernel_shape[1] | |||||
else: | |||||
groups, num_features = 1, kernel_shape[0] | |||||
out_channels = groups * num_features | out_channels = groups * num_features | ||||
if gamma is None: | if gamma is None: | ||||
@@ -93,12 +99,37 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): | |||||
compute_mode=conv.compute_mode, | compute_mode=conv.compute_mode, | ||||
name=conv.name, | name=conv.name, | ||||
) | ) | ||||
new_conv = module if bn is None or not conv.training else module.conv | |||||
if isinstance(conv, ConvTranspose2d): | |||||
module.output_padding = conv.output_padding | |||||
new_conv = ( | |||||
module if bn is None or not conv.training else module.conv_transpose2d | |||||
) | |||||
else: | |||||
new_conv = module if bn is None or not conv.training else module.conv | |||||
weight, bias = conv.weight, conv.bias | weight, bias = conv.weight, conv.bias | ||||
if not conv.training and bn is not None: | if not conv.training and bn is not None: | ||||
weight, bias = fold_weight_bias( | |||||
weight, bias, bn.weight, bn.bias, bn.running_mean, bn.running_var, bn.eps, | |||||
) | |||||
if isinstance(conv, ConvTranspose2d): | |||||
weight, bias = fold_weight_bias( | |||||
weight, | |||||
bias, | |||||
bn.weight, | |||||
bn.bias, | |||||
bn.running_mean, | |||||
bn.running_var, | |||||
bn.eps, | |||||
transpose=True, | |||||
) | |||||
else: | |||||
weight, bias = fold_weight_bias( | |||||
weight, | |||||
bias, | |||||
bn.weight, | |||||
bn.bias, | |||||
bn.running_mean, | |||||
bn.running_var, | |||||
bn.eps, | |||||
) | |||||
new_conv.weight = Parameter(weight) | new_conv.weight = Parameter(weight) | ||||
if bias is not None: | if bias is not None: | ||||
new_conv.bias = Parameter(bias) | new_conv.bias = Parameter(bias) | ||||
@@ -106,55 +137,3 @@ def fuse_conv_bn_relu_module(conv: Conv2d, bn: BatchNorm2d, relu: ReLU): | |||||
module.bn = deepcopy(bn) | module.bn = deepcopy(bn) | ||||
new_conv.training = conv.training | new_conv.training = conv.training | ||||
return module | return module | ||||
def fuse_conv_transpose2d_bn_relu_module( | |||||
conv_transpose2d: ConvTranspose2d, bn: BatchNorm2d, relu: ReLU | |||||
): | |||||
module_key = tuple([type(m) for m in [conv_transpose2d, bn, relu] if m]) | |||||
if bn: | |||||
assert ( | |||||
conv_transpose2d.training == bn.training | |||||
), "ConvTranspose2d and BN both must be in the same mode (train or eval)." | |||||
assert ( | |||||
bn.num_features == conv_transpose2d.out_channels | |||||
), "Output channel of ConvTranspose2d must match num_features of BatchNorm2d" | |||||
module_key = module_key + (conv_transpose2d.training,) | |||||
module = _MAP_TO_FUSED_MODULE[module_key]( | |||||
in_channels=conv_transpose2d.in_channels, | |||||
out_channels=conv_transpose2d.out_channels, | |||||
kernel_size=conv_transpose2d.kernel_size, | |||||
stride=conv_transpose2d.stride, | |||||
padding=conv_transpose2d.padding, | |||||
output_padding=conv_transpose2d.output_padding, | |||||
dilation=conv_transpose2d.dilation, | |||||
groups=conv_transpose2d.groups, | |||||
bias=conv_transpose2d.bias is not None, | |||||
conv_mode=conv_transpose2d.conv_mode, | |||||
compute_mode=conv_transpose2d.compute_mode, | |||||
name=conv_transpose2d.name, | |||||
) | |||||
new_conv_transpose2d = ( | |||||
module | |||||
if bn is None or not conv_transpose2d.training | |||||
else module.conv_transpose2d | |||||
) | |||||
weight, bias = conv_transpose2d.weight, conv_transpose2d.bias | |||||
if not conv_transpose2d.training and bn is not None: | |||||
weight, bias = fold_weight_bias( | |||||
weight, | |||||
bias, | |||||
bn.weight, | |||||
bn.bias, | |||||
bn.running_mean, | |||||
bn.running_var, | |||||
bn.eps, | |||||
transpose=False, | |||||
) | |||||
new_conv_transpose2d.weight = Parameter(weight) | |||||
if bias is not None: | |||||
new_conv_transpose2d.bias = Parameter(bias) | |||||
if bn is not None and conv_transpose2d.training: | |||||
module.bn = deepcopy(bn) | |||||
new_conv_transpose2d.training = conv_transpose2d.training | |||||
return module |
@@ -34,35 +34,49 @@ def test_qat_convbn2d(): | |||||
in_channels = 32 | in_channels = 32 | ||||
out_channels = 64 | out_channels = 64 | ||||
kernel_size = 3 | kernel_size = 3 | ||||
class TestNet(Module): | |||||
def __init__(self, groups, bias): | |||||
super().__init__() | |||||
self.quant = QuantStub() | |||||
self.dequant = DequantStub() | |||||
self.conv_bn = ConvBn2d( | |||||
in_channels, out_channels, kernel_size, groups=groups, bias=bias, | |||||
) | |||||
def forward(self, inp): | |||||
out = self.quant(inp) | |||||
out = self.conv_bn(out) | |||||
out = self.dequant(out) | |||||
return out | |||||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||||
for groups, bias in product([1, 4], [True, False]): | for groups, bias in product([1, 4], [True, False]): | ||||
module = ConvBn2d( | |||||
in_channels, out_channels, kernel_size, groups=groups, bias=bias | |||||
) | |||||
M.init.normal_(module.bn.weight) | |||||
M.init.normal_(module.bn.bias) | |||||
module.train() | |||||
qat_module = quantize_qat(module, inplace=False) | |||||
disable_fake_quant(qat_module) | |||||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||||
normal_outputs = module(inputs) | |||||
qat_outputs = qat_module(inputs) | |||||
net = TestNet(groups, bias) | |||||
net.train() | |||||
qat_net = quantize_qat(net, inplace=False) | |||||
disable_fake_quant(qat_net) | |||||
normal_outputs = net(inputs) | |||||
qat_outputs = qat_net(inputs) | |||||
np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||||
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-4, | |||||
) | ) | ||||
np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
module.bn.running_mean.numpy(), | |||||
qat_module.bn.running_mean.numpy(), | |||||
net.conv_bn.bn.running_mean.numpy(), | |||||
qat_net.conv_bn.bn.running_mean.numpy(), | |||||
atol=5e-8, | atol=5e-8, | ||||
) | ) | ||||
np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7, | |||||
net.conv_bn.bn.running_var.numpy(), | |||||
qat_net.conv_bn.bn.running_var.numpy(), | |||||
atol=5e-7, | |||||
) | ) | ||||
module.eval() | |||||
normal_outputs = module(inputs) | |||||
qat_module.eval() | |||||
qat_outputs = qat_module(inputs) | |||||
net.eval() | |||||
normal_outputs = net(inputs) | |||||
qat_net.eval() | |||||
qat_outputs = qat_net(inputs) | |||||
np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||||
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-4, | |||||
) | ) | ||||
@@ -70,40 +84,44 @@ def test_qat_convtransposebn2d(): | |||||
in_channels = 32 | in_channels = 32 | ||||
out_channels = 64 | out_channels = 64 | ||||
kernel_size = 3 | kernel_size = 3 | ||||
class TestNet(Module): | |||||
def __init__(self, groups, bias): | |||||
super().__init__() | |||||
self.quant = QuantStub() | |||||
self.dequant = DequantStub() | |||||
self.conv_transpose_bn = ConvTransposeBn2d( | |||||
in_channels, out_channels, kernel_size, groups=groups, bias=bias, | |||||
) | |||||
def forward(self, inp): | |||||
out = self.quant(inp) | |||||
out = self.conv_transpose_bn(out) | |||||
out = self.dequant(out) | |||||
return out | |||||
for groups, bias in product([1, 4], [True, False]): | for groups, bias in product([1, 4], [True, False]): | ||||
module = ConvTransposeBn2d( | |||||
in_channels=in_channels, | |||||
out_channels=out_channels, | |||||
kernel_size=kernel_size, | |||||
output_padding=0, | |||||
groups=groups, | |||||
bias=bias, | |||||
) | |||||
M.init.normal_(module.bn.weight) | |||||
M.init.normal_(module.bn.bias) | |||||
module.train() | |||||
qat_module = quantize_qat(module, inplace=False) | |||||
disable_fake_quant(qat_module) | |||||
net = TestNet(groups, bias) | |||||
net.train() | |||||
qat_net = quantize_qat(net, inplace=False) | |||||
disable_fake_quant(qat_net) | |||||
inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | ||||
normal_outputs = module(inputs) | |||||
qat_outputs = qat_module(inputs) | |||||
np.testing.assert_allclose( | |||||
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||||
) | |||||
normal_outputs = net(inputs) | |||||
qat_outputs = qat_net(inputs) | |||||
np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
module.bn.running_mean.numpy(), | |||||
qat_module.bn.running_mean.numpy(), | |||||
atol=5e-8, | |||||
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-5, | |||||
) | ) | ||||
np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7, | |||||
net.conv_transpose_bn.bn.running_var.numpy(), | |||||
qat_net.conv_transpose_bn.bn.running_var.numpy(), | |||||
atol=5e-7, | |||||
) | ) | ||||
module.eval() | |||||
normal_outputs = module(inputs) | |||||
qat_module.eval() | |||||
qat_outputs = qat_module(inputs) | |||||
net.eval() | |||||
normal_outputs = net(inputs) | |||||
qat_net.eval() | |||||
qat_outputs = qat_net(inputs) | |||||
np.testing.assert_allclose( | np.testing.assert_allclose( | ||||
normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6 | |||||
normal_outputs.numpy(), qat_outputs.numpy(), atol=1e-5, | |||||
) | ) | ||||
@@ -3,6 +3,15 @@ import pytest | |||||
from megengine import Parameter, Tensor | from megengine import Parameter, Tensor | ||||
from megengine import module as Float | from megengine import module as Float | ||||
from megengine.functional import ones, zeros | |||||
from megengine.module import ( | |||||
BatchNorm2d, | |||||
Conv2d, | |||||
ConvBn2d, | |||||
ConvTranspose2d, | |||||
ConvTransposeBn2d, | |||||
ReLU, | |||||
) | |||||
from megengine.module import qat as QAT | from megengine.module import qat as QAT | ||||
from megengine.module import quantized as Q | from megengine.module import quantized as Q | ||||
from megengine.quantization import ( | from megengine.quantization import ( | ||||
@@ -24,6 +33,7 @@ from megengine.quantization.quantize import ( | |||||
quantize_qat, | quantize_qat, | ||||
reset_qconfig, | reset_qconfig, | ||||
) | ) | ||||
from megengine.utils.bn_fusion import fuse_conv_bn_relu_module | |||||
class FloatNet(Float.Module): | class FloatNet(Float.Module): | ||||
@@ -291,3 +301,85 @@ def test_convert_with_custom_mapping(): | |||||
net = Net() | net = Net() | ||||
qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) | qat_net = quantize_qat(net, inplace=False, mapping={FloatExample: QATExample}) | ||||
assert isinstance(qat_net.example, QATExample) | assert isinstance(qat_net.example, QATExample) | ||||
def test_ConvBn2d_fold_weight_bias(): | |||||
in_channels = 32 | |||||
out_channels = 64 | |||||
kernel_size = 3 | |||||
conv = Conv2d(in_channels, out_channels, kernel_size) | |||||
bn = BatchNorm2d(out_channels) | |||||
relu = ReLU() | |||||
fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) | |||||
bn.eval() | |||||
fused_conv.eval() | |||||
inputs = Tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||||
expected_result = relu(bn(conv(inputs))) | |||||
actual_result = fused_conv(inputs) | |||||
np.testing.assert_allclose( | |||||
expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||||
) | |||||
conv.eval() | |||||
bn.eval() | |||||
relu.eval() | |||||
fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) | |||||
fused_conv.eval() | |||||
expected_result = relu(conv(inputs)) | |||||
actual_result = fused_conv(inputs) | |||||
np.testing.assert_allclose( | |||||
expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||||
) | |||||
conv.train() | |||||
bn.train() | |||||
fused_conv = fuse_conv_bn_relu_module(conv, bn, None) | |||||
fused_conv.train() | |||||
expected_result = bn(conv(inputs)) | |||||
actual_result = fused_conv(inputs) | |||||
np.testing.assert_allclose( | |||||
expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||||
) | |||||
def test_ConvTransposeBn2d_fold_weight_bias(): | |||||
in_channels = 32 | |||||
out_channels = 64 | |||||
kernel_size = 3 | |||||
conv = ConvTranspose2d(in_channels, out_channels, kernel_size) | |||||
bn = BatchNorm2d(out_channels) | |||||
relu = ReLU() | |||||
fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) | |||||
bn.eval() | |||||
fused_conv.eval() | |||||
inputs = Tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) | |||||
expected_result = relu(bn(conv(inputs))) | |||||
actual_result = fused_conv(inputs) | |||||
np.testing.assert_allclose( | |||||
expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||||
) | |||||
conv.eval() | |||||
bn.eval() | |||||
relu.eval() | |||||
fused_conv = fuse_conv_bn_relu_module(conv, bn, relu) | |||||
fused_conv.eval() | |||||
expected_result = relu(conv(inputs)) | |||||
actual_result = fused_conv(inputs) | |||||
np.testing.assert_allclose( | |||||
expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||||
) | |||||
conv.train() | |||||
bn.train() | |||||
fused_conv = fuse_conv_bn_relu_module(conv, bn, None) | |||||
fused_conv.train() | |||||
expected_result = bn(conv(inputs)) | |||||
actual_result = fused_conv(inputs) | |||||
np.testing.assert_allclose( | |||||
expected_result.numpy(), actual_result.numpy(), atol=1e-4 | |||||
) |