import copy from itertools import product import numpy as np from megengine import tensor from megengine.module import ConvBn2d from megengine.quantization import quantize_qat from megengine.quantization.quantize import disable_fake_quant from megengine.test import assertTensorClose def test_convbn2d(): in_channels = 32 out_channels = 64 kernel_size = 3 module = ConvBn2d(in_channels, out_channels, kernel_size) quantize_qat(module) for groups, bias in product([1, 4], [True, False]): inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32)) module.train() qat_module = copy.deepcopy(module) disable_fake_quant(qat_module) normal_outputs = module.forward(inputs) qat_outputs = qat_module.forward_qat(inputs) assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6) a = module.bn.running_mean.numpy() b = qat_module.bn.running_mean.numpy() assertTensorClose( module.bn.running_mean, qat_module.bn.running_mean, max_err=5e-8 ) assertTensorClose( module.bn.running_var, qat_module.bn.running_var, max_err=5e-7 ) module.eval() normal_outputs = module.forward(inputs) qat_module.eval() qat_outputs = qat_module.forward_qat(inputs) assertTensorClose(normal_outputs, qat_outputs, max_err=5e-6)