From 829f0907e7ef54840ab1c3de894212348643a2f8 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Tue, 9 Nov 2021 20:07:25 +0800 Subject: [PATCH] fix(mge/traced_module): fix insert qat module GitOrigin-RevId: 35849bc1a26b10fbbba4a6ef72593e82c10a2b6d --- .../megengine/traced_module/traced_module.py | 7 +++--- .../test/unit/traced_module/test_modification.py | 26 ++++++++++++++++++++++ .../test/unit/traced_module/test_qat_module.py | 5 ++--- 3 files changed, 32 insertions(+), 6 deletions(-) diff --git a/imperative/python/megengine/traced_module/traced_module.py b/imperative/python/megengine/traced_module/traced_module.py index 8f5b3ce8..4f06de27 100644 --- a/imperative/python/megengine/traced_module/traced_module.py +++ b/imperative/python/megengine/traced_module/traced_module.py @@ -281,11 +281,8 @@ class _InsertExprs: def __exit__(self, ty, va, tr): if va is not None: return False - set_symbolic_shape(self.use_sym_shape) active_module_tracer().patcher.__exit__(ty, va, tr) _set_convert_node_flag(False) - set_active_module_tracer(None) - unset_module_tracing() while self._tensor_method_patch: pf = self._tensor_method_patch.pop() @@ -298,6 +295,10 @@ class _InsertExprs: v = v.build() setattr(module, k, v) + set_symbolic_shape(self.use_sym_shape) + set_active_module_tracer(None) + unset_module_tracing() + extra_inp_nodes = set(self.global_scope.inputs) max_inp_expr_idx = -1 for node in extra_inp_nodes: diff --git a/imperative/python/test/unit/traced_module/test_modification.py b/imperative/python/test/unit/traced_module/test_modification.py index 0fed2217..8a7c80a3 100644 --- a/imperative/python/test/unit/traced_module/test_modification.py +++ b/imperative/python/test/unit/traced_module/test_modification.py @@ -13,6 +13,7 @@ import numpy as np import megengine.functional as F import megengine.module as M +import megengine.module.qat as qat from megengine.module.identity import Identity from megengine.traced_module import trace_module from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input @@ -199,6 +200,31 @@ def test_insert_module(): assert n.value is None +def test_insert_qat_module(): + class concat(qat.Concat): + pass + + traced_module, x, expect = _init_block() + graph = traced_module.graph + self = graph.inputs[0] + out = graph.outputs[0] + setattr(traced_module, "cat_0", qat.Concat()) + setattr(traced_module, "cat_1", concat()) + + with graph.insert_exprs(): + x_0 = self.cat_0([out, out]) + x_1 = self.cat_1([out, x_0]) + graph.replace_node({out: x_1}) + graph.compile() + + x = F.copy(x) + np.testing.assert_allclose( + F.concat([expect, expect, expect]), traced_module(x), atol=1e-6 + ) + assert not hasattr(traced_module.cat_0, "graph") + assert traced_module.cat_1.graph is not None + + def test_add_input_and_output(): traced_module, x, y = _init_module() diff --git a/imperative/python/test/unit/traced_module/test_qat_module.py b/imperative/python/test/unit/traced_module/test_qat_module.py index 721fdcd4..1bcb74d6 100644 --- a/imperative/python/test/unit/traced_module/test_qat_module.py +++ b/imperative/python/test/unit/traced_module/test_qat_module.py @@ -108,9 +108,8 @@ def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams): def build_observered_net(net: M.Module, observer_cls): qat_net = Q.quantize_qat(net, qconfig=get_observer_config(observer_cls)) Q.enable_observer(qat_net) - for _ in range(5): - inp = Tensor(np.random.random(size=(5, 3, 32, 32))) - qat_net(inp) + inp = Tensor(np.random.random(size=(5, 3, 32, 32))) + qat_net(inp) Q.disable_observer(qat_net) return qat_net