|
|
@@ -124,7 +124,8 @@ class trace: |
|
|
|
self._graph = None |
|
|
|
self._need_reset_nodes = None |
|
|
|
self._lazy_eval_graph = None |
|
|
|
self._lazy_eval_tensors = weakref.WeakSet() |
|
|
|
self._lazy_eval_tensors = [] |
|
|
|
self._lazy_eval_tensor_count = 0 |
|
|
|
self._active_tensors = weakref.WeakSet() |
|
|
|
self._tensor_remaps = None |
|
|
|
self._inputs_to_restore = None |
|
|
@@ -283,12 +284,18 @@ class trace: |
|
|
|
x._TraceMixin__restore() |
|
|
|
if self._symbolic: |
|
|
|
# eval lazy eval tensors |
|
|
|
lazy_eval_tensors = tuple(self._lazy_eval_tensors) |
|
|
|
if lazy_eval_tensors: |
|
|
|
readers = [ |
|
|
|
G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] |
|
|
|
for x in lazy_eval_tensors |
|
|
|
] |
|
|
|
if self._lazy_eval_tensors: |
|
|
|
lazy_eval_tensors = [] |
|
|
|
visited = set() |
|
|
|
readers = [] |
|
|
|
for x in self._lazy_eval_tensors: |
|
|
|
x = x() |
|
|
|
if x is None or x in visited: |
|
|
|
continue |
|
|
|
reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] |
|
|
|
readers.append(reader) |
|
|
|
lazy_eval_tensors.append(x) |
|
|
|
visited.add(x) |
|
|
|
self._apply_graph_options(self._lazy_eval_graph) |
|
|
|
self._lazy_eval_graph.compile(*readers) |
|
|
|
self._lazy_eval_graph() |
|
|
@@ -844,7 +851,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): |
|
|
|
] |
|
|
|
ovars = apply(op, *ivars) |
|
|
|
outputs = [LazyEvalTensor(v) for v in ovars] |
|
|
|
active_trace._lazy_eval_tensors.update(outputs) |
|
|
|
active_trace._lazy_eval_tensors.extend(weakref.ref(oup) for oup in outputs) |
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
|
@@ -855,7 +862,7 @@ apply.disable(apply_symbolic_mode) |
|
|
|
def apply_const_symbolic_mode(op: Const, *args: RawTensor): |
|
|
|
graph = active_trace._lazy_eval_graph |
|
|
|
ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) |
|
|
|
active_trace._lazy_eval_tensors.add(ret) |
|
|
|
active_trace._lazy_eval_tensors.append(weakref.ref(ret)) |
|
|
|
return (ret,) |
|
|
|
|
|
|
|
|
|
|
|