|
@@ -37,9 +37,10 @@ class FloatNet(Float.Module): |
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
self.quant = Float.QuantStub() |
|
|
self.quant = Float.QuantStub() |
|
|
self.linear = Float.Linear(3, 3) |
|
|
|
|
|
|
|
|
self.linear = Float.Sequential(Float.Linear(3, 3), Float.Linear(3, 3)) |
|
|
self.dequant = Float.DequantStub() |
|
|
self.dequant = Float.DequantStub() |
|
|
self.linear.bias[...] = Parameter(np.random.rand(3)) |
|
|
|
|
|
|
|
|
self.linear[0].bias[...] = Parameter(np.random.rand(3)) |
|
|
|
|
|
self.linear[1].bias[...] = Parameter(np.random.rand(3)) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
def forward(self, x): |
|
|
x = self.quant(x) |
|
|
x = self.quant(x) |
|
@@ -52,9 +53,10 @@ class QATNet(Float.Module): |
|
|
def __init__(self): |
|
|
def __init__(self): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
self.quant = QAT.QuantStub() |
|
|
self.quant = QAT.QuantStub() |
|
|
self.linear = QAT.Linear(3, 3) |
|
|
|
|
|
|
|
|
self.linear = Float.Sequential(QAT.Linear(3, 3), QAT.Linear(3, 3)) |
|
|
self.dequant = QAT.DequantStub() |
|
|
self.dequant = QAT.DequantStub() |
|
|
self.linear.bias[...] = Parameter(np.random.rand(3)) |
|
|
|
|
|
|
|
|
self.linear[0].bias[...] = Parameter(np.random.rand(3)) |
|
|
|
|
|
self.linear[1].bias[...] = Parameter(np.random.rand(3)) |
|
|
|
|
|
|
|
|
def forward(self, x): |
|
|
def forward(self, x): |
|
|
x = self.quant(x) |
|
|
x = self.quant(x) |
|
@@ -72,10 +74,14 @@ def test_propagate_qconfig(): |
|
|
net.quant.weight_fake_quant is None, |
|
|
net.quant.weight_fake_quant is None, |
|
|
isinstance(net.quant.act_observer, MinMaxObserver), |
|
|
isinstance(net.quant.act_observer, MinMaxObserver), |
|
|
isinstance(net.quant.act_fake_quant, FakeQuantize), |
|
|
isinstance(net.quant.act_fake_quant, FakeQuantize), |
|
|
isinstance(net.linear.weight_observer, MinMaxObserver), |
|
|
|
|
|
isinstance(net.linear.weight_fake_quant, FakeQuantize), |
|
|
|
|
|
isinstance(net.linear.act_observer, MinMaxObserver), |
|
|
|
|
|
isinstance(net.linear.act_fake_quant, FakeQuantize), |
|
|
|
|
|
|
|
|
isinstance(net.linear[0].weight_observer, MinMaxObserver), |
|
|
|
|
|
isinstance(net.linear[0].weight_fake_quant, FakeQuantize), |
|
|
|
|
|
isinstance(net.linear[0].act_observer, MinMaxObserver), |
|
|
|
|
|
isinstance(net.linear[0].act_fake_quant, FakeQuantize), |
|
|
|
|
|
isinstance(net.linear[1].weight_observer, MinMaxObserver), |
|
|
|
|
|
isinstance(net.linear[1].weight_fake_quant, FakeQuantize), |
|
|
|
|
|
isinstance(net.linear[1].act_observer, MinMaxObserver), |
|
|
|
|
|
isinstance(net.linear[1].act_fake_quant, FakeQuantize), |
|
|
net.dequant.weight_observer is None, |
|
|
net.dequant.weight_observer is None, |
|
|
net.dequant.weight_fake_quant is None, |
|
|
net.dequant.weight_fake_quant is None, |
|
|
net.dequant.act_observer is None, |
|
|
net.dequant.act_observer is None, |
|
@@ -91,10 +97,14 @@ def init_qat_net(): |
|
|
max_val = np.random.randint(1, 127, size=(3,)) |
|
|
max_val = np.random.randint(1, 127, size=(3,)) |
|
|
net.quant.act_observer.min_val[...] = Parameter(min_val[0]) |
|
|
net.quant.act_observer.min_val[...] = Parameter(min_val[0]) |
|
|
net.quant.act_observer.max_val[...] = Parameter(max_val[0]) |
|
|
net.quant.act_observer.max_val[...] = Parameter(max_val[0]) |
|
|
net.linear.weight_observer.min_val[...] = Parameter(min_val[1]) |
|
|
|
|
|
net.linear.weight_observer.max_val[...] = Parameter(max_val[1]) |
|
|
|
|
|
net.linear.act_observer.min_val[...] = Parameter(min_val[2]) |
|
|
|
|
|
net.linear.act_observer.max_val[...] = Parameter(max_val[2]) |
|
|
|
|
|
|
|
|
net.linear[0].weight_observer.min_val[...] = Parameter(min_val[1]) |
|
|
|
|
|
net.linear[0].weight_observer.max_val[...] = Parameter(max_val[1]) |
|
|
|
|
|
net.linear[0].act_observer.min_val[...] = Parameter(min_val[2]) |
|
|
|
|
|
net.linear[0].act_observer.max_val[...] = Parameter(max_val[2]) |
|
|
|
|
|
net.linear[1].weight_observer.min_val[...] = Parameter(min_val[1]) |
|
|
|
|
|
net.linear[1].weight_observer.max_val[...] = Parameter(max_val[1]) |
|
|
|
|
|
net.linear[1].act_observer.min_val[...] = Parameter(min_val[2]) |
|
|
|
|
|
net.linear[1].act_observer.max_val[...] = Parameter(max_val[2]) |
|
|
return net |
|
|
return net |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -102,11 +112,20 @@ def test_reset_qconfig(): |
|
|
qat_net = init_qat_net() |
|
|
qat_net = init_qat_net() |
|
|
new_qat_net = reset_qconfig(qat_net, passive_qconfig) |
|
|
new_qat_net = reset_qconfig(qat_net, passive_qconfig) |
|
|
assert ( |
|
|
assert ( |
|
|
new_qat_net.linear.get_weight_qparams() == qat_net.linear.get_weight_qparams() |
|
|
|
|
|
|
|
|
new_qat_net.linear[0].get_weight_qparams() |
|
|
|
|
|
== qat_net.linear[0].get_weight_qparams() |
|
|
) |
|
|
) |
|
|
assert ( |
|
|
assert ( |
|
|
new_qat_net.linear.get_activation_qparams() |
|
|
|
|
|
== qat_net.linear.get_activation_qparams() |
|
|
|
|
|
|
|
|
new_qat_net.linear[0].get_activation_qparams() |
|
|
|
|
|
== qat_net.linear[0].get_activation_qparams() |
|
|
|
|
|
) |
|
|
|
|
|
assert ( |
|
|
|
|
|
new_qat_net.linear[1].get_weight_qparams() |
|
|
|
|
|
== qat_net.linear[1].get_weight_qparams() |
|
|
|
|
|
) |
|
|
|
|
|
assert ( |
|
|
|
|
|
new_qat_net.linear[1].get_activation_qparams() |
|
|
|
|
|
== qat_net.linear[1].get_activation_qparams() |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -114,24 +133,32 @@ def test_enable_and_disable_observer(): |
|
|
net = init_qat_net() |
|
|
net = init_qat_net() |
|
|
enable_observer(net) |
|
|
enable_observer(net) |
|
|
assert net.quant.act_observer.enabled is True |
|
|
assert net.quant.act_observer.enabled is True |
|
|
assert net.linear.weight_observer.enabled is True |
|
|
|
|
|
assert net.linear.act_observer.enabled is True |
|
|
|
|
|
|
|
|
assert net.linear[0].weight_observer.enabled is True |
|
|
|
|
|
assert net.linear[0].act_observer.enabled is True |
|
|
|
|
|
assert net.linear[1].weight_observer.enabled is True |
|
|
|
|
|
assert net.linear[1].act_observer.enabled is True |
|
|
disable_observer(net) |
|
|
disable_observer(net) |
|
|
assert net.quant.act_observer.enabled is False |
|
|
assert net.quant.act_observer.enabled is False |
|
|
assert net.linear.weight_observer.enabled is False |
|
|
|
|
|
assert net.linear.act_observer.enabled is False |
|
|
|
|
|
|
|
|
assert net.linear[0].weight_observer.enabled is False |
|
|
|
|
|
assert net.linear[0].weight_observer.enabled is False |
|
|
|
|
|
assert net.linear[1].act_observer.enabled is False |
|
|
|
|
|
assert net.linear[1].act_observer.enabled is False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_enable_and_disable_fake_quant(): |
|
|
def test_enable_and_disable_fake_quant(): |
|
|
net = init_qat_net() |
|
|
net = init_qat_net() |
|
|
disable_fake_quant(net) |
|
|
disable_fake_quant(net) |
|
|
assert net.quant.act_fake_quant.enabled is False |
|
|
assert net.quant.act_fake_quant.enabled is False |
|
|
assert net.linear.weight_fake_quant.enabled is False |
|
|
|
|
|
assert net.linear.act_fake_quant.enabled is False |
|
|
|
|
|
|
|
|
assert net.linear[0].weight_fake_quant.enabled is False |
|
|
|
|
|
assert net.linear[0].act_fake_quant.enabled is False |
|
|
|
|
|
assert net.linear[1].weight_fake_quant.enabled is False |
|
|
|
|
|
assert net.linear[1].act_fake_quant.enabled is False |
|
|
enable_fake_quant(net) |
|
|
enable_fake_quant(net) |
|
|
assert net.quant.act_fake_quant.enabled is True |
|
|
assert net.quant.act_fake_quant.enabled is True |
|
|
assert net.linear.weight_fake_quant.enabled is True |
|
|
|
|
|
assert net.linear.act_fake_quant.enabled is True |
|
|
|
|
|
|
|
|
assert net.linear[0].weight_fake_quant.enabled is True |
|
|
|
|
|
assert net.linear[0].act_fake_quant.enabled is True |
|
|
|
|
|
assert net.linear[1].weight_fake_quant.enabled is True |
|
|
|
|
|
assert net.linear[1].act_fake_quant.enabled is True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_observer(module, data): |
|
|
def init_observer(module, data): |
|
@@ -165,7 +192,8 @@ def test_quantize_qat(): |
|
|
net = FloatNet() |
|
|
net = FloatNet() |
|
|
qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig) |
|
|
qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig) |
|
|
assert isinstance(qat_net.quant, QAT.QuantStub) |
|
|
assert isinstance(qat_net.quant, QAT.QuantStub) |
|
|
assert isinstance(qat_net.linear, QAT.Linear) |
|
|
|
|
|
|
|
|
assert isinstance(qat_net.linear[0], QAT.Linear) |
|
|
|
|
|
assert isinstance(qat_net.linear[1], QAT.Linear) |
|
|
assert isinstance(qat_net.dequant, QAT.DequantStub) |
|
|
assert isinstance(qat_net.dequant, QAT.DequantStub) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -173,7 +201,8 @@ def test_quantize(): |
|
|
qat_net = init_qat_net() |
|
|
qat_net = init_qat_net() |
|
|
q_net = quantize(qat_net, inplace=False) |
|
|
q_net = quantize(qat_net, inplace=False) |
|
|
assert isinstance(q_net.quant, Q.QuantStub) |
|
|
assert isinstance(q_net.quant, Q.QuantStub) |
|
|
assert isinstance(q_net.linear, Q.Linear) |
|
|
|
|
|
|
|
|
assert isinstance(q_net.linear[0], Q.Linear) |
|
|
|
|
|
assert isinstance(q_net.linear[1], Q.Linear) |
|
|
assert isinstance(q_net.dequant, Q.DequantStub) |
|
|
assert isinstance(q_net.dequant, Q.DequantStub) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -183,8 +212,10 @@ def test_apply_easy_quant(): |
|
|
eq_net = reset_qconfig(qat_net, passive_qconfig, inplace=False) |
|
|
eq_net = reset_qconfig(qat_net, passive_qconfig, inplace=False) |
|
|
apply_easy_quant(eq_net, data, 0.9, 1.1, 10) |
|
|
apply_easy_quant(eq_net, data, 0.9, 1.1, 10) |
|
|
assert isinstance(eq_net.quant.act_observer, PassiveObserver) |
|
|
assert isinstance(eq_net.quant.act_observer, PassiveObserver) |
|
|
assert isinstance(eq_net.linear.weight_observer, PassiveObserver) |
|
|
|
|
|
assert isinstance(eq_net.linear.act_observer, PassiveObserver) |
|
|
|
|
|
|
|
|
assert isinstance(eq_net.linear[0].weight_observer, PassiveObserver) |
|
|
|
|
|
assert isinstance(eq_net.linear[0].act_observer, PassiveObserver) |
|
|
|
|
|
assert isinstance(eq_net.linear[1].weight_observer, PassiveObserver) |
|
|
|
|
|
assert isinstance(eq_net.linear[1].act_observer, PassiveObserver) |
|
|
assert eq_net.dequant.act_observer is None |
|
|
assert eq_net.dequant.act_observer is None |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -192,8 +223,10 @@ def test_apply_tqt(): |
|
|
qat_net = init_qat_net() |
|
|
qat_net = init_qat_net() |
|
|
tqt_net = reset_qconfig(qat_net, tqt_qconfig, inplace=False) |
|
|
tqt_net = reset_qconfig(qat_net, tqt_qconfig, inplace=False) |
|
|
assert isinstance(tqt_net.quant.act_fake_quant, TQT) |
|
|
assert isinstance(tqt_net.quant.act_fake_quant, TQT) |
|
|
assert isinstance(tqt_net.linear.weight_fake_quant, TQT) |
|
|
|
|
|
assert isinstance(tqt_net.linear.act_fake_quant, TQT) |
|
|
|
|
|
|
|
|
assert isinstance(tqt_net.linear[0].weight_fake_quant, TQT) |
|
|
|
|
|
assert isinstance(tqt_net.linear[0].act_fake_quant, TQT) |
|
|
|
|
|
assert isinstance(tqt_net.linear[1].weight_fake_quant, TQT) |
|
|
|
|
|
assert isinstance(tqt_net.linear[1].act_fake_quant, TQT) |
|
|
assert tqt_net.dequant.act_fake_quant is None |
|
|
assert tqt_net.dequant.act_fake_quant is None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|