|
@@ -125,6 +125,9 @@ class trace: |
|
|
self._graph_opt_level = opt_level |
|
|
self._graph_opt_level = opt_level |
|
|
self._tensor_shape = tensor_shape |
|
|
self._tensor_shape = tensor_shape |
|
|
|
|
|
|
|
|
|
|
|
self._reset() |
|
|
|
|
|
|
|
|
|
|
|
def _reset(self): |
|
|
self._untraced = True |
|
|
self._untraced = True |
|
|
self._tinfo = [] # handle -> TensorInfo |
|
|
self._tinfo = [] # handle -> TensorInfo |
|
|
self._seq = [] |
|
|
self._seq = [] |
|
@@ -257,77 +260,117 @@ class trace: |
|
|
def _record_const(self, op, outputs): |
|
|
def _record_const(self, op, outputs): |
|
|
pass |
|
|
pass |
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
|
|
|
def _setup(self): |
|
|
|
|
|
|
|
|
def _set_active(self, active: bool): |
|
|
global active_trace |
|
|
global active_trace |
|
|
if active_trace: |
|
|
|
|
|
raise NotImplementedError("sorry, not implemented: nested trace") |
|
|
|
|
|
active_trace = self |
|
|
|
|
|
|
|
|
|
|
|
if self._untraced: |
|
|
|
|
|
apply.enable(apply_with_tracing) |
|
|
|
|
|
apply.enable(apply_const_with_tracing) |
|
|
|
|
|
if self._symbolic: |
|
|
|
|
|
apply.enable(apply_symbolic_mode) |
|
|
|
|
|
apply.enable(apply_const_symbolic_mode) |
|
|
|
|
|
self._lazy_eval_graph = G.Graph() |
|
|
|
|
|
|
|
|
if active: |
|
|
|
|
|
if active_trace: |
|
|
|
|
|
raise NotImplementedError("sorry, not implemented: nested trace") |
|
|
|
|
|
active_trace = self |
|
|
else: |
|
|
else: |
|
|
apply.enable(apply_compiled_mode) |
|
|
|
|
|
if self._graph is None: |
|
|
|
|
|
self._compile() |
|
|
|
|
|
self._graph.execute() |
|
|
|
|
|
|
|
|
|
|
|
yield |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert active_trace is self |
|
|
|
|
|
active_trace = None |
|
|
|
|
|
|
|
|
|
|
|
def _init_trace(self, symbolic: bool): |
|
|
|
|
|
apply.enable(apply_with_tracing) |
|
|
|
|
|
apply.enable(apply_const_with_tracing) |
|
|
|
|
|
if symbolic: |
|
|
|
|
|
apply.enable(apply_symbolic_mode) |
|
|
|
|
|
apply.enable(apply_const_symbolic_mode) |
|
|
|
|
|
self._lazy_eval_graph = G.Graph() |
|
|
|
|
|
|
|
|
|
|
|
def _take_escaped_tensors(self): |
|
|
escaped_tensors = tuple(self._active_tensors) |
|
|
escaped_tensors = tuple(self._active_tensors) |
|
|
self._active_tensors.clear() |
|
|
self._active_tensors.clear() |
|
|
|
|
|
return escaped_tensors |
|
|
|
|
|
|
|
|
if self._untraced: |
|
|
|
|
|
for x in escaped_tensors: |
|
|
|
|
|
info = self._tinfo[x._TraceMixin__handle] |
|
|
|
|
|
info.data_read = True |
|
|
|
|
|
x._TraceMixin__restore() |
|
|
|
|
|
if self._inputs_to_restore: |
|
|
|
|
|
for x in self._inputs_to_restore: |
|
|
|
|
|
|
|
|
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors): |
|
|
|
|
|
active_lazy_eval_tensors = [] |
|
|
|
|
|
visited = set() |
|
|
|
|
|
readers = [] |
|
|
|
|
|
for x in lazy_eval_tensors: |
|
|
|
|
|
x = x() |
|
|
|
|
|
if x is None or x in visited: |
|
|
|
|
|
continue |
|
|
|
|
|
reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] |
|
|
|
|
|
readers.append(reader) |
|
|
|
|
|
active_lazy_eval_tensors.append(x) |
|
|
|
|
|
visited.add(x) |
|
|
|
|
|
self._apply_graph_options(lazy_eval_graph) |
|
|
|
|
|
lazy_eval_graph.compile(*readers) |
|
|
|
|
|
lazy_eval_graph() |
|
|
|
|
|
for r, x in zip(readers, active_lazy_eval_tensors): |
|
|
|
|
|
assign_raw_tensor(x, as_raw_tensor(r.op.get_value())) |
|
|
|
|
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
|
|
|
def _setup(self): |
|
|
|
|
|
interrupted = False |
|
|
|
|
|
|
|
|
|
|
|
def do_enter(): |
|
|
|
|
|
self._set_active(True) |
|
|
|
|
|
if self._untraced: |
|
|
|
|
|
self._init_trace(self._symbolic) |
|
|
|
|
|
else: |
|
|
|
|
|
apply.enable(apply_compiled_mode) |
|
|
|
|
|
if self._graph is None: |
|
|
|
|
|
self._compile() |
|
|
|
|
|
self._graph.execute() |
|
|
|
|
|
|
|
|
|
|
|
def do_finalize(): |
|
|
|
|
|
escaped_tensors = self._take_escaped_tensors() |
|
|
|
|
|
if self._untraced: |
|
|
|
|
|
for x in escaped_tensors: |
|
|
|
|
|
info = self._tinfo[x._TraceMixin__handle] |
|
|
|
|
|
info.data_read = True |
|
|
x._TraceMixin__restore() |
|
|
x._TraceMixin__restore() |
|
|
if self._symbolic: |
|
|
|
|
|
# eval lazy eval tensors |
|
|
|
|
|
if self._lazy_eval_tensors: |
|
|
|
|
|
lazy_eval_tensors = [] |
|
|
|
|
|
visited = set() |
|
|
|
|
|
readers = [] |
|
|
|
|
|
for x in self._lazy_eval_tensors: |
|
|
|
|
|
x = x() |
|
|
|
|
|
if x is None or x in visited: |
|
|
|
|
|
continue |
|
|
|
|
|
reader = G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] |
|
|
|
|
|
readers.append(reader) |
|
|
|
|
|
lazy_eval_tensors.append(x) |
|
|
|
|
|
visited.add(x) |
|
|
|
|
|
self._apply_graph_options(self._lazy_eval_graph) |
|
|
|
|
|
self._lazy_eval_graph.compile(*readers) |
|
|
|
|
|
self._lazy_eval_graph() |
|
|
|
|
|
for r, x in zip(readers, lazy_eval_tensors): |
|
|
|
|
|
assign_raw_tensor(x, as_raw_tensor(r.op.get_value())) |
|
|
|
|
|
|
|
|
if self._inputs_to_restore: |
|
|
|
|
|
for x in self._inputs_to_restore: |
|
|
|
|
|
x._TraceMixin__restore() |
|
|
|
|
|
if self._symbolic and self._lazy_eval_tensors: |
|
|
|
|
|
# eval lazy eval tensors |
|
|
|
|
|
self._lazy_eval(self._lazy_eval_graph, self._lazy_eval_tensors) |
|
|
self._lazy_eval_graph = None |
|
|
self._lazy_eval_graph = None |
|
|
self._lazy_eval_tensors = None |
|
|
self._lazy_eval_tensors = None |
|
|
self._untraced = False |
|
|
|
|
|
else: |
|
|
|
|
|
if self._pc != len(self._seq): |
|
|
|
|
|
raise TraceMismatchError("premature end") |
|
|
|
|
|
for x in escaped_tensors: |
|
|
|
|
|
assign_raw_tensor(x, as_raw_tensor(x._dev_tensor())) |
|
|
|
|
|
self._graph.wait() |
|
|
|
|
|
self._reset_exec_env() |
|
|
|
|
|
|
|
|
self._untraced = False |
|
|
|
|
|
else: |
|
|
|
|
|
# compiled_tensor leaks |
|
|
|
|
|
if self._pc == len(self._seq): |
|
|
|
|
|
for x in escaped_tensors: |
|
|
|
|
|
try: |
|
|
|
|
|
assign_raw_tensor(x, as_raw_tensor(x._dev_tensor())) |
|
|
|
|
|
except TraceMismatchError: |
|
|
|
|
|
# TraceMismatchError thrown in do_exit |
|
|
|
|
|
pass |
|
|
|
|
|
self._graph.wait() |
|
|
|
|
|
self._reset_exec_env() |
|
|
|
|
|
|
|
|
|
|
|
# reset status |
|
|
self._pc = 0 |
|
|
self._pc = 0 |
|
|
|
|
|
|
|
|
self._tensor_remaps = None |
|
|
|
|
|
apply.disable(apply_with_tracing) |
|
|
|
|
|
apply.disable(apply_const_with_tracing) |
|
|
|
|
|
apply.disable(apply_symbolic_mode) |
|
|
|
|
|
apply.disable(apply_const_symbolic_mode) |
|
|
|
|
|
apply.disable(apply_compiled_mode) |
|
|
|
|
|
active_trace = None |
|
|
|
|
|
|
|
|
self._tensor_remaps = None |
|
|
|
|
|
apply.disable(apply_with_tracing) |
|
|
|
|
|
apply.disable(apply_const_with_tracing) |
|
|
|
|
|
apply.disable(apply_symbolic_mode) |
|
|
|
|
|
apply.disable(apply_const_symbolic_mode) |
|
|
|
|
|
apply.disable(apply_compiled_mode) |
|
|
|
|
|
self._set_active(False) |
|
|
|
|
|
|
|
|
|
|
|
def do_exit(): |
|
|
|
|
|
if not self._untraced and self._pc != len(self._seq): |
|
|
|
|
|
raise TraceMismatchError("premature end") |
|
|
|
|
|
if not self._symbolic or not self._untraced: |
|
|
|
|
|
for x in self._active_tensors: |
|
|
|
|
|
x._dev_tensor() |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
do_enter() |
|
|
|
|
|
yield |
|
|
|
|
|
do_exit() |
|
|
|
|
|
except: |
|
|
|
|
|
interrupted = True |
|
|
|
|
|
raise |
|
|
|
|
|
finally: |
|
|
|
|
|
do_finalize() |
|
|
|
|
|
if interrupted: |
|
|
|
|
|
self._reset() |
|
|
|
|
|
|
|
|
def _begin_excluded_region(self): |
|
|
def _begin_excluded_region(self): |
|
|
if self._capture_as_const: |
|
|
if self._capture_as_const: |
|
|