|
@@ -1,12 +1,18 @@ |
|
|
|
|
|
import collections |
|
|
import contextlib |
|
|
import contextlib |
|
|
import functools |
|
|
import functools |
|
|
|
|
|
import itertools |
|
|
import typing |
|
|
import typing |
|
|
|
|
|
import warnings |
|
|
import weakref |
|
|
import weakref |
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
from ..core.ops.special import Const |
|
|
from ..core.ops.special import Const |
|
|
from ..core.tensor import megbrain_graph as G |
|
|
from ..core.tensor import megbrain_graph as G |
|
|
from ..core.tensor.core import OpBase, apply |
|
|
|
|
|
|
|
|
from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply |
|
|
from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor |
|
|
from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor |
|
|
|
|
|
from ..core.tensor.tensor import Tensor |
|
|
from .sublinear_memory_config import SublinearMemoryConfig |
|
|
from .sublinear_memory_config import SublinearMemoryConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -83,7 +89,6 @@ class trace: |
|
|
self.__wrapped__ = function |
|
|
self.__wrapped__ = function |
|
|
self._symbolic = symbolic |
|
|
self._symbolic = symbolic |
|
|
self._capture_as_const = capture_as_const |
|
|
self._capture_as_const = capture_as_const |
|
|
self._capture_static_shape = False |
|
|
|
|
|
self._sublinear_memory_config = sublinear_memory_config |
|
|
self._sublinear_memory_config = sublinear_memory_config |
|
|
|
|
|
|
|
|
self._untraced = True |
|
|
self._untraced = True |
|
@@ -95,6 +100,12 @@ class trace: |
|
|
self._lazy_eval_graph = None |
|
|
self._lazy_eval_graph = None |
|
|
self._lazy_eval_tensors = weakref.WeakSet() |
|
|
self._lazy_eval_tensors = weakref.WeakSet() |
|
|
self._active_tensors = weakref.WeakSet() |
|
|
self._active_tensors = weakref.WeakSet() |
|
|
|
|
|
self._tensor_remaps = None |
|
|
|
|
|
self._inputs_to_restore = None |
|
|
|
|
|
self._args_bindings = None |
|
|
|
|
|
self._kwargs_bindings = None |
|
|
|
|
|
self._output_bindings = None |
|
|
|
|
|
self._output_names = None |
|
|
|
|
|
|
|
|
def _new_handle(self): |
|
|
def _new_handle(self): |
|
|
handle = len(self._tinfo) |
|
|
handle = len(self._tinfo) |
|
@@ -132,10 +143,13 @@ class trace: |
|
|
"last time, got an internal tensor this time" |
|
|
"last time, got an internal tensor this time" |
|
|
) |
|
|
) |
|
|
if x._handle != info.bound_data._handle: |
|
|
if x._handle != info.bound_data._handle: |
|
|
raise TraceMismatchError( |
|
|
|
|
|
"const capture violated: got " |
|
|
|
|
|
"a different tensor this time" |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
if not np.array_equal( |
|
|
|
|
|
x.numpy(), info.bound_data.numpy(), equal_nan=True |
|
|
|
|
|
): |
|
|
|
|
|
raise TraceMismatchError( |
|
|
|
|
|
"const capture violated: got " |
|
|
|
|
|
"a different tensor this time" |
|
|
|
|
|
) |
|
|
else: |
|
|
else: |
|
|
if info.dtype != x.dtype: |
|
|
if info.dtype != x.dtype: |
|
|
raise TraceMismatchError( |
|
|
raise TraceMismatchError( |
|
@@ -148,10 +162,13 @@ class trace: |
|
|
info.data_setter.set_value(x._dev_tensor()) |
|
|
info.data_setter.set_value(x._dev_tensor()) |
|
|
else: |
|
|
else: |
|
|
if x.__class__ is not CompiledTensorProxy: |
|
|
if x.__class__ is not CompiledTensorProxy: |
|
|
raise TraceMismatchError( |
|
|
|
|
|
"unexpected capture: trying to use an external tensor as input, " |
|
|
|
|
|
"but that input was an internal tensor last time" |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
if x not in self._tensor_remaps: |
|
|
|
|
|
raise TraceMismatchError( |
|
|
|
|
|
"unexpected capture: trying to use an external tensor as " |
|
|
|
|
|
"input, but that input was an internal tensor last time" |
|
|
|
|
|
) |
|
|
|
|
|
else: |
|
|
|
|
|
x = self._tensor_remaps[x] |
|
|
if x._CompiledTensorProxy__handle != h: |
|
|
if x._CompiledTensorProxy__handle != h: |
|
|
raise TraceMismatchError( |
|
|
raise TraceMismatchError( |
|
|
"mis-wiring: input edge to an data flow " |
|
|
"mis-wiring: input edge to an data flow " |
|
@@ -227,6 +244,9 @@ class trace: |
|
|
info = self._tinfo[x._TraceMixin__handle] |
|
|
info = self._tinfo[x._TraceMixin__handle] |
|
|
info.data_read = True |
|
|
info.data_read = True |
|
|
x._TraceMixin__restore() |
|
|
x._TraceMixin__restore() |
|
|
|
|
|
if self._inputs_to_restore: |
|
|
|
|
|
for x in self._inputs_to_restore: |
|
|
|
|
|
x._TraceMixin__restore() |
|
|
if self._symbolic: |
|
|
if self._symbolic: |
|
|
# eval lazy eval tensors |
|
|
# eval lazy eval tensors |
|
|
lazy_eval_tensors = tuple(self._lazy_eval_tensors) |
|
|
lazy_eval_tensors = tuple(self._lazy_eval_tensors) |
|
@@ -252,6 +272,7 @@ class trace: |
|
|
self._reset_exec_env() |
|
|
self._reset_exec_env() |
|
|
self._pc = 0 |
|
|
self._pc = 0 |
|
|
|
|
|
|
|
|
|
|
|
self._tensor_remaps = None |
|
|
apply.disable(apply_with_tracing) |
|
|
apply.disable(apply_with_tracing) |
|
|
apply.disable(apply_const_with_tracing) |
|
|
apply.disable(apply_const_with_tracing) |
|
|
apply.disable(apply_symbolic_mode) |
|
|
apply.disable(apply_symbolic_mode) |
|
@@ -260,6 +281,10 @@ class trace: |
|
|
active_trace = None |
|
|
active_trace = None |
|
|
|
|
|
|
|
|
def _begin_excluded_region(self): |
|
|
def _begin_excluded_region(self): |
|
|
|
|
|
if self._capture_as_const: |
|
|
|
|
|
raise RuntimeError( |
|
|
|
|
|
"exclude_from_trace cannot be used with capture_as_const" |
|
|
|
|
|
) |
|
|
if self._untraced: |
|
|
if self._untraced: |
|
|
# conditionally reading a compiled tensor in excluded region |
|
|
# conditionally reading a compiled tensor in excluded region |
|
|
# is permitted, so we have to assume every tensor might be read |
|
|
# is permitted, so we have to assume every tensor might be read |
|
@@ -292,6 +317,19 @@ class trace: |
|
|
need_reset_nodes = self._need_reset_nodes = [] |
|
|
need_reset_nodes = self._need_reset_nodes = [] |
|
|
# links enforce ordering of I/O nodes |
|
|
# links enforce ordering of I/O nodes |
|
|
links = () |
|
|
links = () |
|
|
|
|
|
|
|
|
|
|
|
if self._capture_as_const: |
|
|
|
|
|
for h in itertools.chain( |
|
|
|
|
|
self._args_bindings, self._kwargs_bindings.values() |
|
|
|
|
|
): |
|
|
|
|
|
info = self._tinfo[h] |
|
|
|
|
|
opnode = info.data_setter = G.InputNode( |
|
|
|
|
|
device=info.device, dtype=info.dtype, graph=graph |
|
|
|
|
|
) |
|
|
|
|
|
need_reset_nodes.append(opnode) |
|
|
|
|
|
info.varnode = opnode.outputs[0] |
|
|
|
|
|
links += opnode.outputs[1:] |
|
|
|
|
|
|
|
|
for op, ihandles, ohandles in self._seq: |
|
|
for op, ihandles, ohandles in self._seq: |
|
|
ivars = [] |
|
|
ivars = [] |
|
|
readers = [] |
|
|
readers = [] |
|
@@ -355,7 +393,193 @@ class trace: |
|
|
|
|
|
|
|
|
def __call__(self, *args, **kwargs): |
|
|
def __call__(self, *args, **kwargs): |
|
|
with self._setup(): |
|
|
with self._setup(): |
|
|
return self.__wrapped__(*args, **kwargs) |
|
|
|
|
|
|
|
|
if self._capture_as_const: |
|
|
|
|
|
self._process_inputs(*args, **kwargs) |
|
|
|
|
|
outputs = self.__wrapped__(*args, **kwargs) |
|
|
|
|
|
if self._capture_as_const: |
|
|
|
|
|
self._process_outputs(outputs) |
|
|
|
|
|
return outputs |
|
|
|
|
|
|
|
|
|
|
|
def dump(self, file, *, arg_names=None, output_names=None): |
|
|
|
|
|
if not self._capture_as_const: |
|
|
|
|
|
raise ValueError( |
|
|
|
|
|
"you must specify capture_as_const=True at __init__ to use dump" |
|
|
|
|
|
) |
|
|
|
|
|
if self._untraced: |
|
|
|
|
|
raise RuntimeError("should run at least once before calling dump") |
|
|
|
|
|
if self._output_names and output_names: |
|
|
|
|
|
raise TypeError( |
|
|
|
|
|
"cannot specify output_names when output is already in dict format" |
|
|
|
|
|
) |
|
|
|
|
|
if output_names and not isinstance(output_names, collections.Sequence): |
|
|
|
|
|
output_names = (output_names,) |
|
|
|
|
|
if output_names and len(output_names) != len(self._output_bindings): |
|
|
|
|
|
raise ValueError("wrong number of output_names") |
|
|
|
|
|
if arg_names and not isinstance(arg_names, collections.Sequence): |
|
|
|
|
|
arg_names = (arg_names,) |
|
|
|
|
|
if arg_names and len(arg_names) != len(self._arg_bindings): |
|
|
|
|
|
raise ValueError("wrong number of arg_names") |
|
|
|
|
|
output_names = output_names or self._output_names |
|
|
|
|
|
|
|
|
|
|
|
h2v = {} |
|
|
|
|
|
graph = G.Graph() |
|
|
|
|
|
|
|
|
|
|
|
for i, h in enumerate(self._args_bindings): |
|
|
|
|
|
info = self._tinfo[h] |
|
|
|
|
|
h2v[h] = graph.make_h2d(dtype=info.dtype, device=info.device) |
|
|
|
|
|
if arg_names: |
|
|
|
|
|
h2v[h].name = arg_names[i] |
|
|
|
|
|
for k, h in self._kwargs_bindings.items(): |
|
|
|
|
|
info = self._tinfo[h] |
|
|
|
|
|
h2v[h] = graph.make_h2d(dtype=info.dtype, device=info.device) |
|
|
|
|
|
h2v[h].name = k |
|
|
|
|
|
|
|
|
|
|
|
for op, ihandles, ohandles in self._seq: |
|
|
|
|
|
ivars = [] |
|
|
|
|
|
for h in ihandles: |
|
|
|
|
|
info = self._tinfo[h] |
|
|
|
|
|
if h not in h2v: |
|
|
|
|
|
assert info.external |
|
|
|
|
|
assert info.bound_data |
|
|
|
|
|
h2v[h] = graph.make_const(info.bound_data._dev_tensor()) |
|
|
|
|
|
ivars.append(h2v[h]) |
|
|
|
|
|
ovars = apply(op, *ivars) |
|
|
|
|
|
assert len(ovars) == len(ohandles) |
|
|
|
|
|
h2v.update(zip(ohandles, ovars)) |
|
|
|
|
|
|
|
|
|
|
|
dest_vars = [] |
|
|
|
|
|
for i, h in enumerate(self._output_bindings): |
|
|
|
|
|
v = h2v[h] |
|
|
|
|
|
if output_names: |
|
|
|
|
|
v.name = output_names[i] |
|
|
|
|
|
dest_vars.append(v) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(file, str): |
|
|
|
|
|
file = open(file, "wb") |
|
|
|
|
|
file.write(G.dump(*dest_vars)) |
|
|
|
|
|
|
|
|
|
|
|
def _process_inputs(self, *args, **kwargs): |
|
|
|
|
|
if self._untraced: |
|
|
|
|
|
self._inputs_to_restore = [] |
|
|
|
|
|
|
|
|
|
|
|
def record_input(x): |
|
|
|
|
|
if x is None: |
|
|
|
|
|
return |
|
|
|
|
|
h, info = self._new_handle() |
|
|
|
|
|
info.external = False |
|
|
|
|
|
info.device = x.device |
|
|
|
|
|
info.dtype = x.dtype |
|
|
|
|
|
TraceMixin._TraceMixin__inject(x, h) |
|
|
|
|
|
self._inputs_to_restore.append(x) |
|
|
|
|
|
return h |
|
|
|
|
|
|
|
|
|
|
|
self._args_bindings = [] |
|
|
|
|
|
for i, x in enumerate(args): |
|
|
|
|
|
x = find_raw_tensor(x) |
|
|
|
|
|
if x is None: |
|
|
|
|
|
raise TypeError( |
|
|
|
|
|
"positional arguments should all be tensor " |
|
|
|
|
|
"but args[%d] cannot be recognized as one" % i |
|
|
|
|
|
) |
|
|
|
|
|
self._args_bindings.append(record_input(x)) |
|
|
|
|
|
|
|
|
|
|
|
self._kwargs_bindings = {} |
|
|
|
|
|
for k, x in kwargs.items(): |
|
|
|
|
|
x = find_raw_tensor(x) |
|
|
|
|
|
if x is not None: |
|
|
|
|
|
self._kwargs_bindings[k] = record_input(x) |
|
|
|
|
|
else: |
|
|
|
|
|
if len(args) != len(self._args_bindings): |
|
|
|
|
|
raise TraceMismatchError("positional argument length mismatch") |
|
|
|
|
|
|
|
|
|
|
|
self._tensor_remaps = {} |
|
|
|
|
|
|
|
|
|
|
|
for i, (h, x) in enumerate(zip(self._args_bindings, args)): |
|
|
|
|
|
x = find_raw_tensor(x) |
|
|
|
|
|
if x is None: |
|
|
|
|
|
raise TypeError( |
|
|
|
|
|
"positional arguments should all be tensor " |
|
|
|
|
|
"but args[%d] cannot be recognized as one" % i |
|
|
|
|
|
) |
|
|
|
|
|
info = self._tinfo[h] |
|
|
|
|
|
if x.dtype != info.dtype: |
|
|
|
|
|
raise TypeError("args[%d].dtype different from last time" % i) |
|
|
|
|
|
if x.device != info.device: |
|
|
|
|
|
raise TypeError("args[%d].device different from last time" % i) |
|
|
|
|
|
info.data_setter.set_value(x._dev_tensor()) |
|
|
|
|
|
self._tensor_remaps[x] = CompiledTensorProxy(h) |
|
|
|
|
|
|
|
|
|
|
|
kwargs_tensors = {} |
|
|
|
|
|
for k, x in kwargs.items(): |
|
|
|
|
|
x = find_raw_tensor(x) |
|
|
|
|
|
if x is not None: |
|
|
|
|
|
kwargs_tensors[k] = x |
|
|
|
|
|
if set(kwargs_tensors) != set(self._kwargs_bindings): |
|
|
|
|
|
too_many = set(kwargs_tensors) - set(self._kwargs_bindings) |
|
|
|
|
|
too_few = set(self._kwargs_bindings) - set(kwargs_tensors) |
|
|
|
|
|
if too_many: |
|
|
|
|
|
raise TraceMismatchError( |
|
|
|
|
|
"keyword arguments found to be tensor this time " |
|
|
|
|
|
"but were non-tensor previously: %s" % " ".join(too_many) |
|
|
|
|
|
) |
|
|
|
|
|
if too_few: |
|
|
|
|
|
raise TraceMismatchError( |
|
|
|
|
|
"keyword arguments found to be non-tensor this time " |
|
|
|
|
|
"but were tensor previously: %s" % " ".join(too_few) |
|
|
|
|
|
) |
|
|
|
|
|
for k, h in self._kwargs_bindings.items(): |
|
|
|
|
|
x = kwargs_tensors[k] |
|
|
|
|
|
info = self._tinfo[h] |
|
|
|
|
|
if x.dtype != info.dtype: |
|
|
|
|
|
raise TypeError("kwargs[%s].dtype different from last time" % k) |
|
|
|
|
|
if x.device != info.device: |
|
|
|
|
|
raise TypeError("kwargs[%s].device different from last time" % k) |
|
|
|
|
|
info.data_setter.set_value(x._dev_tensor()) |
|
|
|
|
|
self._tensor_remaps[x] = CompiledTensorProxy(h) |
|
|
|
|
|
|
|
|
|
|
|
def _process_outputs(self, outputs): |
|
|
|
|
|
output_names = None |
|
|
|
|
|
if isinstance(outputs, collections.Mapping): |
|
|
|
|
|
output_names, outputs = zip(*sorted(outputs.items())) |
|
|
|
|
|
elif not isinstance(outputs, collections.Sequence): |
|
|
|
|
|
outputs = (outputs,) |
|
|
|
|
|
|
|
|
|
|
|
if not self._untraced: |
|
|
|
|
|
if output_names != self._output_names: |
|
|
|
|
|
too_many = set(output_names) - set(self._output_names) |
|
|
|
|
|
too_few = set(self._output_names) - set(output_names) |
|
|
|
|
|
if too_many: |
|
|
|
|
|
raise TraceMismatchError( |
|
|
|
|
|
"output has more keys than last time: %s" % " ".join(too_many) |
|
|
|
|
|
) |
|
|
|
|
|
if too_few: |
|
|
|
|
|
raise TraceMismatchError( |
|
|
|
|
|
"output has less keys than last time: %s" % " ".join(too_few) |
|
|
|
|
|
) |
|
|
|
|
|
if len(outputs) != len(self._output_bindings): |
|
|
|
|
|
raise TraceMismatchError("output size differs from last time") |
|
|
|
|
|
else: |
|
|
|
|
|
self._output_names = output_names |
|
|
|
|
|
self._output_bindings = [] |
|
|
|
|
|
|
|
|
|
|
|
for i, x in enumerate(outputs): |
|
|
|
|
|
x = find_raw_tensor(x) |
|
|
|
|
|
if x is None: |
|
|
|
|
|
raise TypeError("every item of return value should be tensor") |
|
|
|
|
|
if self._untraced: |
|
|
|
|
|
if not isinstance(x, TraceMixin): |
|
|
|
|
|
raise RuntimeError("output is not computed from inputs") |
|
|
|
|
|
h = x._TraceMixin__handle |
|
|
|
|
|
self._output_bindings.append(h) |
|
|
|
|
|
else: |
|
|
|
|
|
if not isinstance(x, CompiledTensorProxy): |
|
|
|
|
|
raise RuntimeError("output is not computed from inputs") |
|
|
|
|
|
h = x._CompiledTensorProxy__handle |
|
|
|
|
|
if h != self._output_bindings[i]: |
|
|
|
|
|
raise TraceMismatchError( |
|
|
|
|
|
"retval[%s] is a different tensor than last time" |
|
|
|
|
|
% (output_names and output_names[i] or i) |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class CompiledTensorProxy(RawTensor): |
|
|
class CompiledTensorProxy(RawTensor): |
|
@@ -514,6 +738,7 @@ apply.disable(apply_symbolic_mode) |
|
|
def apply_const_symbolic_mode(op: Const, *args: RawTensor): |
|
|
def apply_const_symbolic_mode(op: Const, *args: RawTensor): |
|
|
graph = active_trace._lazy_eval_graph |
|
|
graph = active_trace._lazy_eval_graph |
|
|
ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) |
|
|
ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) |
|
|
|
|
|
active_trace._lazy_eval_tensors.add(ret) |
|
|
return (ret,) |
|
|
return (ret,) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -561,3 +786,27 @@ class BrokenRawTensor(RawTensor): |
|
|
|
|
|
|
|
|
def __setattr__(self, *_): |
|
|
def __setattr__(self, *_): |
|
|
raise RuntimeError("broken due to misuse of tracing") |
|
|
raise RuntimeError("broken due to misuse of tracing") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@functools.singledispatch |
|
|
|
|
|
def find_raw_tensor(x): |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@find_raw_tensor.register(RawTensor) |
|
|
|
|
|
def _(x): |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@find_raw_tensor.register(TensorWrapperBase) |
|
|
|
|
|
def _(x): |
|
|
|
|
|
x = getattr(x, "__wrapped__", None) |
|
|
|
|
|
if x is not None: |
|
|
|
|
|
return find_raw_tensor(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@find_raw_tensor.register(Tensor) |
|
|
|
|
|
def _(x): |
|
|
|
|
|
x = getattr(x, "_data", None) |
|
|
|
|
|
if x is not None: |
|
|
|
|
|
return find_raw_tensor(x) |