|
|
@@ -30,7 +30,10 @@ min_max_fakequant_qconfig = QConfig( |
|
|
|
act_fake_quant=partial(FakeQuantize, dtype="qint8"), |
|
|
|
) |
|
|
|
|
|
|
|
inp_scale = np.float32(np.random.rand() + 1) |
|
|
|
|
|
|
|
def gen_inp_scale(): |
|
|
|
return np.float32(np.random.rand() + 1) |
|
|
|
|
|
|
|
|
|
|
|
min_val = np.random.randint(-127, 0, size=(2,)).astype("float32") |
|
|
|
max_val = np.random.randint(1, 127, size=(2,)).astype("float32") |
|
|
@@ -116,6 +119,7 @@ def test_dequant_stub(): |
|
|
|
q_net.eval() |
|
|
|
|
|
|
|
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) |
|
|
|
inp_scale = gen_inp_scale() |
|
|
|
x = fake_quant_act(x, inp_scale) |
|
|
|
x.qparams.scale = inp_scale |
|
|
|
|
|
|
@@ -192,6 +196,7 @@ def test_linear(): |
|
|
|
init_qat_net(qat_net) |
|
|
|
|
|
|
|
x = mge.tensor(np.random.normal(size=(3, 3)).astype("float32")) |
|
|
|
inp_scale = gen_inp_scale() |
|
|
|
x = fake_quant_act(x, inp_scale) |
|
|
|
x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) |
|
|
|
|
|
|
@@ -235,6 +240,7 @@ def test_conv(module): |
|
|
|
init_qat_net(qat_net) |
|
|
|
|
|
|
|
x = mge.tensor(np.random.normal(size=(1, 3, 3, 3)).astype("float32")) |
|
|
|
inp_scale = gen_inp_scale() |
|
|
|
x = fake_quant_act(x, inp_scale) |
|
|
|
x.qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) |
|
|
|
|
|
|
@@ -269,3 +275,41 @@ def test_conv(module): |
|
|
|
np.testing.assert_allclose(qat_without_fakequant, normal, atol=1e-5) |
|
|
|
np.testing.assert_allclose(qat, fake_quant_normal, atol=act_scale) |
|
|
|
np.testing.assert_allclose(q, fake_quant_normal.numpy(), atol=act_scale) |
|
|
|
|
|
|
|
|
|
|
|
def test_concat(): |
|
|
|
normal_net = Float.Concat() |
|
|
|
normal_net.eval() |
|
|
|
|
|
|
|
qat_net = QAT.Concat() |
|
|
|
qat_net.eval() |
|
|
|
disable_observer(qat_net) |
|
|
|
|
|
|
|
propagate_qconfig(qat_net, min_max_fakequant_qconfig) |
|
|
|
init_qat_net(qat_net) |
|
|
|
|
|
|
|
inps = [] |
|
|
|
inps_int8 = [] |
|
|
|
for i in range(3): |
|
|
|
inp_scale = gen_inp_scale() |
|
|
|
inps.append(mge.tensor(np.random.normal(size=(3, 3)).astype("float32"))) |
|
|
|
inps[i] = fake_quant_act(inps[i], inp_scale) |
|
|
|
inps[i].qparams.update(create_qparams(QuantMode.SYMMERTIC, "qint8", inp_scale)) |
|
|
|
inps_int8.append(quant(inps[i], inp_scale)) |
|
|
|
|
|
|
|
qat_from_float = QAT.Concat.from_float_module(normal_net) |
|
|
|
qat_from_float.eval() |
|
|
|
disable_fake_quant(qat_from_float) |
|
|
|
disable_observer(qat_from_float) |
|
|
|
|
|
|
|
q_net = Q.Concat.from_qat_module(qat_net) |
|
|
|
q_net.eval() |
|
|
|
|
|
|
|
normal = normal_net(inps) |
|
|
|
qat_without_fakequant = qat_from_float(inps) |
|
|
|
fake_quant_normal = fake_quant_act(normal_net(inps), act_scale) |
|
|
|
qat = qat_net(inps) |
|
|
|
q = q_net(inps_int8).numpy() * act_scale |
|
|
|
np.testing.assert_allclose(qat_without_fakequant, normal) |
|
|
|
np.testing.assert_allclose(qat, fake_quant_normal.numpy()) |
|
|
|
np.testing.assert_allclose(q, fake_quant_normal.numpy()) |