|
|
@@ -1,3 +1,5 @@ |
|
|
|
from functools import partial |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
import pytest |
|
|
|
|
|
|
@@ -6,17 +8,21 @@ import megengine.functional as F |
|
|
|
import megengine.module as Float |
|
|
|
import megengine.module.qat as QAT |
|
|
|
import megengine.module.quantized as Q |
|
|
|
from megengine import Parameter, Tensor |
|
|
|
from megengine.core.tensor import dtype |
|
|
|
from megengine.quantization import min_max_fakequant_qconfig |
|
|
|
from megengine.quantization import FakeQuantize, MinMaxObserver, QConfig |
|
|
|
from megengine.quantization.quantize import ( |
|
|
|
disable_fake_quant, |
|
|
|
disable_observer, |
|
|
|
propagate_qconfig, |
|
|
|
) |
|
|
|
|
|
|
|
""" |
|
|
|
Calculate testing scales based on ``min_max_fakequant_qconfig`` |
|
|
|
""" |
|
|
|
min_max_fakequant_qconfig = QConfig( |
|
|
|
weight_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=True), |
|
|
|
act_observer=partial(MinMaxObserver, dtype="qint8", narrow_range=False), |
|
|
|
weight_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=True), |
|
|
|
act_fake_quant=partial(FakeQuantize, dtype="qint8", narrow_range=False), |
|
|
|
) |
|
|
|
|
|
|
|
inp_scale = np.float32(np.random.rand() + 1) |
|
|
|
|
|
|
@@ -31,21 +37,26 @@ def quant(x, scale): |
|
|
|
return x.astype(inp_dtype) |
|
|
|
|
|
|
|
|
|
|
|
def fake_quant(x, scale): |
|
|
|
def fake_quant(x, scale, qmin, qmax): |
|
|
|
x = x / scale |
|
|
|
x = F.round(x) |
|
|
|
x = F.clip(x, -128, 127) |
|
|
|
x = F.clip(x, qmin, qmax) |
|
|
|
x = x * scale |
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
fake_quant_act = partial(fake_quant, qmin=-128, qmax=127) |
|
|
|
fake_quant_weight = partial(fake_quant, qmin=-127, qmax=127) |
|
|
|
fake_quant_bias = partial(fake_quant, qmin=-(2 ** 31), qmax=2 ** 31 - 1) |
|
|
|
|
|
|
|
|
|
|
|
def init_qat_net(net): |
|
|
|
if net.with_weight: |
|
|
|
net.weight_observer.min_val.set_value(min_val[0]) |
|
|
|
net.weight_observer.max_val.set_value(max_val[0]) |
|
|
|
net.weight_observer.min_val[...] = Tensor(min_val[0]) |
|
|
|
net.weight_observer.max_val[...] = Tensor(max_val[0]) |
|
|
|
if net.with_act: |
|
|
|
net.act_observer.min_val.set_value(min_val[1]) |
|
|
|
net.act_observer.max_val.set_value(max_val[1]) |
|
|
|
net.act_observer.min_val[...] = Tensor(min_val[1]) |
|
|
|
net.act_observer.max_val[...] = Tensor(max_val[1]) |
|
|
|
|
|
|
|
|
|
|
|
def test_quant_stub(): |
|
|
@@ -71,7 +82,7 @@ def test_quant_stub(): |
|
|
|
|
|
|
|
normal = normal_net(x) |
|
|
|
qat_without_fakequant = qat_from_float(x) |
|
|
|
fake_quant_normal = fake_quant(normal_net(x), act_scale) |
|
|
|
fake_quant_normal = fake_quant_act(normal_net(x), act_scale) |
|
|
|
qat = qat_net(x) |
|
|
|
q = q_net(x).numpy() * act_scale |
|
|
|
np.testing.assert_allclose(qat_without_fakequant, normal) |
|
|
@@ -99,7 +110,7 @@ def test_dequant_stub(): |
|
|
|
q_net.eval() |
|
|
|
|
|
|
|
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) |
|
|
|
x = fake_quant(x, inp_scale) |
|
|
|
x = fake_quant_act(x, inp_scale) |
|
|
|
x.q_dict["scale"] = inp_scale |
|
|
|
|
|
|
|
normal = normal_net(x) |
|
|
@@ -134,12 +145,12 @@ def test_elemwise(kind): |
|
|
|
|
|
|
|
x1_scale = np.float32(np.random.rand() + 1) |
|
|
|
x1 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) |
|
|
|
x1 = fake_quant(x1, x1_scale) |
|
|
|
x1 = fake_quant_act(x1, x1_scale) |
|
|
|
x1.q_dict["scale"] = x1_scale |
|
|
|
|
|
|
|
x2_scale = np.float32(np.random.rand() + 1) |
|
|
|
x2 = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) |
|
|
|
x2 = fake_quant(x2, x2_scale) |
|
|
|
x2 = fake_quant_act(x2, x2_scale) |
|
|
|
x2.q_dict["scale"] = x2_scale |
|
|
|
|
|
|
|
x1_int8 = quant(x1, x1_scale) |
|
|
@@ -149,13 +160,13 @@ def test_elemwise(kind): |
|
|
|
if kind in ("ADD", "MUL", "FUSE_ADD_RELU"): |
|
|
|
normal = normal_net(x1, x2) |
|
|
|
qat_without_fakequant = qat_from_float(x1, x2) |
|
|
|
fake_quant_normal = fake_quant(normal_net(x1, x2), act_scale) |
|
|
|
fake_quant_normal = fake_quant_act(normal_net(x1, x2), act_scale) |
|
|
|
qat = qat_net(x1, x2) |
|
|
|
q = q_net(x1_int8, x2_int8).numpy() * act_scale |
|
|
|
else: |
|
|
|
normal = normal_net(x1) |
|
|
|
qat_without_fakequant = qat_from_float(x1) |
|
|
|
fake_quant_normal = fake_quant(normal_net(x1), act_scale) |
|
|
|
fake_quant_normal = fake_quant_act(normal_net(x1), act_scale) |
|
|
|
qat = qat_net(x1) |
|
|
|
q = q_net(x1_int8).numpy() * act_scale |
|
|
|
np.testing.assert_allclose(qat_without_fakequant, normal) |
|
|
@@ -175,17 +186,17 @@ def test_linear(): |
|
|
|
init_qat_net(qat_net) |
|
|
|
|
|
|
|
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) |
|
|
|
x = fake_quant(x, inp_scale) |
|
|
|
x = fake_quant_act(x, inp_scale) |
|
|
|
x.q_dict["scale"] = inp_scale |
|
|
|
|
|
|
|
x_int8 = quant(x, inp_scale) |
|
|
|
|
|
|
|
weight = np.random.normal(size=(3, 3)).astype("float32") |
|
|
|
bias = np.random.normal(size=(3,)).astype("float32") |
|
|
|
normal_net.weight.set_value(fake_quant(weight, weight_scale)) |
|
|
|
normal_net.bias.set_value(fake_quant(bias, inp_scale * weight_scale)) |
|
|
|
qat_net.weight.set_value(weight) |
|
|
|
qat_net.bias.set_value(bias) |
|
|
|
normal_net.weight[...] = fake_quant_weight(weight, weight_scale) |
|
|
|
normal_net.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale) |
|
|
|
qat_net.weight[...] = Parameter(weight) |
|
|
|
qat_net.bias[...] = Parameter(bias) |
|
|
|
|
|
|
|
qat_from_float = QAT.Linear.from_float_module(normal_net) |
|
|
|
qat_from_float.eval() |
|
|
@@ -197,11 +208,11 @@ def test_linear(): |
|
|
|
|
|
|
|
normal = normal_net(x) |
|
|
|
qat_without_fakequant = qat_from_float(x) |
|
|
|
fake_quant_normal = fake_quant(normal_net(x), act_scale) |
|
|
|
fake_quant_normal = fake_quant_act(normal_net(x), act_scale) |
|
|
|
qat = qat_net(x) |
|
|
|
q = q_net(x_int8).numpy() * act_scale |
|
|
|
np.testing.assert_allclose(qat_without_fakequant, normal) |
|
|
|
np.testing.assert_allclose(qat, fake_quant_normal) |
|
|
|
np.testing.assert_allclose(qat, fake_quant_normal.numpy()) |
|
|
|
np.testing.assert_allclose(q, fake_quant_normal.numpy()) |
|
|
|
|
|
|
|
|
|
|
@@ -218,7 +229,7 @@ def test_conv(module): |
|
|
|
init_qat_net(qat_net) |
|
|
|
|
|
|
|
x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32")) |
|
|
|
x = fake_quant(x, inp_scale) |
|
|
|
x = fake_quant_act(x, inp_scale) |
|
|
|
x.q_dict["scale"] = inp_scale |
|
|
|
|
|
|
|
x_int8 = quant(x, inp_scale) |
|
|
@@ -226,15 +237,15 @@ def test_conv(module): |
|
|
|
weight = np.random.normal(size=(3, 3, 3, 3)).astype("float32") |
|
|
|
bias = np.random.normal(size=(1, 3, 1, 1)).astype("float32") |
|
|
|
if module in ("ConvBn2d", "ConvBnRelu2d"): |
|
|
|
normal_net.conv.weight.set_value(fake_quant(weight, weight_scale)) |
|
|
|
normal_net.conv.bias.set_value(fake_quant(bias, inp_scale * weight_scale)) |
|
|
|
qat_net.conv.weight.set_value(weight) |
|
|
|
qat_net.conv.bias.set_value(bias) |
|
|
|
normal_net.conv.weight[...] = fake_quant_weight(weight, weight_scale) |
|
|
|
normal_net.conv.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale) |
|
|
|
qat_net.conv.weight[...] = Parameter(weight) |
|
|
|
qat_net.conv.bias[...] = Parameter(bias) |
|
|
|
else: |
|
|
|
normal_net.weight.set_value(fake_quant(weight, weight_scale)) |
|
|
|
normal_net.bias.set_value(fake_quant(bias, inp_scale * weight_scale)) |
|
|
|
qat_net.weight.set_value(weight) |
|
|
|
qat_net.bias.set_value(bias) |
|
|
|
normal_net.weight[...] = fake_quant_weight(weight, weight_scale) |
|
|
|
normal_net.bias[...] = fake_quant_bias(bias, inp_scale * weight_scale) |
|
|
|
qat_net.weight[...] = Parameter(weight) |
|
|
|
qat_net.bias[...] = Parameter(bias) |
|
|
|
|
|
|
|
qat_from_float = getattr(QAT, module).from_float_module(normal_net) |
|
|
|
qat_from_float.eval() |
|
|
@@ -246,9 +257,9 @@ def test_conv(module): |
|
|
|
|
|
|
|
normal = normal_net(x) |
|
|
|
qat_without_fakequant = qat_from_float(x) |
|
|
|
fake_quant_normal = fake_quant(normal_net(x), act_scale) |
|
|
|
fake_quant_normal = fake_quant_act(normal_net(x), act_scale) |
|
|
|
qat = qat_net(x) |
|
|
|
q = q_net(x_int8).numpy() * act_scale |
|
|
|
np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-6) |
|
|
|
np.testing.assert_allclose(qat, fake_quant_normal) |
|
|
|
np.testing.assert_allclose(q, fake_quant_normal.numpy()) |
|
|
|
np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-5) |
|
|
|
np.testing.assert_allclose(qat, fake_quant_normal, atol=act_scale) |
|
|
|
np.testing.assert_allclose(q, fake_quant_normal.numpy(), atol=act_scale) |