From 76f367962f78ea45df90e88c0b3992a96f146430 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 7 Sep 2020 19:07:23 +0800 Subject: [PATCH] fix(mge/trace): fix op order in symbolic GitOrigin-RevId: fbf081a1999dec7b9401d8898b938eec19021e98 --- imperative/python/megengine/jit/tracing.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index c660a5bb..987a96fb 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -124,7 +124,8 @@ class trace: self._graph = None self._need_reset_nodes = None self._lazy_eval_graph = None - self._lazy_eval_tensors = weakref.WeakSet() + self._lazy_eval_tensors = [] + self._lazy_eval_tensor_count = 0 self._active_tensors = weakref.WeakSet() self._tensor_remaps = None self._inputs_to_restore = None @@ -283,12 +284,18 @@ class trace: x._TraceMixin__restore() if self._symbolic: # eval lazy eval tensors - lazy_eval_tensors = tuple(self._lazy_eval_tensors) - if lazy_eval_tensors: - readers = [ - G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] - for x in lazy_eval_tensors - ] + if self._lazy_eval_tensors: + lazy_eval_tensors = [] + visited = set() + readers = [] + for x in self._lazy_eval_tensors: + x = x() + if x is None or x in visited: + continue + reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] + readers.append(reader) + lazy_eval_tensors.append(x) + visited.add(x) self._apply_graph_options(self._lazy_eval_graph) self._lazy_eval_graph.compile(*readers) self._lazy_eval_graph() @@ -844,7 +851,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): ] ovars = apply(op, *ivars) outputs = [LazyEvalTensor(v) for v in ovars] - active_trace._lazy_eval_tensors.update(outputs) + active_trace._lazy_eval_tensors.extend(weakref.ref(oup) for oup in outputs) return outputs @@ -855,7 +862,7 @@ apply.disable(apply_symbolic_mode) def apply_const_symbolic_mode(op: Const, *args: RawTensor): graph = active_trace._lazy_eval_graph ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) - active_trace._lazy_eval_tensors.add(ret) + active_trace._lazy_eval_tensors.append(weakref.ref(ret)) return (ret,)