import copy from itertools import product import numpy as np from megengine import tensor from megengine.module import ConvBn2d from megengine.quantization.quantize import disable_fake_quant, quantize_qat from megengine.test import assertTensorClose def test_convbn2d(): in_channels = 32 out_channels = 64 kernel_size = 3 for groups, bias in product([1, 4], [True, False]): module = ConvBn2d( in_channels, out_channels, kernel_size, groups=groups, bias=bias ) module.train() qat_module = quantize_qat(module, inplace=False) disable_fake_quant(qat_module) inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) normal_outputs = module(inputs) qat_outputs = qat_module(inputs) assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) assertTensorClose( module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8 ) assertTensorClose( module.bn.running_var, qat_module.bn.running_var, max_err=5e-7 ) module.eval() normal_outputs = module(inputs) qat_module.eval() qat_outputs = qat_module(inputs) assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)