From 9c90ce8c330306c0b66a36c8f5bc03a4d03caec8 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 16 Mar 2021 15:15:51 +0800 Subject: [PATCH] fix(mge/quantization): fix `quantize` and `quantize_qat`'s `set_expand_structure` arguments GitOrigin-RevId: c61633095d7371728be14fd260c1a4de7f3bbd92 --- .../python/megengine/quantization/quantize.py | 4 +- .../python/test/unit/quantization/test_quantize.py | 91 +++++++++++++++------- 2 files changed, 64 insertions(+), 31 deletions(-) diff --git a/imperative/python/megengine/quantization/quantize.py b/imperative/python/megengine/quantization/quantize.py index 7d39c1d2..1011b867 100644 --- a/imperative/python/megengine/quantization/quantize.py +++ b/imperative/python/megengine/quantization/quantize.py @@ -80,7 +80,7 @@ def quantize(module: Module, inplace: bool = True, mapping: dict = None): module._flatten(with_key=True, with_parent=True, predicate=is_qat) ): new_mod = convert_dict[type(submodule)].from_qat_module(submodule) - set_expand_structure(parent, key, new_mod) + set_expand_structure(module, key, new_mod) return module @@ -123,7 +123,7 @@ def quantize_qat( continue new_mod = convert_dict[type(submodule)].from_float_module(submodule) - set_expand_structure(parent, key, new_mod) + set_expand_structure(module, key, new_mod) propagate_qconfig(module, qconfig) return module diff --git a/imperative/python/test/unit/quantization/test_quantize.py b/imperative/python/test/unit/quantization/test_quantize.py index c77ba782..13abc43b 100644 --- a/imperative/python/test/unit/quantization/test_quantize.py +++ b/imperative/python/test/unit/quantization/test_quantize.py @@ -37,9 +37,10 @@ class FloatNet(Float.Module): def __init__(self): super().__init__() self.quant = Float.QuantStub() - self.linear = Float.Linear(3, 3) + self.linear = Float.Sequential(Float.Linear(3, 3), Float.Linear(3, 3)) self.dequant = Float.DequantStub() - self.linear.bias[...] = Parameter(np.random.rand(3)) + self.linear[0].bias[...] = Parameter(np.random.rand(3)) + self.linear[1].bias[...] = Parameter(np.random.rand(3)) def forward(self, x): x = self.quant(x) @@ -52,9 +53,10 @@ class QATNet(Float.Module): def __init__(self): super().__init__() self.quant = QAT.QuantStub() - self.linear = QAT.Linear(3, 3) + self.linear = Float.Sequential(QAT.Linear(3, 3), QAT.Linear(3, 3)) self.dequant = QAT.DequantStub() - self.linear.bias[...] = Parameter(np.random.rand(3)) + self.linear[0].bias[...] = Parameter(np.random.rand(3)) + self.linear[1].bias[...] = Parameter(np.random.rand(3)) def forward(self, x): x = self.quant(x) @@ -72,10 +74,14 @@ def test_propagate_qconfig(): net.quant.weight_fake_quant is None, isinstance(net.quant.act_observer, MinMaxObserver), isinstance(net.quant.act_fake_quant, FakeQuantize), - isinstance(net.linear.weight_observer, MinMaxObserver), - isinstance(net.linear.weight_fake_quant, FakeQuantize), - isinstance(net.linear.act_observer, MinMaxObserver), - isinstance(net.linear.act_fake_quant, FakeQuantize), + isinstance(net.linear[0].weight_observer, MinMaxObserver), + isinstance(net.linear[0].weight_fake_quant, FakeQuantize), + isinstance(net.linear[0].act_observer, MinMaxObserver), + isinstance(net.linear[0].act_fake_quant, FakeQuantize), + isinstance(net.linear[1].weight_observer, MinMaxObserver), + isinstance(net.linear[1].weight_fake_quant, FakeQuantize), + isinstance(net.linear[1].act_observer, MinMaxObserver), + isinstance(net.linear[1].act_fake_quant, FakeQuantize), net.dequant.weight_observer is None, net.dequant.weight_fake_quant is None, net.dequant.act_observer is None, @@ -91,10 +97,14 @@ def init_qat_net(): max_val = np.random.randint(1, 127, size=(3,)) net.quant.act_observer.min_val[...] = Parameter(min_val[0]) net.quant.act_observer.max_val[...] = Parameter(max_val[0]) - net.linear.weight_observer.min_val[...] = Parameter(min_val[1]) - net.linear.weight_observer.max_val[...] = Parameter(max_val[1]) - net.linear.act_observer.min_val[...] = Parameter(min_val[2]) - net.linear.act_observer.max_val[...] = Parameter(max_val[2]) + net.linear[0].weight_observer.min_val[...] = Parameter(min_val[1]) + net.linear[0].weight_observer.max_val[...] = Parameter(max_val[1]) + net.linear[0].act_observer.min_val[...] = Parameter(min_val[2]) + net.linear[0].act_observer.max_val[...] = Parameter(max_val[2]) + net.linear[1].weight_observer.min_val[...] = Parameter(min_val[1]) + net.linear[1].weight_observer.max_val[...] = Parameter(max_val[1]) + net.linear[1].act_observer.min_val[...] = Parameter(min_val[2]) + net.linear[1].act_observer.max_val[...] = Parameter(max_val[2]) return net @@ -102,11 +112,20 @@ def test_reset_qconfig(): qat_net = init_qat_net() new_qat_net = reset_qconfig(qat_net, passive_qconfig) assert ( - new_qat_net.linear.get_weight_qparams() == qat_net.linear.get_weight_qparams() + new_qat_net.linear[0].get_weight_qparams() + == qat_net.linear[0].get_weight_qparams() ) assert ( - new_qat_net.linear.get_activation_qparams() - == qat_net.linear.get_activation_qparams() + new_qat_net.linear[0].get_activation_qparams() + == qat_net.linear[0].get_activation_qparams() + ) + assert ( + new_qat_net.linear[1].get_weight_qparams() + == qat_net.linear[1].get_weight_qparams() + ) + assert ( + new_qat_net.linear[1].get_activation_qparams() + == qat_net.linear[1].get_activation_qparams() ) @@ -114,24 +133,32 @@ def test_enable_and_disable_observer(): net = init_qat_net() enable_observer(net) assert net.quant.act_observer.enabled is True - assert net.linear.weight_observer.enabled is True - assert net.linear.act_observer.enabled is True + assert net.linear[0].weight_observer.enabled is True + assert net.linear[0].act_observer.enabled is True + assert net.linear[1].weight_observer.enabled is True + assert net.linear[1].act_observer.enabled is True disable_observer(net) assert net.quant.act_observer.enabled is False - assert net.linear.weight_observer.enabled is False - assert net.linear.act_observer.enabled is False + assert net.linear[0].weight_observer.enabled is False + assert net.linear[0].weight_observer.enabled is False + assert net.linear[1].act_observer.enabled is False + assert net.linear[1].act_observer.enabled is False def test_enable_and_disable_fake_quant(): net = init_qat_net() disable_fake_quant(net) assert net.quant.act_fake_quant.enabled is False - assert net.linear.weight_fake_quant.enabled is False - assert net.linear.act_fake_quant.enabled is False + assert net.linear[0].weight_fake_quant.enabled is False + assert net.linear[0].act_fake_quant.enabled is False + assert net.linear[1].weight_fake_quant.enabled is False + assert net.linear[1].act_fake_quant.enabled is False enable_fake_quant(net) assert net.quant.act_fake_quant.enabled is True - assert net.linear.weight_fake_quant.enabled is True - assert net.linear.act_fake_quant.enabled is True + assert net.linear[0].weight_fake_quant.enabled is True + assert net.linear[0].act_fake_quant.enabled is True + assert net.linear[1].weight_fake_quant.enabled is True + assert net.linear[1].act_fake_quant.enabled is True def init_observer(module, data): @@ -165,7 +192,8 @@ def test_quantize_qat(): net = FloatNet() qat_net = quantize_qat(net, inplace=False, qconfig=min_max_fakequant_qconfig) assert isinstance(qat_net.quant, QAT.QuantStub) - assert isinstance(qat_net.linear, QAT.Linear) + assert isinstance(qat_net.linear[0], QAT.Linear) + assert isinstance(qat_net.linear[1], QAT.Linear) assert isinstance(qat_net.dequant, QAT.DequantStub) @@ -173,7 +201,8 @@ def test_quantize(): qat_net = init_qat_net() q_net = quantize(qat_net, inplace=False) assert isinstance(q_net.quant, Q.QuantStub) - assert isinstance(q_net.linear, Q.Linear) + assert isinstance(q_net.linear[0], Q.Linear) + assert isinstance(q_net.linear[1], Q.Linear) assert isinstance(q_net.dequant, Q.DequantStub) @@ -183,8 +212,10 @@ def test_apply_easy_quant(): eq_net = reset_qconfig(qat_net, passive_qconfig, inplace=False) apply_easy_quant(eq_net, data, 0.9, 1.1, 10) assert isinstance(eq_net.quant.act_observer, PassiveObserver) - assert isinstance(eq_net.linear.weight_observer, PassiveObserver) - assert isinstance(eq_net.linear.act_observer, PassiveObserver) + assert isinstance(eq_net.linear[0].weight_observer, PassiveObserver) + assert isinstance(eq_net.linear[0].act_observer, PassiveObserver) + assert isinstance(eq_net.linear[1].weight_observer, PassiveObserver) + assert isinstance(eq_net.linear[1].act_observer, PassiveObserver) assert eq_net.dequant.act_observer is None @@ -192,8 +223,10 @@ def test_apply_tqt(): qat_net = init_qat_net() tqt_net = reset_qconfig(qat_net, tqt_qconfig, inplace=False) assert isinstance(tqt_net.quant.act_fake_quant, TQT) - assert isinstance(tqt_net.linear.weight_fake_quant, TQT) - assert isinstance(tqt_net.linear.act_fake_quant, TQT) + assert isinstance(tqt_net.linear[0].weight_fake_quant, TQT) + assert isinstance(tqt_net.linear[0].act_fake_quant, TQT) + assert isinstance(tqt_net.linear[1].weight_fake_quant, TQT) + assert isinstance(tqt_net.linear[1].act_fake_quant, TQT) assert tqt_net.dequant.act_fake_quant is None