|
- import contextlib
- import functools
- import typing
- import weakref
-
- from ..core.ops.special import Const
- from ..core.tensor import megbrain_graph as G
- from ..core.tensor.core import OpBase, apply
- from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor
-
-
- class TraceMismatchError(RuntimeError):
- pass
-
-
- active_trace = None
- skip_tracing = False
-
-
- @contextlib.contextmanager
- def exclude_from_trace():
- global skip_tracing
- if skip_tracing:
- yield
- return
- try:
- skip_tracing = True
- if active_trace is not None:
- active_trace._begin_excluded_region()
- yield
- finally:
- skip_tracing = False
-
-
- class TensorInfo:
- __slots__ = (
- # collected attributes
- "external",
- "exported",
- "data_read",
- "shape_read",
- "value_read",
- "device",
- "dtype",
- "bound_data",
- # resources for execution
- "varnode",
- "data_setter",
- "shape_reader",
- "value_reader",
- "data_reader",
- )
-
- def __init__(self):
- self.exported = None
- self.data_read = None
- self.shape_read = None
- self.value_read = None
- self.bound_data = None
-
- self.data_setter = None
- self.shape_reader = None
- self.value_reader = None
- self.data_reader = None
-
-
- class trace:
- def __new__(cls, *args, **kwargs):
- if not args:
- return functools.partial(cls, **kwargs)
- self = super().__new__(cls)
- self.__init__(*args, **kwargs)
- return self
-
- def __init__(self, function, symbolic=False, capture_as_const=False):
- self.__wrapped__ = function
- self._symbolic = symbolic
- self._capture_as_const = capture_as_const
- self._capture_static_shape = False
-
- self._untraced = True
- self._tinfo = [] # handle -> TensorInfo
- self._seq = []
- self._pc = 0
- self._graph = None
- self._need_reset_nodes = None
- self._lazy_eval_graph = None
- self._lazy_eval_tensors = weakref.WeakSet()
- self._active_tensors = weakref.WeakSet()
-
- def _new_handle(self):
- handle = len(self._tinfo)
- info = TensorInfo()
- self._tinfo.append(info)
- return handle, info
-
- def _apply_op(self, op, args):
- assert not self._untraced
- # check against trace
- if self._pc >= len(self._seq):
- raise TraceMismatchError("trace should end here, but more op observed")
- record = self._seq[self._pc]
- op_, ihandles, ohandles = record
- if op != op_:
- raise TraceMismatchError("op different from last time")
- if len(ihandles) != len(args):
- raise TraceMismatchError("op input size different from last time")
-
- for h, x in zip(ihandles, args):
- info = self._tinfo[h]
- if info.external:
- if (
- x.__class__ is CompiledTensorProxy
- and not self._tinfo[x._CompiledTensorProxy__handle].exported
- ):
- raise TraceMismatchError(
- "failed to capture: input was an external tensor "
- "last time, got an internal tensor this time"
- )
- if info.bound_data:
- if x.__class__ is CompiledTensorProxy:
- raise TraceMismatchError(
- "const capture violated: was an external tensor "
- "last time, got an internal tensor this time"
- )
- if x._handle != info.bound_data._handle:
- raise TraceMismatchError(
- "const capture violated: got "
- "a different tensor this time"
- )
- else:
- if info.dtype != x.dtype:
- raise TraceMismatchError(
- "failed to capture: different dtype from last time"
- )
- if info.device != x.device:
- raise TraceMismatchError(
- "failed to capture: different device from last time"
- )
- info.data_setter.set_value(x._dev_tensor())
- else:
- if x.__class__ is not CompiledTensorProxy:
- raise TraceMismatchError(
- "unexpected capture: trying to use an external tensor as input, "
- "but that input was an internal tensor last time"
- )
- if x._CompiledTensorProxy__handle != h:
- raise TraceMismatchError(
- "mis-wiring: input edge to an data flow "
- "graph node is different from last time"
- )
-
- self._pc += 1
- outputs = tuple([CompiledTensorProxy(h) for h in ohandles])
- self._active_tensors.update(outputs)
- return outputs
-
- def _record_op(self, op, inputs, outputs):
- if skip_tracing:
- for x in inputs:
- h = getattr(x, "_TraceMixin__handle", None)
- if h is not None:
- self._tinfo[h].data_read = True
- return
-
- ihandles = []
- for x in inputs:
- h = getattr(x, "_TraceMixin__handle", None)
- if h is None or (not self._capture_as_const and self._tinfo[h].exported):
- h, info = self._new_handle()
- info.external = True
- info.device = x.device
- info.dtype = x.dtype
- if self._capture_as_const:
- info.bound_data = x
-
- ihandles.append(h)
-
- ohandles = []
- for x in outputs:
- h, info = self._new_handle()
- ohandles.append(h)
- info.external = False
- TraceMixin._TraceMixin__inject(x, h)
-
- self._seq.append((op, tuple(ihandles), tuple(ohandles)))
- self._active_tensors.update(outputs)
-
- @contextlib.contextmanager
- def _setup(self):
- global active_trace
- if active_trace:
- raise NotImplementedError("sorry, not implemented: nested trace")
- active_trace = self
-
- if self._untraced:
- apply.enable(apply_with_tracing)
- if self._symbolic:
- apply.enable(apply_symbolic_mode)
- self._lazy_eval_graph = G.Graph()
- else:
- apply.enable(apply_compiled_mode)
- if self._graph is None:
- self._compile()
- self._graph.execute()
-
- yield
-
- escaped_tensors = tuple(self._active_tensors)
- self._active_tensors.clear()
-
- if self._untraced:
- for x in escaped_tensors:
- info = self._tinfo[x._TraceMixin__handle]
- info.data_read = True
- x._TraceMixin__restore()
- if self._symbolic:
- # eval lazy eval tensors
- lazy_eval_tensors = tuple(self._lazy_eval_tensors)
- if lazy_eval_tensors:
- readers = [
- G.OutputNode(x._LazyEvalTensor__varnode).outputs[0]
- for x in lazy_eval_tensors
- ]
- self._lazy_eval_graph.compile(*readers)
- self._lazy_eval_graph()
- for r, x in zip(readers, lazy_eval_tensors):
- assign_raw_tensor(x, as_raw_tensor(r.op.get_value()))
- self._lazy_eval_graph = None
- self._lazy_eval_tensors = None
- self._untraced = False
- else:
- if self._pc != len(self._seq):
- raise TraceMismatchError("premature end")
- for x in escaped_tensors:
- assign_raw_tensor(x, as_raw_tensor(x._dev_tensor()))
- self._graph.wait()
- self._reset_exec_env()
- self._pc = 0
-
- apply.disable(apply_with_tracing)
- apply.disable(apply_symbolic_mode)
- apply.disable(apply_compiled_mode)
- active_trace = None
-
- def _begin_excluded_region(self):
- if self._untraced:
- # conditionally reading a compiled tensor in excluded region
- # is permitted, so we have to assume every tensor might be read
- for x in self._active_tensors:
- info = self._tinfo[x._TraceMixin__handle]
- info.exported = True
- info.data_read = True
-
- def _compile(self):
- graph = self._graph = G.Graph()
- # graph.options.graph_opt_level = 0
- need_reset_nodes = self._need_reset_nodes = []
- # links enforce ordering of I/O nodes
- links = ()
- for op, ihandles, ohandles in self._seq:
- ivars = []
- readers = []
- for h in ihandles:
- info = self._tinfo[h]
- if not hasattr(info, "varnode"):
- assert info.external
- if info.bound_data:
- info.varnode = graph.make_const(info.bound_data._dev_tensor())
- else:
- opnode = info.data_setter = G.InputNode(
- *links, device=info.device, dtype=info.dtype, graph=graph
- )
- need_reset_nodes.append(opnode)
- info.varnode, *links = opnode.outputs
-
- ivars.append(info.varnode)
- ovars = apply(op, *ivars)
- assert len(ovars) == len(ohandles)
- for h, v in zip(ohandles, ovars):
- info = self._tinfo[h]
- info.varnode = v
-
- def add_reader(opnode):
- nonlocal links
- need_reset_nodes.append(opnode)
- readers.append(opnode.outputs[0])
- links = opnode.outputs
-
- if info.data_read:
- # Shape can be obtained from data so doesn't need its own
- # output node. On the other hand, value is read separately
- # to leverage eager h2d copy
- info.shape_read = False
- opnode = info.data_reader = G.OutputNode(v, *links)
- add_reader(opnode)
- if info.value_read:
- opnode = info.value_reader = G.ValueOutputNode(v, *links)
- add_reader(opnode)
- if info.shape_read:
- opnode = info.shape_reader = G.AttrOutputNode(v, *links)
- add_reader(opnode)
-
- graph.compile(*readers)
-
- def _reset_exec_env(self):
- for opnode in self._need_reset_nodes:
- opnode.reset()
-
- def _require_shape(self, handle):
- info = self._tinfo[handle]
- info.shape_read = True
-
- def _require_value(self, handle):
- info = self._tinfo[handle]
- info.value_read = True
-
- def _require_data(self, handle):
- info = self._tinfo[handle]
- info.data_read = True
-
- def __call__(self, *args, **kwargs):
- with self._setup():
- return self.__wrapped__(*args, **kwargs)
-
-
- class CompiledTensorProxy(RawTensor):
- """
- Duck-typed RawTensor
- """
-
- def __init__(self, handle):
- self.__handle = handle
- self.__info = active_trace._tinfo[handle]
- self.__shape = None
- self.__data = None
- self.__value = None
-
- @property
- def dtype(self):
- return self.__info.varnode.dtype
-
- @property
- def device(self):
- return self.__info.varnode.device
-
- @property
- def shape(self):
- if self.__shape is None:
- if self.__info.shape_read:
- self.__shape = self.__info.shape_reader.get_value().shape
- elif self.__info.data_read:
- self.__shape = self._dev_tensor().shape
- else:
- raise TraceMismatchError("shape of this tensor is not read in trace")
- return self.__shape
-
- def numpy(self):
- if self.__value is None:
- if self.__info.value_read:
- self.__value = self.__info.value_reader.get_value()
- elif self.__info.data_read:
- self.__value = self._dev_tensor().numpy()
- else:
- raise TraceMismatchError("value of this tensor is not read in trace")
- return self.__value
-
- def _dev_tensor(self):
- if self.__data is None:
- if not self.__info.data_read:
- raise TraceMismatchError("raw data of this tensor is not read in trace")
- self.__data = self.__info.data_reader.get_value()
- return self.__data
-
- def __del__(self):
- if self.__info.shape_read and self.__shape is not None:
- self.__info.shape_reader.drop_value()
- if self.__info.value_read and self.__value is not None:
- self.__info.value_reader.drop_value()
- if self.__info.data_read and self.__data is not None:
- self.__info.data_reader.drop_value()
-
-
- class LazyEvalTensor(RawTensor):
- def __init__(self, varnode):
- self.__varnode = varnode
-
- @property
- def dtype(self):
- return self.__varnode.dtype
-
- @property
- def device(self):
- return self.__varnode.device
-
- @property
- def shape(self):
- return self.__varnode.shape
-
- def numpy(self):
- raise RuntimeError("cannot read value during symbolic tracing")
-
- def _dev_tensor(self):
- raise RuntimeError("cannot access data during symbolic tracing")
-
-
- class TraceMixin:
- __subclass_cache = {}
-
- def __inject(self, handle):
- cache = __class__.__subclass_cache
- cls = self.__class__
- subcls = cache.get(cls)
- if subcls is None:
- subcls = cache[cls] = type("Traced" + cls.__name__, (__class__, cls), {})
- self.__class__ = subcls
- self.__handle = handle
- self.__cls = cls
- return self
-
- def __restore(self):
- cls = self.__cls
- del self.__handle
- del self.__cls
- self.__class__ = cls
- return self
-
- @property
- def shape(self):
- if not skip_tracing:
- active_trace._require_shape(self.__handle)
- return super().shape
-
- def numpy(self):
- if not skip_tracing:
- active_trace._require_value(self.__handle)
- return super().numpy()
-
- def _dev_tensor(self):
- if not skip_tracing:
- active_trace._require_data(self.__handle)
- return super()._dev_tensor()
-
-
- class TracedRawTensor(TraceMixin, RawTensor):
- pass
-
-
- class TracedLazyTensor(TraceMixin, LazyEvalTensor):
- pass
-
-
- def assign_raw_tensor(lhs, rhs):
- handle = rhs._handle
- rhs.__dict__.clear()
- lhs.__dict__.clear()
- lhs.__class__ = RawTensor
- lhs.__init__(handle)
-
-
- # this hook turns RawTensor into LazyEvalTensor
- @apply.register()
- def apply_symbolic_mode(op: OpDef, *args: RawTensor):
- graph = active_trace._lazy_eval_graph
- ivars = [
- getattr(x, "_LazyEvalTensor__varnode", None)
- or graph.make_const(x._dev_tensor())
- for x in args
- ]
- ovars = apply(op, *ivars)
- outputs = [LazyEvalTensor(v) for v in ovars]
- active_trace._lazy_eval_tensors.update(outputs)
- return outputs
-
-
- apply.disable(apply_symbolic_mode)
-
-
- @apply.register()
- def apply_compiled_mode(op: OpDef, *args: RawTensor):
- if skip_tracing:
- args = [
- as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
- for x in args
- ]
- return apply.super(op, *args)
- return active_trace._apply_op(op, args)
-
-
- apply.disable(apply_compiled_mode)
-
-
- # this hook injects TraceMixin
- @apply.register()
- def apply_with_tracing(op: OpDef, *args: RawTensor):
- outputs = apply.super(op, *args)
- active_trace._record_op(op, args, outputs)
- return outputs
-
-
- apply.disable(apply_with_tracing)
-
-
- # @apply.register()
- # def _(op: Const, *args: RawTensor):
- # return active_trace._apply_const(op, args)
-
-
- class BrokenRawTensor(RawTensor):
- def __getattribute__(self, _):
- raise RuntimeError("broken due to misuse of tracing")
-
- def __setattr__(self, *_):
- raise RuntimeError("broken due to misuse of tracing")
|