|
|
@@ -18,7 +18,13 @@ import weakref |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from ..core._imperative_rt import GraphProfiler |
|
|
|
from ..core._imperative_rt.ops import OprAttr |
|
|
|
from ..core._imperative_rt.ops import ( |
|
|
|
CollectiveComm, |
|
|
|
OprAttr, |
|
|
|
RemoteRecv, |
|
|
|
RemoteSend, |
|
|
|
VirtualDep, |
|
|
|
) |
|
|
|
from ..core._trace_option import set_symbolic_shape |
|
|
|
from ..core._wrap import device as as_device |
|
|
|
from ..core.ops.special import Const |
|
|
@@ -92,6 +98,9 @@ class TensorInfo: |
|
|
|
self.data_reader = None |
|
|
|
|
|
|
|
|
|
|
|
_io_op_types = {CollectiveComm, RemoteSend, RemoteRecv} |
|
|
|
|
|
|
|
|
|
|
|
class trace: |
|
|
|
""" |
|
|
|
Wraps a callable and provide: |
|
|
@@ -143,8 +152,8 @@ class trace: |
|
|
|
self._graph = None |
|
|
|
self._need_reset_nodes = None |
|
|
|
self._lazy_eval_graph = None |
|
|
|
self._lazy_eval_tensors = [] |
|
|
|
self._lazy_eval_tensor_count = 0 |
|
|
|
self._lazy_eval_tensors = weakref.WeakSet() |
|
|
|
self._lazy_eval_links = None |
|
|
|
self._active_tensors = weakref.WeakSet() |
|
|
|
self._tensor_remaps = None |
|
|
|
self._inputs_to_restore = None |
|
|
@@ -286,27 +295,22 @@ class trace: |
|
|
|
apply.enable(apply_const_symbolic_mode) |
|
|
|
self._lazy_eval_graph = G.Graph() |
|
|
|
self._apply_graph_options(self._lazy_eval_graph) |
|
|
|
self._lazy_eval_links = () |
|
|
|
|
|
|
|
def _take_escaped_tensors(self): |
|
|
|
escaped_tensors = tuple(self._active_tensors) |
|
|
|
self._active_tensors.clear() |
|
|
|
return escaped_tensors |
|
|
|
|
|
|
|
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors): |
|
|
|
active_lazy_eval_tensors = [] |
|
|
|
visited = set() |
|
|
|
readers = [] |
|
|
|
for x in lazy_eval_tensors: |
|
|
|
x = x() |
|
|
|
if x is None or x in visited: |
|
|
|
continue |
|
|
|
reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] |
|
|
|
readers.append(reader) |
|
|
|
active_lazy_eval_tensors.append(x) |
|
|
|
visited.add(x) |
|
|
|
lazy_eval_graph.compile(*readers) |
|
|
|
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): |
|
|
|
readers = [ |
|
|
|
G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] |
|
|
|
for x in lazy_eval_tensors |
|
|
|
] |
|
|
|
self._apply_graph_options(lazy_eval_graph) |
|
|
|
lazy_eval_graph.compile(*lazy_eval_links, *readers) |
|
|
|
lazy_eval_graph() |
|
|
|
for r, x in zip(readers, active_lazy_eval_tensors): |
|
|
|
for r, x in zip(readers, lazy_eval_tensors): |
|
|
|
assign_raw_tensor(x, as_raw_tensor(r.op.get_value())) |
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
@@ -333,11 +337,18 @@ class trace: |
|
|
|
if self._inputs_to_restore: |
|
|
|
for x in self._inputs_to_restore: |
|
|
|
x._TraceMixin__restore() |
|
|
|
if self._symbolic and self._lazy_eval_tensors: |
|
|
|
if self._symbolic and ( |
|
|
|
self._lazy_eval_tensors or self._lazy_eval_links |
|
|
|
): |
|
|
|
# eval lazy eval tensors |
|
|
|
self._lazy_eval(self._lazy_eval_graph, self._lazy_eval_tensors) |
|
|
|
self._lazy_eval( |
|
|
|
self._lazy_eval_graph, |
|
|
|
tuple(self._lazy_eval_tensors), |
|
|
|
self._lazy_eval_links, |
|
|
|
) |
|
|
|
self._lazy_eval_graph = None |
|
|
|
self._lazy_eval_tensors = None |
|
|
|
self._lazy_eval_links = None |
|
|
|
self._untraced = False |
|
|
|
else: |
|
|
|
# compiled_tensor leaks |
|
|
@@ -438,8 +449,10 @@ class trace: |
|
|
|
links += opnode.outputs[1:] |
|
|
|
|
|
|
|
for op, ihandles, ohandles in self._seq: |
|
|
|
require_links = type(op) in _io_op_types |
|
|
|
|
|
|
|
ivars = [] |
|
|
|
for h in ihandles: |
|
|
|
for i, h in enumerate(ihandles): |
|
|
|
info = self._tinfo[h] |
|
|
|
if not hasattr(info, "varnode"): |
|
|
|
assert info.external |
|
|
@@ -455,9 +468,14 @@ class trace: |
|
|
|
) |
|
|
|
need_reset_nodes.append(opnode) |
|
|
|
info.varnode, *links = opnode.outputs |
|
|
|
if require_links and i == 0 and len(links) > 0: |
|
|
|
info.varnode = apply(VirtualDep(), info.varnode, *links)[0] |
|
|
|
links = (info.varnode,) |
|
|
|
|
|
|
|
ivars.append(info.varnode) |
|
|
|
ovars = apply(op, *ivars) |
|
|
|
if require_links and len(ovars) > 0: |
|
|
|
links = (ovars[0],) |
|
|
|
assert len(ovars) == len(ohandles) |
|
|
|
for h, v in zip(ohandles, ovars): |
|
|
|
info = self._tinfo[h] |
|
|
@@ -502,6 +520,8 @@ class trace: |
|
|
|
info.data_read = True |
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
|
if is_tracing(): |
|
|
|
return self.__wrapped__(*args, **kwargs) |
|
|
|
with self._setup(): |
|
|
|
if self._capture_as_const: |
|
|
|
self._process_inputs(*args, **kwargs) |
|
|
@@ -938,9 +958,21 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): |
|
|
|
or graph.make_const(x._dev_tensor()) |
|
|
|
for x in args |
|
|
|
] |
|
|
|
|
|
|
|
require_links = type(op) in _io_op_types |
|
|
|
|
|
|
|
if require_links and active_trace._lazy_eval_links: |
|
|
|
assert len(ivars) > 0, "op should has at least one input" |
|
|
|
ivars[0] = apply(VirtualDep(), ivars[0], *active_trace._lazy_eval_links)[0] |
|
|
|
active_trace._lazy_eval_links = (ivars[0],) |
|
|
|
|
|
|
|
ovars = apply(op, *ivars) |
|
|
|
|
|
|
|
if require_links: |
|
|
|
active_trace._lazy_eval_links = (ovars[0],) |
|
|
|
|
|
|
|
outputs = [LazyEvalTensor(v) for v in ovars] |
|
|
|
active_trace._lazy_eval_tensors.extend(weakref.ref(oup) for oup in outputs) |
|
|
|
active_trace._lazy_eval_tensors.update(outputs) |
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
|
@@ -951,7 +983,7 @@ apply.disable(apply_symbolic_mode) |
|
|
|
def apply_const_symbolic_mode(op: Const, *args: RawTensor): |
|
|
|
graph = active_trace._lazy_eval_graph |
|
|
|
ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) |
|
|
|
active_trace._lazy_eval_tensors.append(weakref.ref(ret)) |
|
|
|
active_trace._lazy_eval_tensors.add(ret) |
|
|
|
return (ret,) |
|
|
|
|
|
|
|
|
|
|
|