|
|
@@ -8,7 +8,11 @@ import megengine.module.qat as QAT |
|
|
|
import megengine.module.quantized as Q |
|
|
|
from megengine.core.tensor import dtype |
|
|
|
from megengine.quantization import min_max_fakequant_qconfig |
|
|
|
from megengine.quantization.quantize import disable_observer, propagate_qconfig |
|
|
|
from megengine.quantization.quantize import ( |
|
|
|
disable_fake_quant, |
|
|
|
disable_observer, |
|
|
|
propagate_qconfig, |
|
|
|
) |
|
|
|
|
|
|
|
""" |
|
|
|
Calculate testing scales based on ``min_max_fakequant_qconfig`` |
|
|
@@ -47,6 +51,12 @@ def init_qat_net(net): |
|
|
|
def test_quant_stub(): |
|
|
|
normal_net = Float.QuantStub() |
|
|
|
normal_net.eval() |
|
|
|
|
|
|
|
qat_from_float = QAT.QuantStub.from_float_module(normal_net) |
|
|
|
qat_from_float.eval() |
|
|
|
disable_observer(qat_from_float) |
|
|
|
disable_fake_quant(qat_from_float) |
|
|
|
|
|
|
|
qat_net = QAT.QuantStub() |
|
|
|
qat_net.eval() |
|
|
|
disable_observer(qat_net) |
|
|
@@ -59,16 +69,25 @@ def test_quant_stub(): |
|
|
|
|
|
|
|
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) |
|
|
|
|
|
|
|
normal_out = fake_quant(normal_net(x), act_scale) |
|
|
|
qat_out = qat_net(x) |
|
|
|
q_out = q_net(x).numpy() * act_scale |
|
|
|
np.testing.assert_allclose(qat_out, normal_out) |
|
|
|
np.testing.assert_allclose(q_out, normal_out.numpy()) |
|
|
|
normal = normal_net(x) |
|
|
|
qat_without_fakequant = qat_from_float(x) |
|
|
|
fake_quant_normal = fake_quant(normal_net(x), act_scale) |
|
|
|
qat = qat_net(x) |
|
|
|
q = q_net(x).numpy() * act_scale |
|
|
|
np.testing.assert_allclose(qat_without_fakequant, normal) |
|
|
|
np.testing.assert_allclose(qat, fake_quant_normal) |
|
|
|
np.testing.assert_allclose(q, fake_quant_normal.numpy()) |
|
|
|
|
|
|
|
|
|
|
|
def test_dequant_stub(): |
|
|
|
normal_net = Float.DequantStub() |
|
|
|
normal_net.eval() |
|
|
|
|
|
|
|
qat_from_float = QAT.DequantStub.from_float_module(normal_net) |
|
|
|
qat_from_float.eval() |
|
|
|
disable_fake_quant(qat_from_float) |
|
|
|
disable_observer(qat_from_float) |
|
|
|
|
|
|
|
qat_net = QAT.DequantStub() |
|
|
|
qat_net.eval() |
|
|
|
disable_observer(qat_net) |
|
|
@@ -83,17 +102,26 @@ def test_dequant_stub(): |
|
|
|
x = fake_quant(x, inp_scale) |
|
|
|
x.q_dict["scale"] = inp_scale |
|
|
|
|
|
|
|
normal_out = normal_net(x) |
|
|
|
qat_out = qat_net(x) |
|
|
|
q_out = q_net(quant(x, inp_scale)).numpy() |
|
|
|
np.testing.assert_allclose(qat_out, normal_out) |
|
|
|
np.testing.assert_allclose(q_out, normal_out.numpy()) |
|
|
|
normal = normal_net(x) |
|
|
|
qat_without_fakequant = qat_from_float(x) |
|
|
|
fake_quant_normal = normal_net(x) |
|
|
|
qat = qat_net(x) |
|
|
|
q = q_net(quant(x, inp_scale)).numpy() |
|
|
|
np.testing.assert_allclose(qat_without_fakequant, normal) |
|
|
|
np.testing.assert_allclose(qat, fake_quant_normal) |
|
|
|
np.testing.assert_allclose(q, fake_quant_normal.numpy()) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("kind", ["COS", "RELU", "ADD", "MUL", "FUSE_ADD_RELU"]) |
|
|
|
def test_elemwise(kind): |
|
|
|
normal_net = Float.Elemwise(kind) |
|
|
|
normal_net.eval() |
|
|
|
|
|
|
|
qat_from_float = QAT.Elemwise.from_float_module(normal_net) |
|
|
|
qat_from_float.eval() |
|
|
|
disable_observer(qat_from_float) |
|
|
|
disable_fake_quant(qat_from_float) |
|
|
|
|
|
|
|
qat_net = QAT.Elemwise(kind) |
|
|
|
qat_net.eval() |
|
|
|
disable_observer(qat_net) |
|
|
@@ -117,16 +145,22 @@ def test_elemwise(kind): |
|
|
|
x1_int8 = quant(x1, x1_scale) |
|
|
|
x2_int8 = quant(x2, x2_scale) |
|
|
|
|
|
|
|
# test correctness of `Float`, `QAT` and `Quantized` |
|
|
|
if kind in ("ADD", "MUL", "FUSE_ADD_RELU"): |
|
|
|
normal_out = fake_quant(normal_net(x1, x2), act_scale) |
|
|
|
qat_out = qat_net(x1, x2) |
|
|
|
q_out = q_net(x1_int8, x2_int8).numpy() * act_scale |
|
|
|
normal = normal_net(x1, x2) |
|
|
|
qat_without_fakequant = qat_from_float(x1, x2) |
|
|
|
fake_quant_normal = fake_quant(normal_net(x1, x2), act_scale) |
|
|
|
qat = qat_net(x1, x2) |
|
|
|
q = q_net(x1_int8, x2_int8).numpy() * act_scale |
|
|
|
else: |
|
|
|
normal_out = fake_quant(normal_net(x1), act_scale) |
|
|
|
qat_out = qat_net(x1) |
|
|
|
q_out = q_net(x1_int8).numpy() * act_scale |
|
|
|
np.testing.assert_allclose(qat_out, normal_out) |
|
|
|
np.testing.assert_allclose(q_out, normal_out.numpy()) |
|
|
|
normal = normal_net(x1) |
|
|
|
qat_without_fakequant = qat_from_float(x1) |
|
|
|
fake_quant_normal = fake_quant(normal_net(x1), act_scale) |
|
|
|
qat = qat_net(x1) |
|
|
|
q = q_net(x1_int8).numpy() * act_scale |
|
|
|
np.testing.assert_allclose(qat_without_fakequant, normal) |
|
|
|
np.testing.assert_allclose(qat, fake_quant_normal) |
|
|
|
np.testing.assert_allclose(q, fake_quant_normal.numpy()) |
|
|
|
|
|
|
|
|
|
|
|
def test_linear(): |
|
|
@@ -153,20 +187,29 @@ def test_linear(): |
|
|
|
qat_net.weight.set_value(weight) |
|
|
|
qat_net.bias.set_value(bias) |
|
|
|
|
|
|
|
qat_from_float = QAT.Linear.from_float_module(normal_net) |
|
|
|
qat_from_float.eval() |
|
|
|
disable_fake_quant(qat_from_float) |
|
|
|
disable_observer(qat_from_float) |
|
|
|
|
|
|
|
q_net = Q.Linear.from_qat_module(qat_net) |
|
|
|
q_net.eval() |
|
|
|
|
|
|
|
normal_out = fake_quant(normal_net(x), act_scale) |
|
|
|
qat_out = qat_net(x) |
|
|
|
q_out = q_net(x_int8).numpy() * act_scale |
|
|
|
np.testing.assert_allclose(qat_out, normal_out) |
|
|
|
np.testing.assert_allclose(q_out, normal_out.numpy()) |
|
|
|
normal = normal_net(x) |
|
|
|
qat_without_fakequant = qat_from_float(x) |
|
|
|
fake_quant_normal = fake_quant(normal_net(x), act_scale) |
|
|
|
qat = qat_net(x) |
|
|
|
q = q_net(x_int8).numpy() * act_scale |
|
|
|
np.testing.assert_allclose(qat_without_fakequant, normal) |
|
|
|
np.testing.assert_allclose(qat, fake_quant_normal) |
|
|
|
np.testing.assert_allclose(q, fake_quant_normal.numpy()) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize("module", ["Conv2d", "ConvBn2d", "ConvBnRelu2d"]) |
|
|
|
def test_conv(module): |
|
|
|
normal_net = getattr(Float, module)(3, 3, 3, 1, 1, 1, bias=True) |
|
|
|
normal_net.eval() |
|
|
|
|
|
|
|
qat_net = getattr(QAT, module)(3, 3, 3, 1, 1, 1, bias=True) |
|
|
|
qat_net.eval() |
|
|
|
disable_observer(qat_net) |
|
|
@@ -193,11 +236,19 @@ def test_conv(module): |
|
|
|
qat_net.weight.set_value(weight) |
|
|
|
qat_net.bias.set_value(bias) |
|
|
|
|
|
|
|
qat_from_float = getattr(QAT, module).from_float_module(normal_net) |
|
|
|
qat_from_float.eval() |
|
|
|
disable_observer(qat_from_float) |
|
|
|
disable_fake_quant(qat_from_float) |
|
|
|
|
|
|
|
q_net = getattr(Q, module).from_qat_module(qat_net) |
|
|
|
q_net.eval() |
|
|
|
|
|
|
|
normal_out = fake_quant(normal_net(x), act_scale) |
|
|
|
qat_out = qat_net(x) |
|
|
|
q_out = q_net(x_int8).numpy() * act_scale |
|
|
|
np.testing.assert_allclose(qat_out, normal_out) |
|
|
|
np.testing.assert_allclose(q_out, normal_out.numpy()) |
|
|
|
normal = normal_net(x) |
|
|
|
qat_without_fakequant = qat_from_float(x) |
|
|
|
fake_quant_normal = fake_quant(normal_net(x), act_scale) |
|
|
|
qat = qat_net(x) |
|
|
|
q = q_net(x_int8).numpy() * act_scale |
|
|
|
np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-6) |
|
|
|
np.testing.assert_allclose(qat, fake_quant_normal) |
|
|
|
np.testing.assert_allclose(q, fake_quant_normal.numpy()) |