|
- import io
- from itertools import product
-
- import numpy as np
- import pytest
-
- import megengine.utils.comp_graph_tools as cgtools
- from megengine import jit, tensor
- from megengine.device import get_device_count
- from megengine.functional import expand_dims
- from megengine.module import (
- BatchMatMulActivation,
- Conv2d,
- ConvBn2d,
- ConvRelu2d,
- ConvTranspose2d,
- DequantStub,
- Module,
- QuantStub,
- )
- from megengine.quantization.quantize import (
- disable_fake_quant,
- enable_fake_quant,
- quantize,
- quantize_qat,
- )
-
-
- def test_qat_convbn2d():
- in_channels = 32
- out_channels = 64
- kernel_size = 3
- for groups, bias in product([1, 4], [True, False]):
- module = ConvBn2d(
- in_channels, out_channels, kernel_size, groups=groups, bias=bias
- )
- module.train()
- qat_module = quantize_qat(module, inplace=False)
- disable_fake_quant(qat_module)
- inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
- normal_outputs = module(inputs)
- qat_outputs = qat_module(inputs)
- np.testing.assert_allclose(
- normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
- )
- np.testing.assert_allclose(
- module.bn.running_mean.numpy(),
- qat_module.bn.running_mean.numpy(),
- atol=5e-8,
- )
- np.testing.assert_allclose(
- module.bn.running_var.numpy(), qat_module.bn.running_var.numpy(), atol=5e-7,
- )
- module.eval()
- normal_outputs = module(inputs)
- qat_module.eval()
- qat_outputs = qat_module(inputs)
- np.testing.assert_allclose(
- normal_outputs.numpy(), qat_outputs.numpy(), atol=5e-6
- )
-
-
- def test_qat_conv():
-
- in_channels = 32
- out_channels = 64
- kernel_size = 3
-
- class TestNet(Module):
- def __init__(self, groups, bias):
- super().__init__()
- self.quant = QuantStub()
- self.dequant = DequantStub()
- self.conv = Conv2d(
- in_channels, out_channels, kernel_size, groups=groups, bias=bias
- )
- self.conv_relu = ConvRelu2d(
- out_channels, in_channels, kernel_size, groups=groups, bias=bias
- )
-
- def forward(self, inp):
- out = self.quant(inp)
- out = self.conv(out)
- out = self.conv_relu(out)
- out = self.dequant(out)
- return out
-
- inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
- for groups, bias in product([1, 4], [True, False]):
- net = TestNet(groups, bias)
- net.train()
- qat_net = quantize_qat(net, inplace=False)
- disable_fake_quant(qat_net)
- normal_outputs = net(inputs)
- qat_outputs = qat_net(inputs)
- np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
-
- net.eval()
- normal_outputs = net(inputs)
- qat_net.eval()
- qat_outputs = qat_net(inputs)
- np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
-
-
- @pytest.mark.skipif(get_device_count("gpu") > 0, reason="no int8 algorithm on cuda")
- def test_qat_batchmatmul_activation():
- batch = 4
- in_features = 8
- out_features = 4
-
- class TestNet(Module):
- def __init__(self, bias):
- super().__init__()
- self.quant = QuantStub()
- self.dequant = DequantStub()
- self.batch_mm = BatchMatMulActivation(
- batch, in_features, out_features, bias=bias
- )
-
- def forward(self, inp):
- out = self.quant(inp)
- out = self.batch_mm(out)
- out = self.dequant(out)
- return out
-
- inputs = tensor(
- np.random.randn(batch, in_features, out_features).astype(np.float32)
- )
- for bias in (True, False):
- net = TestNet(bias)
- net.train()
- qat_net = quantize_qat(net, inplace=False)
- disable_fake_quant(qat_net)
- normal_outputs = net(inputs)
- qat_outputs = qat_net(inputs)
- np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
-
- net.eval()
- normal_outputs = net(inputs)
- qat_net.eval()
- qat_outputs = qat_net(inputs)
- np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
-
-
- @pytest.mark.skip(reason="FIXME: abnormal exit")
- def test_quantize_batchmatmul_activation():
- batch = 4
- in_features = 8
- out_features = 4
-
- class TestNet(Module):
- def __init__(self, bias):
- super().__init__()
- self.quant = QuantStub()
- self.dequant = DequantStub()
- self.batch_mm = BatchMatMulActivation(
- batch, in_features, out_features, bias=bias
- )
-
- def forward(self, inp):
- out = self.quant(inp)
- out = self.batch_mm(out)
- out = expand_dims(out, -1)
- out = self.dequant(out)
- return out
-
- inputs = tensor(
- np.random.randn(batch, in_features, out_features).astype(np.float32)
- )
- for bias in (True, False):
- net = TestNet(bias)
- net.train()
- qat_net = quantize_qat(net, inplace=False)
- disable_fake_quant(qat_net)
- normal_outputs = net(inputs)
- qat_outputs = qat_net(inputs)
- np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
-
- net.eval()
- normal_outputs = net(inputs)
- qat_net.eval()
- qat_outputs = qat_net(inputs)
- np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
-
- enable_fake_quant(qat_net)
- qat_outputs = qat_net(inputs)
- qnet = quantize(qat_net, inplace=False)
- qnet.eval()
- quantize_outputs = qnet(inputs)
- np.testing.assert_allclose(
- qat_outputs.numpy(), quantize_outputs.numpy(), atol=1e-6
- )
-
- @jit.trace(capture_as_const=True)
- def f(x):
- qnet.eval()
- return qnet(x)
-
- f(inputs)
- file = io.BytesIO()
- f.dump(file, enable_nchw4=True)
- file.seek(0)
- infer_cg = cgtools.GraphInference(file)[0]
- dumped_outputs = list(infer_cg.run(inputs.numpy()).values())[0]
- np.testing.assert_allclose(quantize_outputs.numpy(), dumped_outputs, atol=1e-6)
-
-
- def test_qat_conv_transpose2d():
- in_channels = 32
- out_channels = 64
- kernel_size = 3
-
- class TestNet(Module):
- def __init__(self, bias):
- super().__init__()
- self.quant = QuantStub()
- self.dequant = DequantStub()
- self.conv = ConvTranspose2d(
- in_channels, out_channels, kernel_size, bias=bias
- )
-
- def forward(self, inp):
- out = self.quant(inp)
- out = self.conv(out)
- out = self.dequant(out)
- return out
-
- inputs = tensor(np.random.randn(4, in_channels, 32, 32).astype(np.float32))
- for bias in [True, False]:
- net = TestNet(bias)
- net.train()
- qat_net = quantize_qat(net, inplace=False)
- disable_fake_quant(qat_net)
- normal_outputs = net(inputs)
- qat_outputs = qat_net(inputs)
- np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
-
- net.eval()
- normal_outputs = net(inputs)
- qat_net.eval()
- qat_outputs = qat_net(inputs)
- np.testing.assert_allclose(normal_outputs.numpy(), qat_outputs.numpy())
|