Browse Source

feat(mge): add trace.dump

GitOrigin-RevId: ea4c9d33c8
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
a3b2232ba7
2 changed files with 279 additions and 11 deletions
  1. +260
    -11
      imperative/python/megengine/jit/tracing.py
  2. +19
    -0
      imperative/python/test/unit/test_tracing.py

+ 260
- 11
imperative/python/megengine/jit/tracing.py View File

@@ -1,12 +1,18 @@
import collections
import contextlib import contextlib
import functools import functools
import itertools
import typing import typing
import warnings
import weakref import weakref


import numpy as np

from ..core.ops.special import Const from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G from ..core.tensor import megbrain_graph as G
from ..core.tensor.core import OpBase, apply
from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply
from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor
from ..core.tensor.tensor import Tensor
from .sublinear_memory_config import SublinearMemoryConfig from .sublinear_memory_config import SublinearMemoryConfig




@@ -83,7 +89,6 @@ class trace:
self.__wrapped__ = function self.__wrapped__ = function
self._symbolic = symbolic self._symbolic = symbolic
self._capture_as_const = capture_as_const self._capture_as_const = capture_as_const
self._capture_static_shape = False
self._sublinear_memory_config = sublinear_memory_config self._sublinear_memory_config = sublinear_memory_config


self._untraced = True self._untraced = True
@@ -95,6 +100,12 @@ class trace:
self._lazy_eval_graph = None self._lazy_eval_graph = None
self._lazy_eval_tensors = weakref.WeakSet() self._lazy_eval_tensors = weakref.WeakSet()
self._active_tensors = weakref.WeakSet() self._active_tensors = weakref.WeakSet()
self._tensor_remaps = None
self._inputs_to_restore = None
self._args_bindings = None
self._kwargs_bindings = None
self._output_bindings = None
self._output_names = None


def _new_handle(self): def _new_handle(self):
handle = len(self._tinfo) handle = len(self._tinfo)
@@ -132,10 +143,13 @@ class trace:
"last time, got an internal tensor this time" "last time, got an internal tensor this time"
) )
if x._handle != info.bound_data._handle: if x._handle != info.bound_data._handle:
raise TraceMismatchError(
"const capture violated: got "
"a different tensor this time"
)
if not np.array_equal(
x.numpy(), info.bound_data.numpy(), equal_nan=True
):
raise TraceMismatchError(
"const capture violated: got "
"a different tensor this time"
)
else: else:
if info.dtype != x.dtype: if info.dtype != x.dtype:
raise TraceMismatchError( raise TraceMismatchError(
@@ -148,10 +162,13 @@ class trace:
info.data_setter.set_value(x._dev_tensor()) info.data_setter.set_value(x._dev_tensor())
else: else:
if x.__class__ is not CompiledTensorProxy: if x.__class__ is not CompiledTensorProxy:
raise TraceMismatchError(
"unexpected capture: trying to use an external tensor as input, "
"but that input was an internal tensor last time"
)
if x not in self._tensor_remaps:
raise TraceMismatchError(
"unexpected capture: trying to use an external tensor as "
"input, but that input was an internal tensor last time"
)
else:
x = self._tensor_remaps[x]
if x._CompiledTensorProxy__handle != h: if x._CompiledTensorProxy__handle != h:
raise TraceMismatchError( raise TraceMismatchError(
"mis-wiring: input edge to an data flow " "mis-wiring: input edge to an data flow "
@@ -227,6 +244,9 @@ class trace:
info = self._tinfo[x._TraceMixin__handle] info = self._tinfo[x._TraceMixin__handle]
info.data_read = True info.data_read = True
x._TraceMixin__restore() x._TraceMixin__restore()
if self._inputs_to_restore:
for x in self._inputs_to_restore:
x._TraceMixin__restore()
if self._symbolic: if self._symbolic:
# eval lazy eval tensors # eval lazy eval tensors
lazy_eval_tensors = tuple(self._lazy_eval_tensors) lazy_eval_tensors = tuple(self._lazy_eval_tensors)
@@ -252,6 +272,7 @@ class trace:
self._reset_exec_env() self._reset_exec_env()
self._pc = 0 self._pc = 0


self._tensor_remaps = None
apply.disable(apply_with_tracing) apply.disable(apply_with_tracing)
apply.disable(apply_const_with_tracing) apply.disable(apply_const_with_tracing)
apply.disable(apply_symbolic_mode) apply.disable(apply_symbolic_mode)
@@ -260,6 +281,10 @@ class trace:
active_trace = None active_trace = None


def _begin_excluded_region(self): def _begin_excluded_region(self):
if self._capture_as_const:
raise RuntimeError(
"exclude_from_trace cannot be used with capture_as_const"
)
if self._untraced: if self._untraced:
# conditionally reading a compiled tensor in excluded region # conditionally reading a compiled tensor in excluded region
# is permitted, so we have to assume every tensor might be read # is permitted, so we have to assume every tensor might be read
@@ -292,6 +317,19 @@ class trace:
need_reset_nodes = self._need_reset_nodes = [] need_reset_nodes = self._need_reset_nodes = []
# links enforce ordering of I/O nodes # links enforce ordering of I/O nodes
links = () links = ()

if self._capture_as_const:
for h in itertools.chain(
self._args_bindings, self._kwargs_bindings.values()
):
info = self._tinfo[h]
opnode = info.data_setter = G.InputNode(
device=info.device, dtype=info.dtype, graph=graph
)
need_reset_nodes.append(opnode)
info.varnode = opnode.outputs[0]
links += opnode.outputs[1:]

for op, ihandles, ohandles in self._seq: for op, ihandles, ohandles in self._seq:
ivars = [] ivars = []
readers = [] readers = []
@@ -355,7 +393,193 @@ class trace:


def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
with self._setup(): with self._setup():
return self.__wrapped__(*args, **kwargs)
if self._capture_as_const:
self._process_inputs(*args, **kwargs)
outputs = self.__wrapped__(*args, **kwargs)
if self._capture_as_const:
self._process_outputs(outputs)
return outputs

def dump(self, file, *, arg_names=None, output_names=None):
if not self._capture_as_const:
raise ValueError(
"you must specify capture_as_const=True at __init__ to use dump"
)
if self._untraced:
raise RuntimeError("should run at least once before calling dump")
if self._output_names and output_names:
raise TypeError(
"cannot specify output_names when output is already in dict format"
)
if output_names and not isinstance(output_names, collections.Sequence):
output_names = (output_names,)
if output_names and len(output_names) != len(self._output_bindings):
raise ValueError("wrong number of output_names")
if arg_names and not isinstance(arg_names, collections.Sequence):
arg_names = (arg_names,)
if arg_names and len(arg_names) != len(self._arg_bindings):
raise ValueError("wrong number of arg_names")
output_names = output_names or self._output_names

h2v = {}
graph = G.Graph()

for i, h in enumerate(self._args_bindings):
info = self._tinfo[h]
h2v[h] = graph.make_h2d(dtype=info.dtype, device=info.device)
if arg_names:
h2v[h].name = arg_names[i]
for k, h in self._kwargs_bindings.items():
info = self._tinfo[h]
h2v[h] = graph.make_h2d(dtype=info.dtype, device=info.device)
h2v[h].name = k

for op, ihandles, ohandles in self._seq:
ivars = []
for h in ihandles:
info = self._tinfo[h]
if h not in h2v:
assert info.external
assert info.bound_data
h2v[h] = graph.make_const(info.bound_data._dev_tensor())
ivars.append(h2v[h])
ovars = apply(op, *ivars)
assert len(ovars) == len(ohandles)
h2v.update(zip(ohandles, ovars))

dest_vars = []
for i, h in enumerate(self._output_bindings):
v = h2v[h]
if output_names:
v.name = output_names[i]
dest_vars.append(v)

if isinstance(file, str):
file = open(file, "wb")
file.write(G.dump(*dest_vars))

def _process_inputs(self, *args, **kwargs):
if self._untraced:
self._inputs_to_restore = []

def record_input(x):
if x is None:
return
h, info = self._new_handle()
info.external = False
info.device = x.device
info.dtype = x.dtype
TraceMixin._TraceMixin__inject(x, h)
self._inputs_to_restore.append(x)
return h

self._args_bindings = []
for i, x in enumerate(args):
x = find_raw_tensor(x)
if x is None:
raise TypeError(
"positional arguments should all be tensor "
"but args[%d] cannot be recognized as one" % i
)
self._args_bindings.append(record_input(x))

self._kwargs_bindings = {}
for k, x in kwargs.items():
x = find_raw_tensor(x)
if x is not None:
self._kwargs_bindings[k] = record_input(x)
else:
if len(args) != len(self._args_bindings):
raise TraceMismatchError("positional argument length mismatch")

self._tensor_remaps = {}

for i, (h, x) in enumerate(zip(self._args_bindings, args)):
x = find_raw_tensor(x)
if x is None:
raise TypeError(
"positional arguments should all be tensor "
"but args[%d] cannot be recognized as one" % i
)
info = self._tinfo[h]
if x.dtype != info.dtype:
raise TypeError("args[%d].dtype different from last time" % i)
if x.device != info.device:
raise TypeError("args[%d].device different from last time" % i)
info.data_setter.set_value(x._dev_tensor())
self._tensor_remaps[x] = CompiledTensorProxy(h)

kwargs_tensors = {}
for k, x in kwargs.items():
x = find_raw_tensor(x)
if x is not None:
kwargs_tensors[k] = x
if set(kwargs_tensors) != set(self._kwargs_bindings):
too_many = set(kwargs_tensors) - set(self._kwargs_bindings)
too_few = set(self._kwargs_bindings) - set(kwargs_tensors)
if too_many:
raise TraceMismatchError(
"keyword arguments found to be tensor this time "
"but were non-tensor previously: %s" % " ".join(too_many)
)
if too_few:
raise TraceMismatchError(
"keyword arguments found to be non-tensor this time "
"but were tensor previously: %s" % " ".join(too_few)
)
for k, h in self._kwargs_bindings.items():
x = kwargs_tensors[k]
info = self._tinfo[h]
if x.dtype != info.dtype:
raise TypeError("kwargs[%s].dtype different from last time" % k)
if x.device != info.device:
raise TypeError("kwargs[%s].device different from last time" % k)
info.data_setter.set_value(x._dev_tensor())
self._tensor_remaps[x] = CompiledTensorProxy(h)

def _process_outputs(self, outputs):
output_names = None
if isinstance(outputs, collections.Mapping):
output_names, outputs = zip(*sorted(outputs.items()))
elif not isinstance(outputs, collections.Sequence):
outputs = (outputs,)

if not self._untraced:
if output_names != self._output_names:
too_many = set(output_names) - set(self._output_names)
too_few = set(self._output_names) - set(output_names)
if too_many:
raise TraceMismatchError(
"output has more keys than last time: %s" % " ".join(too_many)
)
if too_few:
raise TraceMismatchError(
"output has less keys than last time: %s" % " ".join(too_few)
)
if len(outputs) != len(self._output_bindings):
raise TraceMismatchError("output size differs from last time")
else:
self._output_names = output_names
self._output_bindings = []

for i, x in enumerate(outputs):
x = find_raw_tensor(x)
if x is None:
raise TypeError("every item of return value should be tensor")
if self._untraced:
if not isinstance(x, TraceMixin):
raise RuntimeError("output is not computed from inputs")
h = x._TraceMixin__handle
self._output_bindings.append(h)
else:
if not isinstance(x, CompiledTensorProxy):
raise RuntimeError("output is not computed from inputs")
h = x._CompiledTensorProxy__handle
if h != self._output_bindings[i]:
raise TraceMismatchError(
"retval[%s] is a different tensor than last time"
% (output_names and output_names[i] or i)
)




class CompiledTensorProxy(RawTensor): class CompiledTensorProxy(RawTensor):
@@ -514,6 +738,7 @@ apply.disable(apply_symbolic_mode)
def apply_const_symbolic_mode(op: Const, *args: RawTensor): def apply_const_symbolic_mode(op: Const, *args: RawTensor):
graph = active_trace._lazy_eval_graph graph = active_trace._lazy_eval_graph
ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device)) ret = LazyEvalTensor(graph.make_const(op.value, dtype=op.dtype, device=op.device))
active_trace._lazy_eval_tensors.add(ret)
return (ret,) return (ret,)




@@ -561,3 +786,27 @@ class BrokenRawTensor(RawTensor):


def __setattr__(self, *_): def __setattr__(self, *_):
raise RuntimeError("broken due to misuse of tracing") raise RuntimeError("broken due to misuse of tracing")


@functools.singledispatch
def find_raw_tensor(x):
return None


@find_raw_tensor.register(RawTensor)
def _(x):
return x


@find_raw_tensor.register(TensorWrapperBase)
def _(x):
x = getattr(x, "__wrapped__", None)
if x is not None:
return find_raw_tensor(x)


@find_raw_tensor.register(Tensor)
def _(x):
x = getattr(x, "_data", None)
if x is not None:
return find_raw_tensor(x)

+ 19
- 0
imperative/python/test/unit/test_tracing.py View File

@@ -1,3 +1,5 @@
import io

import numpy as np import numpy as np


from megengine.core.ops import builtin as ops from megengine.core.ops import builtin as ops
@@ -63,3 +65,20 @@ def test_print_in_trace():
buf = None buf = None
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)
np.testing.assert_equal(z, buf) np.testing.assert_equal(z, buf)


def test_dump():
@trace(symbolic=True, capture_as_const=True)
def f(x):
op = ops.Elemwise(mode="negate")
(y,) = apply(op, x)
return y

x = as_raw_tensor([1]).numpy()
y = f.__wrapped__(as_raw_tensor(x)).numpy()

for i in range(3):
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y)

file = io.BytesIO()
f.dump(file)

Loading…
Cancel
Save