Browse Source

test(mge/quantization): classmethod `from_float_module` of qat module

GitOrigin-RevId: 95c3d45f83
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
0b4918b25e
1 changed files with 80 additions and 29 deletions
  1. +80
    -29
      imperative/python/test/unit/quantization/test_module.py

+ 80
- 29
imperative/python/test/unit/quantization/test_module.py View File

@@ -8,7 +8,11 @@ import megengine.module.qat as QAT
import megengine.module.quantized as Q
from megengine.core.tensor import dtype
from megengine.quantization import min_max_fakequant_qconfig
from megengine.quantization.quantize import disable_observer, propagate_qconfig
from megengine.quantization.quantize import (
disable_fake_quant,
disable_observer,
propagate_qconfig,
)

"""
Calculate testing scales based on ``min_max_fakequant_qconfig``
@@ -47,6 +51,12 @@ def init_qat_net(net):
def test_quant_stub():
normal_net = Float.QuantStub()
normal_net.eval()

qat_from_float = QAT.QuantStub.from_float_module(normal_net)
qat_from_float.eval()
disable_observer(qat_from_float)
disable_fake_quant(qat_from_float)

qat_net = QAT.QuantStub()
qat_net.eval()
disable_observer(qat_net)
@@ -59,16 +69,25 @@ def test_quant_stub():

x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))

normal_out = fake_quant(normal_net(x), act_scale)
qat_out = qat_net(x)
q_out = q_net(x).numpy() * act_scale
np.testing.assert_allclose(qat_out, normal_out)
np.testing.assert_allclose(q_out, normal_out.numpy())
normal = normal_net(x)
qat_without_fakequant = qat_from_float(x)
fake_quant_normal = fake_quant(normal_net(x), act_scale)
qat = qat_net(x)
q = q_net(x).numpy() * act_scale
np.testing.assert_allclose(qat_without_fakequant, normal)
np.testing.assert_allclose(qat, fake_quant_normal)
np.testing.assert_allclose(q, fake_quant_normal.numpy())


def test_dequant_stub():
normal_net = Float.DequantStub()
normal_net.eval()

qat_from_float = QAT.DequantStub.from_float_module(normal_net)
qat_from_float.eval()
disable_fake_quant(qat_from_float)
disable_observer(qat_from_float)

qat_net = QAT.DequantStub()
qat_net.eval()
disable_observer(qat_net)
@@ -83,17 +102,26 @@ def test_dequant_stub():
x = fake_quant(x, inp_scale)
x.q_dict["scale"] = inp_scale

normal_out = normal_net(x)
qat_out = qat_net(x)
q_out = q_net(quant(x, inp_scale)).numpy()
np.testing.assert_allclose(qat_out, normal_out)
np.testing.assert_allclose(q_out, normal_out.numpy())
normal = normal_net(x)
qat_without_fakequant = qat_from_float(x)
fake_quant_normal = normal_net(x)
qat = qat_net(x)
q = q_net(quant(x, inp_scale)).numpy()
np.testing.assert_allclose(qat_without_fakequant, normal)
np.testing.assert_allclose(qat, fake_quant_normal)
np.testing.assert_allclose(q, fake_quant_normal.numpy())


@pytest.mark.parametrize("kind", ["COS", "RELU", "ADD", "MUL", "FUSE_ADD_RELU"])
def test_elemwise(kind):
normal_net = Float.Elemwise(kind)
normal_net.eval()

qat_from_float = QAT.Elemwise.from_float_module(normal_net)
qat_from_float.eval()
disable_observer(qat_from_float)
disable_fake_quant(qat_from_float)

qat_net = QAT.Elemwise(kind)
qat_net.eval()
disable_observer(qat_net)
@@ -117,16 +145,22 @@ def test_elemwise(kind):
x1_int8 = quant(x1, x1_scale)
x2_int8 = quant(x2, x2_scale)

# test correctness of `Float`, `QAT` and `Quantized`
if kind in ("ADD", "MUL", "FUSE_ADD_RELU"):
normal_out = fake_quant(normal_net(x1, x2), act_scale)
qat_out = qat_net(x1, x2)
q_out = q_net(x1_int8, x2_int8).numpy() * act_scale
normal = normal_net(x1, x2)
qat_without_fakequant = qat_from_float(x1, x2)
fake_quant_normal = fake_quant(normal_net(x1, x2), act_scale)
qat = qat_net(x1, x2)
q = q_net(x1_int8, x2_int8).numpy() * act_scale
else:
normal_out = fake_quant(normal_net(x1), act_scale)
qat_out = qat_net(x1)
q_out = q_net(x1_int8).numpy() * act_scale
np.testing.assert_allclose(qat_out, normal_out)
np.testing.assert_allclose(q_out, normal_out.numpy())
normal = normal_net(x1)
qat_without_fakequant = qat_from_float(x1)
fake_quant_normal = fake_quant(normal_net(x1), act_scale)
qat = qat_net(x1)
q = q_net(x1_int8).numpy() * act_scale
np.testing.assert_allclose(qat_without_fakequant, normal)
np.testing.assert_allclose(qat, fake_quant_normal)
np.testing.assert_allclose(q, fake_quant_normal.numpy())


def test_linear():
@@ -153,20 +187,29 @@ def test_linear():
qat_net.weight.set_value(weight)
qat_net.bias.set_value(bias)

qat_from_float = QAT.Linear.from_float_module(normal_net)
qat_from_float.eval()
disable_fake_quant(qat_from_float)
disable_observer(qat_from_float)

q_net = Q.Linear.from_qat_module(qat_net)
q_net.eval()

normal_out = fake_quant(normal_net(x), act_scale)
qat_out = qat_net(x)
q_out = q_net(x_int8).numpy() * act_scale
np.testing.assert_allclose(qat_out, normal_out)
np.testing.assert_allclose(q_out, normal_out.numpy())
normal = normal_net(x)
qat_without_fakequant = qat_from_float(x)
fake_quant_normal = fake_quant(normal_net(x), act_scale)
qat = qat_net(x)
q = q_net(x_int8).numpy() * act_scale
np.testing.assert_allclose(qat_without_fakequant, normal)
np.testing.assert_allclose(qat, fake_quant_normal)
np.testing.assert_allclose(q, fake_quant_normal.numpy())


@pytest.mark.parametrize("module", ["Conv2d", "ConvBn2d", "ConvBnRelu2d"])
def test_conv(module):
normal_net = getattr(Float, module)(3, 3, 3, 1, 1, 1, bias=True)
normal_net.eval()

qat_net = getattr(QAT, module)(3, 3, 3, 1, 1, 1, bias=True)
qat_net.eval()
disable_observer(qat_net)
@@ -193,11 +236,19 @@ def test_conv(module):
qat_net.weight.set_value(weight)
qat_net.bias.set_value(bias)

qat_from_float = getattr(QAT, module).from_float_module(normal_net)
qat_from_float.eval()
disable_observer(qat_from_float)
disable_fake_quant(qat_from_float)

q_net = getattr(Q, module).from_qat_module(qat_net)
q_net.eval()

normal_out = fake_quant(normal_net(x), act_scale)
qat_out = qat_net(x)
q_out = q_net(x_int8).numpy() * act_scale
np.testing.assert_allclose(qat_out, normal_out)
np.testing.assert_allclose(q_out, normal_out.numpy())
normal = normal_net(x)
qat_without_fakequant = qat_from_float(x)
fake_quant_normal = fake_quant(normal_net(x), act_scale)
qat = qat_net(x)
q = q_net(x_int8).numpy() * act_scale
np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-6)
np.testing.assert_allclose(qat, fake_quant_normal)
np.testing.assert_allclose(q, fake_quant_normal.numpy())

Loading…
Cancel
Save