|
|
@@ -186,6 +186,9 @@ class trace: |
|
|
|
self._seq.append((op, tuple(ihandles), tuple(ohandles))) |
|
|
|
self._active_tensors.update(outputs) |
|
|
|
|
|
|
|
def _record_const(self, op, outputs): |
|
|
|
pass |
|
|
|
|
|
|
|
@contextlib.contextmanager |
|
|
|
def _setup(self): |
|
|
|
global active_trace |
|
|
@@ -195,8 +198,10 @@ class trace: |
|
|
|
|
|
|
|
if self._untraced: |
|
|
|
apply.enable(apply_with_tracing) |
|
|
|
apply.enable(apply_const_with_tracing) |
|
|
|
if self._symbolic: |
|
|
|
apply.enable(apply_symbolic_mode) |
|
|
|
apply.enable(apply_const_symbolic_mode) |
|
|
|
self._lazy_eval_graph = G.Graph() |
|
|
|
else: |
|
|
|
apply.enable(apply_compiled_mode) |
|
|
@@ -239,7 +244,9 @@ class trace: |
|
|
|
self._pc = 0 |
|
|
|
|
|
|
|
apply.disable(apply_with_tracing) |
|
|
|
apply.disable(apply_const_with_tracing) |
|
|
|
apply.disable(apply_symbolic_mode) |
|
|
|
apply.disable(apply_const_symbolic_mode) |
|
|
|
apply.disable(apply_compiled_mode) |
|
|
|
active_trace = None |
|
|
|
|
|
|
@@ -478,6 +485,16 @@ apply.disable(apply_symbolic_mode) |
|
|
|
|
|
|
|
|
|
|
|
@apply.register() |
|
|
|
def apply_const_symbolic_mode(op: Const, *args: RawTensor): |
|
|
|
graph = active_trace._lazy_eval_graph |
|
|
|
ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) |
|
|
|
return (ret,) |
|
|
|
|
|
|
|
|
|
|
|
apply.disable(apply_const_symbolic_mode) |
|
|
|
|
|
|
|
|
|
|
|
@apply.register() |
|
|
|
def apply_compiled_mode(op: OpDef, *args: RawTensor): |
|
|
|
if skip_tracing: |
|
|
|
args = [ |
|
|
@@ -502,9 +519,14 @@ def apply_with_tracing(op: OpDef, *args: RawTensor): |
|
|
|
apply.disable(apply_with_tracing) |
|
|
|
|
|
|
|
|
|
|
|
# @apply.register() |
|
|
|
# def _(op: Const, *args: RawTensor): |
|
|
|
# return active_trace._apply_const(op, args) |
|
|
|
@apply.register() |
|
|
|
def apply_const_with_tracing(op: Const, *args: RawTensor): |
|
|
|
outputs = apply.super(op, *args) |
|
|
|
active_trace._record_const(op, outputs) |
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
|
|
apply.disable(apply_const_with_tracing) |
|
|
|
|
|
|
|
|
|
|
|
class BrokenRawTensor(RawTensor): |
|
|
|