diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index aef6f386..3ac8fbbd 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -1,12 +1,18 @@ +import collections import contextlib import functools +import itertools import typing +import warnings import weakref +import numpy as np + from ..core.ops.special import Const from ..core.tensor import megbrain_graph as G -from ..core.tensor.core import OpBase, apply +from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor +from ..core.tensor.tensor import Tensor from .sublinear_memory_config import SublinearMemoryConfig @@ -83,7 +89,6 @@ class trace: self.__wrapped__ = function self._symbolic = symbolic self._capture_as_const = capture_as_const - self._capture_static_shape = False self._sublinear_memory_config = sublinear_memory_config self._untraced = True @@ -95,6 +100,12 @@ class trace: self._lazy_eval_graph = None self._lazy_eval_tensors = weakref.WeakSet() self._active_tensors = weakref.WeakSet() + self._tensor_remaps = None + self._inputs_to_restore = None + self._args_bindings = None + self._kwargs_bindings = None + self._output_bindings = None + self._output_names = None def _new_handle(self): handle = len(self._tinfo) @@ -132,10 +143,13 @@ class trace: "last time, got an internal tensor this time" ) if x._handle != info.bound_data._handle: - raise TraceMismatchError( - "const capture violated: got " - "a different tensor this time" - ) + if not np.array_equal( + x.numpy(), info.bound_data.numpy(), equal_nan=True + ): + raise TraceMismatchError( + "const capture violated: got " + "a different tensor this time" + ) else: if info.dtype != x.dtype: raise TraceMismatchError( @@ -148,10 +162,13 @@ class trace: info.data_setter.set_value(x._dev_tensor()) else: if x.__class__ is not CompiledTensorProxy: - raise TraceMismatchError( - "unexpected capture: trying to use an external tensor as input, " - "but that input was an internal tensor last time" - ) + if x not in self._tensor_remaps: + raise TraceMismatchError( + "unexpected capture: trying to use an external tensor as " + "input, but that input was an internal tensor last time" + ) + else: + x = self._tensor_remaps[x] if x._CompiledTensorProxy__handle != h: raise TraceMismatchError( "mis-wiring: input edge to an data flow " @@ -227,6 +244,9 @@ class trace: info = self._tinfo[x._TraceMixin__handle] info.data_read = True x._TraceMixin__restore() + if self._inputs_to_restore: + for x in self._inputs_to_restore: + x._TraceMixin__restore() if self._symbolic: # eval lazy eval tensors lazy_eval_tensors = tuple(self._lazy_eval_tensors) @@ -252,6 +272,7 @@ class trace: self._reset_exec_env() self._pc = 0 + self._tensor_remaps = None apply.disable(apply_with_tracing) apply.disable(apply_const_with_tracing) apply.disable(apply_symbolic_mode) @@ -260,6 +281,10 @@ class trace: active_trace = None def _begin_excluded_region(self): + if self._capture_as_const: + raise RuntimeError( + "exclude_from_trace cannot be used with capture_as_const" + ) if self._untraced: # conditionally reading a compiled tensor in excluded region # is permitted, so we have to assume every tensor might be read @@ -292,6 +317,19 @@ class trace: need_reset_nodes = self._need_reset_nodes = [] # links enforce ordering of I/O nodes links = () + + if self._capture_as_const: + for h in itertools.chain( + self._args_bindings, self._kwargs_bindings.values() + ): + info = self._tinfo[h] + opnode = info.data_setter = G.InputNode( + device=info.device, dtype=info.dtype, graph=graph + ) + need_reset_nodes.append(opnode) + info.varnode = opnode.outputs[0] + links += opnode.outputs[1:] + for op, ihandles, ohandles in self._seq: ivars = [] readers = [] @@ -355,7 +393,193 @@ class trace: def __call__(self, *args, **kwargs): with self._setup(): - return self.__wrapped__(*args, **kwargs) + if self._capture_as_const: + self._process_inputs(*args, **kwargs) + outputs = self.__wrapped__(*args, **kwargs) + if self._capture_as_const: + self._process_outputs(outputs) + return outputs + + def dump(self, file, *, arg_names=None, output_names=None): + if not self._capture_as_const: + raise ValueError( + "you must specify capture_as_const=True at __init__ to use dump" + ) + if self._untraced: + raise RuntimeError("should run at least once before calling dump") + if self._output_names and output_names: + raise TypeError( + "cannot specify output_names when output is already in dict format" + ) + if output_names and not isinstance(output_names, collections.Sequence): + output_names = (output_names,) + if output_names and len(output_names) != len(self._output_bindings): + raise ValueError("wrong number of output_names") + if arg_names and not isinstance(arg_names, collections.Sequence): + arg_names = (arg_names,) + if arg_names and len(arg_names) != len(self._arg_bindings): + raise ValueError("wrong number of arg_names") + output_names = output_names or self._output_names + + h2v = {} + graph = G.Graph() + + for i, h in enumerate(self._args_bindings): + info = self._tinfo[h] + h2v[h] = graph.make_h2d(dtype=info.dtype, device=info.device) + if arg_names: + h2v[h].name = arg_names[i] + for k, h in self._kwargs_bindings.items(): + info = self._tinfo[h] + h2v[h] = graph.make_h2d(dtype=info.dtype, device=info.device) + h2v[h].name = k + + for op, ihandles, ohandles in self._seq: + ivars = [] + for h in ihandles: + info = self._tinfo[h] + if h not in h2v: + assert info.external + assert info.bound_data + h2v[h] = graph.make_const(info.bound_data._dev_tensor()) + ivars.append(h2v[h]) + ovars = apply(op, *ivars) + assert len(ovars) == len(ohandles) + h2v.update(zip(ohandles, ovars)) + + dest_vars = [] + for i, h in enumerate(self._output_bindings): + v = h2v[h] + if output_names: + v.name = output_names[i] + dest_vars.append(v) + + if isinstance(file, str): + file = open(file, "wb") + file.write(G.dump(*dest_vars)) + + def _process_inputs(self, *args, **kwargs): + if self._untraced: + self._inputs_to_restore = [] + + def record_input(x): + if x is None: + return + h, info = self._new_handle() + info.external = False + info.device = x.device + info.dtype = x.dtype + TraceMixin._TraceMixin__inject(x, h) + self._inputs_to_restore.append(x) + return h + + self._args_bindings = [] + for i, x in enumerate(args): + x = find_raw_tensor(x) + if x is None: + raise TypeError( + "positional arguments should all be tensor " + "but args[%d] cannot be recognized as one" % i + ) + self._args_bindings.append(record_input(x)) + + self._kwargs_bindings = {} + for k, x in kwargs.items(): + x = find_raw_tensor(x) + if x is not None: + self._kwargs_bindings[k] = record_input(x) + else: + if len(args) != len(self._args_bindings): + raise TraceMismatchError("positional argument length mismatch") + + self._tensor_remaps = {} + + for i, (h, x) in enumerate(zip(self._args_bindings, args)): + x = find_raw_tensor(x) + if x is None: + raise TypeError( + "positional arguments should all be tensor " + "but args[%d] cannot be recognized as one" % i + ) + info = self._tinfo[h] + if x.dtype != info.dtype: + raise TypeError("args[%d].dtype different from last time" % i) + if x.device != info.device: + raise TypeError("args[%d].device different from last time" % i) + info.data_setter.set_value(x._dev_tensor()) + self._tensor_remaps[x] = CompiledTensorProxy(h) + + kwargs_tensors = {} + for k, x in kwargs.items(): + x = find_raw_tensor(x) + if x is not None: + kwargs_tensors[k] = x + if set(kwargs_tensors) != set(self._kwargs_bindings): + too_many = set(kwargs_tensors) - set(self._kwargs_bindings) + too_few = set(self._kwargs_bindings) - set(kwargs_tensors) + if too_many: + raise TraceMismatchError( + "keyword arguments found to be tensor this time " + "but were non-tensor previously: %s" % " ".join(too_many) + ) + if too_few: + raise TraceMismatchError( + "keyword arguments found to be non-tensor this time " + "but were tensor previously: %s" % " ".join(too_few) + ) + for k, h in self._kwargs_bindings.items(): + x = kwargs_tensors[k] + info = self._tinfo[h] + if x.dtype != info.dtype: + raise TypeError("kwargs[%s].dtype different from last time" % k) + if x.device != info.device: + raise TypeError("kwargs[%s].device different from last time" % k) + info.data_setter.set_value(x._dev_tensor()) + self._tensor_remaps[x] = CompiledTensorProxy(h) + + def _process_outputs(self, outputs): + output_names = None + if isinstance(outputs, collections.Mapping): + output_names, outputs = zip(*sorted(outputs.items())) + elif not isinstance(outputs, collections.Sequence): + outputs = (outputs,) + + if not self._untraced: + if output_names != self._output_names: + too_many = set(output_names) - set(self._output_names) + too_few = set(self._output_names) - set(output_names) + if too_many: + raise TraceMismatchError( + "output has more keys than last time: %s" % " ".join(too_many) + ) + if too_few: + raise TraceMismatchError( + "output has less keys than last time: %s" % " ".join(too_few) + ) + if len(outputs) != len(self._output_bindings): + raise TraceMismatchError("output size differs from last time") + else: + self._output_names = output_names + self._output_bindings = [] + + for i, x in enumerate(outputs): + x = find_raw_tensor(x) + if x is None: + raise TypeError("every item of return value should be tensor") + if self._untraced: + if not isinstance(x, TraceMixin): + raise RuntimeError("output is not computed from inputs") + h = x._TraceMixin__handle + self._output_bindings.append(h) + else: + if not isinstance(x, CompiledTensorProxy): + raise RuntimeError("output is not computed from inputs") + h = x._CompiledTensorProxy__handle + if h != self._output_bindings[i]: + raise TraceMismatchError( + "retval[%s] is a different tensor than last time" + % (output_names and output_names[i] or i) + ) class CompiledTensorProxy(RawTensor): @@ -514,6 +738,7 @@ apply.disable(apply_symbolic_mode) def apply_const_symbolic_mode(op: Const, *args: RawTensor): graph = active_trace._lazy_eval_graph ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) + active_trace._lazy_eval_tensors.add(ret) return (ret,) @@ -561,3 +786,27 @@ class BrokenRawTensor(RawTensor): def __setattr__(self, *_): raise RuntimeError("broken due to misuse of tracing") + + +@functools.singledispatch +def find_raw_tensor(x): + return None + + +@find_raw_tensor.register(RawTensor) +def _(x): + return x + + +@find_raw_tensor.register(TensorWrapperBase) +def _(x): + x = getattr(x, "__wrapped__", None) + if x is not None: + return find_raw_tensor(x) + + +@find_raw_tensor.register(Tensor) +def _(x): + x = getattr(x, "_data", None) + if x is not None: + return find_raw_tensor(x) diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py index 83b61fec..b7a6dacd 100644 --- a/imperative/python/test/unit/test_tracing.py +++ b/imperative/python/test/unit/test_tracing.py @@ -1,3 +1,5 @@ +import io + import numpy as np from megengine.core.ops import builtin as ops @@ -63,3 +65,20 @@ def test_print_in_trace(): buf = None np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) np.testing.assert_equal(z, buf) + + +def test_dump(): + @trace(symbolic=True, capture_as_const=True) + def f(x): + op = ops.Elemwise(mode="negate") + (y,) = apply(op, x) + return y + + x = as_raw_tensor([1]).numpy() + y = f.__wrapped__(as_raw_tensor(x)).numpy() + + for i in range(3): + np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) + + file = io.BytesIO() + f.dump(file)