Browse Source

fix(mge/trace): fix op order in symbolic

GitOrigin-RevId: fbf081a199
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
76f367962f
1 changed files with 16 additions and 9 deletions
  1. +16
    -9
      imperative/python/megengine/jit/tracing.py

+ 16
- 9
imperative/python/megengine/jit/tracing.py View File

@@ -124,7 +124,8 @@ class trace:
self._graph = None self._graph = None
self._need_reset_nodes = None self._need_reset_nodes = None
self._lazy_eval_graph = None self._lazy_eval_graph = None
self._lazy_eval_tensors = weakref.WeakSet()
self._lazy_eval_tensors = []
self._lazy_eval_tensor_count = 0
self._active_tensors = weakref.WeakSet() self._active_tensors = weakref.WeakSet()
self._tensor_remaps = None self._tensor_remaps = None
self._inputs_to_restore = None self._inputs_to_restore = None
@@ -283,12 +284,18 @@ class trace:
x._TraceMixin__restore() x._TraceMixin__restore()
if self._symbolic: if self._symbolic:
# eval lazy eval tensors # eval lazy eval tensors
lazy_eval_tensors = tuple(self._lazy_eval_tensors)
if lazy_eval_tensors:
readers = [
G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
for x in lazy_eval_tensors
]
if self._lazy_eval_tensors:
lazy_eval_tensors = []
visited = set()
readers = []
for x in self._lazy_eval_tensors:
x = x()
if x is None or x in visited:
continue
reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
readers.append(reader)
lazy_eval_tensors.append(x)
visited.add(x)
self._apply_graph_options(self._lazy_eval_graph) self._apply_graph_options(self._lazy_eval_graph)
self._lazy_eval_graph.compile(*readers) self._lazy_eval_graph.compile(*readers)
self._lazy_eval_graph() self._lazy_eval_graph()
@@ -844,7 +851,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
] ]
ovars = apply(op, *ivars) ovars = apply(op, *ivars)
outputs = [LazyEvalTensor(v) for v in ovars] outputs = [LazyEvalTensor(v) for v in ovars]
active_trace._lazy_eval_tensors.update(outputs)
active_trace._lazy_eval_tensors.extend(weakref.ref(oup) for oup in outputs)
return outputs return outputs




@@ -855,7 +862,7 @@ apply.disable(apply_symbolic_mode)
def apply_const_symbolic_mode(op: Const, *args: RawTensor): def apply_const_symbolic_mode(op: Const, *args: RawTensor):
graph = active_trace._lazy_eval_graph graph = active_trace._lazy_eval_graph
ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device))
active_trace._lazy_eval_tensors.add(ret)
active_trace._lazy_eval_tensors.append(weakref.ref(ret))
return (ret,) return (ret,)






Loading…
Cancel
Save