@@ -281,11 +281,8 @@ class _InsertExprs: | |||||
def __exit__(self, ty, va, tr): | def __exit__(self, ty, va, tr): | ||||
if va is not None: | if va is not None: | ||||
return False | return False | ||||
set_symbolic_shape(self.use_sym_shape) | |||||
active_module_tracer().patcher.__exit__(ty, va, tr) | active_module_tracer().patcher.__exit__(ty, va, tr) | ||||
_set_convert_node_flag(False) | _set_convert_node_flag(False) | ||||
set_active_module_tracer(None) | |||||
unset_module_tracing() | |||||
while self._tensor_method_patch: | while self._tensor_method_patch: | ||||
pf = self._tensor_method_patch.pop() | pf = self._tensor_method_patch.pop() | ||||
@@ -298,6 +295,10 @@ class _InsertExprs: | |||||
v = v.build() | v = v.build() | ||||
setattr(module, k, v) | setattr(module, k, v) | ||||
set_symbolic_shape(self.use_sym_shape) | |||||
set_active_module_tracer(None) | |||||
unset_module_tracing() | |||||
extra_inp_nodes = set(self.global_scope.inputs) | extra_inp_nodes = set(self.global_scope.inputs) | ||||
max_inp_expr_idx = -1 | max_inp_expr_idx = -1 | ||||
for node in extra_inp_nodes: | for node in extra_inp_nodes: | ||||
@@ -13,6 +13,7 @@ import numpy as np | |||||
import megengine.functional as F | import megengine.functional as F | ||||
import megengine.module as M | import megengine.module as M | ||||
import megengine.module.qat as qat | |||||
from megengine.module.identity import Identity | from megengine.module.identity import Identity | ||||
from megengine.traced_module import trace_module | from megengine.traced_module import trace_module | ||||
from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input | from megengine.traced_module.expr import CallFunction, CallMethod, Expr, GetAttr, Input | ||||
@@ -199,6 +200,31 @@ def test_insert_module(): | |||||
assert n.value is None | assert n.value is None | ||||
def test_insert_qat_module(): | |||||
class concat(qat.Concat): | |||||
pass | |||||
traced_module, x, expect = _init_block() | |||||
graph = traced_module.graph | |||||
self = graph.inputs[0] | |||||
out = graph.outputs[0] | |||||
setattr(traced_module, "cat_0", qat.Concat()) | |||||
setattr(traced_module, "cat_1", concat()) | |||||
with graph.insert_exprs(): | |||||
x_0 = self.cat_0([out, out]) | |||||
x_1 = self.cat_1([out, x_0]) | |||||
graph.replace_node({out: x_1}) | |||||
graph.compile() | |||||
x = F.copy(x) | |||||
np.testing.assert_allclose( | |||||
F.concat([expect, expect, expect]), traced_module(x), atol=1e-6 | |||||
) | |||||
assert not hasattr(traced_module.cat_0, "graph") | |||||
assert traced_module.cat_1.graph is not None | |||||
def test_add_input_and_output(): | def test_add_input_and_output(): | ||||
traced_module, x, y = _init_module() | traced_module, x, y = _init_module() | ||||
@@ -108,9 +108,8 @@ def check_qparams(qparmsa: Q.QParams, qparmsb: Q.QParams): | |||||
def build_observered_net(net: M.Module, observer_cls): | def build_observered_net(net: M.Module, observer_cls): | ||||
qat_net = Q.quantize_qat(net, qconfig=get_observer_config(observer_cls)) | qat_net = Q.quantize_qat(net, qconfig=get_observer_config(observer_cls)) | ||||
Q.enable_observer(qat_net) | Q.enable_observer(qat_net) | ||||
for _ in range(5): | |||||
inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | |||||
qat_net(inp) | |||||
inp = Tensor(np.random.random(size=(5, 3, 32, 32))) | |||||
qat_net(inp) | |||||
Q.disable_observer(qat_net) | Q.disable_observer(qat_net) | ||||
return qat_net | return qat_net | ||||