|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import collections
- import contextlib
- import functools
- import itertools
- import json
- import os
- import typing
- import warnings
- import weakref
-
- import numpy as np
-
- from ..core._imperative_rt import GraphProfiler, common
- from ..core._imperative_rt.core2 import Tensor as RawTensor
- from ..core._imperative_rt.core2 import (
- TensorWeakRef,
- apply,
- set_compiled,
- set_tracing,
- skip_tracing,
- unset_compiled,
- unset_tracing,
- )
- from ..core._imperative_rt.ops import CollectiveComm, RemoteRecv, RemoteSend
- from ..core._trace_option import set_symbolic_shape
- from ..core._wrap import device as as_device
- from ..core.ops.builtin import BackwardGraph, OpDef
- from ..core.ops.special import Const
- from ..core.tensor import megbrain_graph as G
- from ..core.tensor.utils import setscalar
- from .sublinear_memory_config import SublinearMemoryConfig
-
-
- def _input_node_use_static_shape():
- return os.environ.get("MEGENGINE_INPUT_NODE_USE_STATIC_SHAPE") is not None
-
-
- class TraceMismatchError(RuntimeError):
- pass
-
-
- active_trace = None
-
-
- def is_tracing():
- if active_trace is None:
- return False
- else:
- return not skip_tracing
-
-
- @contextlib.contextmanager
- def exclude_from_trace():
- global skip_tracing
- if skip_tracing:
- yield
- return
- try:
- skip_tracing = True
- unset_tracing()
- if active_trace is not None:
- active_trace._begin_excluded_region()
- yield
- finally:
- skip_tracing = False
- set_tracing()
-
-
- class TensorInfo:
- __slots__ = (
- # collected attributes
- "external",
- "data_read",
- "shape_read",
- "value_read",
- "exported",
- "device",
- "dtype",
- "shape",
- "is_const",
- "bound_data",
- # resources for execution
- "varnode",
- "data_setter",
- "shape_reader",
- "value_reader",
- "data_reader",
- )
-
- def __init__(self):
- self.exported = None
- self.data_read = None
- self.shape_read = None
- self.value_read = None
- self.bound_data = None
-
- self.data_setter = None
- self.shape_reader = None
- self.value_reader = None
- self.data_reader = None
-
-
- _io_op_types = {CollectiveComm, RemoteSend, RemoteRecv}
-
-
- class trace:
- """
- Wraps a callable and provide:
-
- * tracing via :meth:`.trace` and :meth:`.dump`
- * accelerated evalutaion via :meth:`.__call__`
-
- :param function: the function will be traced.
- :param symbolic: whether to apply symbolic execution for tracing. Default: False
- :param capture_as_const: capture global vars or closures as const value. Default: False
- :param sublinear_memory_config: configuration for sublinear memory optimization.
- If not None, it enables sublinear memory optimization with given setting.
- :param profiling: whether to profile compiled trace. Default: False
- :param opt_level: optimization level for compiling trace.
- :param symbolic_shape: whether to use symbolic shape for tracing. Default: True
- """
-
- def __new__(cls, *args, **kwargs):
- if not args:
- return functools.partial(cls, **kwargs)
- return super().__new__(cls)
-
- def __init__(
- self,
- function,
- symbolic=False,
- capture_as_const=False,
- sublinear_memory_config: SublinearMemoryConfig = None,
- profiling: bool = False,
- opt_level: int = None,
- symbolic_shape: bool = True,
- ):
- self.__wrapped__ = function
- self._symbolic = symbolic
- self._capture_as_const = capture_as_const
- self._sublinear_memory_config = sublinear_memory_config
- self._profiling = profiling
- self._profiler = None
- self._graph_opt_level = opt_level
- self._symbolic_shape = symbolic_shape
- self._output_handles = set()
-
- self._reset()
-
- def _reset(self):
- self._untraced = True
- self._tinfo = [] # handle -> TensorInfo
- self._seq = []
- self._pc = 0
- self._graph = None
- self._need_reset_nodes = None
- self._lazy_eval_graph = None
- self._lazy_eval_tensors = {}
- self._lazy_eval_links = None
- self._active_tensors = {}
- self._tensor_remaps = None
- self._inputs_to_restore = None
- self._arg_bindings = None
- self._kwarg_bindings = None
- self._output_bindings = None
- self._output_names = None
-
- def _new_handle(self):
- handle = len(self._tinfo)
- info = TensorInfo()
- self._tinfo.append(info)
- return handle, info
-
- def _apply_op(self, op, args):
- assert not self._untraced
- # check against trace
- if self._pc >= len(self._seq):
- raise TraceMismatchError("trace should end here, but more op observed")
- record = self._seq[self._pc]
- op_, ihandles, ohandles = record
- if (isinstance(op_, str) and op_ == "Const") or (op != op_):
- raise TraceMismatchError("op different from last time")
- if len(ihandles) != len(args):
- raise TraceMismatchError("op input size different from last time")
-
- # check all inputs of crrent op
- for h, x in zip(ihandles, args):
- info = self._tinfo[h]
- if info.external:
- if (
- x._compiled_info is not None
- and not self._tinfo[x._mixin_handle].exported
- ):
- raise TraceMismatchError(
- "failed to capture: input was an external tensor "
- "last time, got an internal tensor this time"
- )
- if info.bound_data:
- if x._compiled_info is not None:
- raise TraceMismatchError(
- "const capture violated: was an external tensor "
- "last time, got an internal tensor this time"
- )
- if x._handle != info.bound_data._handle:
- if not np.array_equal(x.numpy(), info.bound_data.numpy()):
- raise TraceMismatchError(
- "const capture violated: got "
- "a different tensor this time"
- )
- else:
- if info.dtype != x.dtype:
- raise TraceMismatchError(
- "failed to capture: different dtype from last time"
- )
- if info.device != x.device:
- raise TraceMismatchError(
- "failed to capture: different device from last time"
- )
- info.data_setter.set_value(x._dev_tensor())
- else:
- if x._mixin_handle == -1:
- if x._handle not in self._tensor_remaps:
- raise TraceMismatchError(
- "unexpected capture: trying to use an external tensor as "
- "input, but that input was an internal tensor last time"
- )
- else:
- x._mixin_handle = self._tensor_remaps[
- x._handle
- ]._CompiledTensorProxy__handle
- if x._mixin_handle != h:
- raise TraceMismatchError(
- "mis-wiring: input edge to an data flow "
- "graph node is different from last time"
- )
-
- self._pc += 1
- outputs = []
- for h in ohandles:
- info = self._tinfo[h]
- # generate output tensor and create compied info
- y = RawTensor(info.varnode)
- y._compiled_info = CompiledTensorProxy(h)
- y._mixin_handle = h
- outputs += [y]
- self._active_tensors[h] = TensorWeakRef(y)
- self._output_handles.update(ohandles)
- return outputs
-
- def _apply_const(self, value, dtype, device):
- assert not self._untraced
- # check against trace
- if self._pc >= len(self._seq):
- raise TraceMismatchError("trace should end here, but more op observed")
- record = self._seq[self._pc]
- op_, ihandles, ohandles = record
- # Const op is represented by a str
- assert isinstance(op_, str) and op_ == "Const"
-
- eq = np.all(np.atleast_1d(value) == self._tinfo[ohandles[0]].bound_data.numpy())
- if not eq:
- raise TraceMismatchError(
- "const tensor violated: got a different tensor this time"
- )
-
- self._pc += 1
- (h,) = ohandles
- outputs = [self._tinfo[h].bound_data]
- return outputs
-
- # run in first step, record information for trace
- def _record_op(self, op, inputs, outputs):
- if skip_tracing:
- for x in inputs:
- h = getattr(x, "_mixin_handle", -1)
- if h >= 0:
- self._tinfo[h].data = True
- return
-
- ihandles = []
- for x in inputs:
- h = getattr(x, "_mixin_handle", -1)
- if h < 0 or (not self._capture_as_const and self._tinfo[h].exported):
- h, info = self._new_handle()
- info.external = True
- info.device = x.device
- info.dtype = x.dtype
- info.shape = x.shape
- if self._capture_as_const:
- info.bound_data = RawTensor(x.numpy(), x.dtype, x.device, False)
-
- ihandles.append(h)
-
- ohandles = []
- for x in outputs:
- h, info = self._new_handle()
- ohandles.append(h)
- info.external = False
- x._mixin_handle = h
- x._recording = True
- x._trace_mixin_info = info
- self._active_tensors[h] = TensorWeakRef(x)
- if self._symbolic:
- self._lazy_eval_tensors[h] = TensorWeakRef(x)
-
- self._seq.append((op, tuple(ihandles), tuple(ohandles)))
-
- def _record_const(self, outputs):
- if skip_tracing:
- (x,) = outputs
- h = getattr(x, "_mixin_handle", -1)
- if h >= 0:
- self._tinfo[h].data_read = True
- return
-
- (x,) = outputs
- h, info = self._new_handle()
- ohandles = [h]
- info.external = True
- info.device = x.device
- info.dtype = x.dtype
- info.shape = x.shape
- info.bound_data = x
- info.is_const = True
- x._mixin_handle = h
- x._recording = True
- x._trace_mixin_info = info
- if self._symbolic:
- self._lazy_eval_tensors[h] = TensorWeakRef(x)
- self._seq.append(("Const", tuple(), tuple(ohandles)))
-
- def _set_active(self, active: bool):
- global active_trace
- if active:
- if active_trace:
- raise NotImplementedError("sorry, not implemented: nested trace")
- active_trace = self
- else:
- assert active_trace is self
- active_trace = None
-
- def _init_trace(self, symbolic: bool):
- if symbolic:
- self._lazy_eval_graph = G.Graph()
- self._apply_graph_options(self._lazy_eval_graph)
- self._lazy_eval_links = ()
-
- def _take_escaped_tensors(self):
- escaped_tensors = tuple(
- filter(lambda x: x() is not None, self._active_tensors.values())
- )
- self._active_tensors.clear()
- return escaped_tensors
-
- def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links):
- lazy_eval_tensors = list(
- filter(lambda x: x() is not None, lazy_eval_tensors.values())
- )
- readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors]
- self._apply_graph_options(lazy_eval_graph)
- # FIXME
- if self._graph_opt_level is not None:
- lazy_eval_graph.options.graph_opt_level = self._graph_opt_level
- else:
- lazy_eval_graph.options.graph_opt_level = 2
- lazy_eval_graph._set_priority_to_id([*lazy_eval_links, *readers])
- lazy_eval_graph.compile(*lazy_eval_links, *readers)
- lazy_eval_graph()
- for r, x in zip(readers, lazy_eval_tensors):
- # get values from lazy_eval_graph and assign to lazy_eval tensor
- x()._handle = RawTensor(r.op.get_value())._handle
- x()._reset_varnode()
-
- @contextlib.contextmanager
- def _setup(self):
- interrupted = False
-
- def do_enter():
- set_tracing()
- self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape)
- self._set_active(True)
- if self._untraced:
- self._init_trace(self._symbolic)
- else:
- set_compiled()
- if self._graph is None:
- self._compile()
- self._graph.execute()
-
- def do_finalize():
- escaped_tensors = self._take_escaped_tensors()
- if self._untraced:
- for x in escaped_tensors:
- if x():
- info = self._tinfo[x()._mixin_handle]
- info.data_read = True
- x()._mixin_handle = -1
- x()._recording = False
- if self._inputs_to_restore:
- for x in self._inputs_to_restore:
- x._mixin_handle = -1
- x._recording = False
- if self._symbolic and (
- self._lazy_eval_tensors or self._lazy_eval_links
- ):
- # eval lazy eval tensors
- self._lazy_eval(
- self._lazy_eval_graph,
- self._lazy_eval_tensors,
- self._lazy_eval_links,
- )
- self._lazy_eval_graph = None
- self._lazy_eval_tensors = None
- self._lazy_eval_links = None
- self._untraced = False
- else:
- # compiled_tensor leaks
- if self._pc == len(self._seq):
- for x in escaped_tensors:
- try:
- assign_raw_tensor(x(), RawTensor(x()._dev_tensor()))
- except RuntimeError:
- # TraceMismatchError thrown in do_exit
- pass
- self._graph.wait()
- self._reset_exec_env()
-
- # reset status
- self._pc = 0
- self._tensor_remaps = None
- self._set_active(False)
- set_symbolic_shape(self._save_symbolic_shape)
- unset_compiled()
- unset_tracing()
-
- def do_exit():
- unset_tracing()
- if not self._untraced and self._pc != len(self._seq):
- raise TraceMismatchError("premature end")
- if not self._symbolic or not self._untraced:
- # reset output tensors
- for x in self._active_tensors.values():
- if x() is not None:
- x()._dev_tensor()
- x()._reset_varnode()
- x()._mixin_handle = -1
- x()._recording = False
- x()._trace_mixin_info = None
-
- try:
- do_enter()
- yield
- do_exit()
- except:
- interrupted = True
- raise
- finally:
- do_finalize()
- if interrupted:
- self._reset()
-
- def _begin_excluded_region(self):
- if self._capture_as_const:
- raise RuntimeError(
- "exclude_from_trace cannot be used with capture_as_const"
- )
- if self._untraced:
- # conditionally reading a compiled tensor in excluded region
- # is permitted, so we have to assume every tensor might be read
- for x in self._active_tensors.values():
- if x():
- info = self._tinfo[x()._mixin_handle]
- info.exported = True
- info.data_read = True
- else:
- for x in self._active_tensors.values():
- if x():
- x()._dev_tensor()
-
- def _apply_graph_options(self, graph):
-
- graph.options.no_force_inplace = True
- graph.options.seq_opt.enable_seq_comp_node_opt = False
- # graph opt level
- # if self._graph_opt_level is not None:
- # graph.options.graph_opt_level = self._graph_opt_level
- # FIXME
- graph.options.graph_opt_level = 0
- # sublinear
- if self._sublinear_memory_config is not None:
- graph.options.enable_sublinear_memory_opt = True
- sublinear_config = graph.options.sublinear_mem_config
- sublinear_config.lb_memory = self._sublinear_memory_config.lb_memory
- sublinear_config.genetic_nr_iter = (
- self._sublinear_memory_config.genetic_nr_iter
- )
- sublinear_config.genetic_pool_size = (
- self._sublinear_memory_config.genetic_pool_size
- )
- sublinear_config.thresh_nr_try = self._sublinear_memory_config.thresh_nr_try
- sublinear_config.num_worker = self._sublinear_memory_config.num_worker
- # profile
- if self._profiling:
- self._profiler = GraphProfiler(graph)
- if int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")):
- graph.options.var_sanity_check_first_run = False
-
- def _compile(self):
- graph = self._graph = G.Graph()
- graph.options.async_exec_level = 0b100
- self._apply_graph_options(graph)
- # graph.options.graph_opt_level = 0
- need_reset_nodes = self._need_reset_nodes = []
- # links enforce ordering of I/O nodes
- in_out_links = ()
- io_links = ()
- readers = []
-
- if self._capture_as_const:
- for h in itertools.chain(self._arg_bindings, self._kwarg_bindings.values()):
- info = self._tinfo[h]
- opnode = info.data_setter = G.InputNode(
- device=info.device,
- dtype=info.dtype,
- shape=info.shape or (1,),
- graph=graph,
- use_static_shape=_input_node_use_static_shape(),
- )
- need_reset_nodes.append(opnode)
- info.varnode = opnode.outputs[0]
- in_out_links += opnode.outputs[1:]
-
- for op, ihandles, ohandles in self._seq:
- if isinstance(op, str) and op == "Const":
- assert len(ihandles) == 0
- (h,) = ohandles
- info = self._tinfo[h]
- if not hasattr(info, "varnode"):
- assert info.external
- assert info.bound_data
- info.varnode = graph.make_const(
- info.bound_data.numpy(),
- info.bound_data.dtype,
- info.bound_data.device,
- )
- continue
-
- require_links = type(op) in _io_op_types
- ivars = []
- for i, h in enumerate(ihandles):
- info = self._tinfo[h]
- if not hasattr(info, "varnode"):
- assert info.external
- if info.bound_data:
- if hasattr(info, "is_const") and info.is_const:
- info.varnode = graph.make_const(
- info.bound_data.numpy(),
- info.bound_data.dtype,
- info.bound_data.device,
- )
- else:
- info.varnode = graph.make_const(
- info.bound_data._dev_tensor()
- # info.bound_data.numpy()
- )
- else:
- opnode = info.data_setter = G.InputNode(
- *in_out_links,
- device=info.device,
- dtype=info.dtype,
- shape=info.shape or (1,),
- graph=graph,
- use_static_shape=_input_node_use_static_shape(),
- )
- need_reset_nodes.append(opnode)
- info.varnode, *in_out_links = opnode.outputs
- if require_links and i == 0 and len(io_links) > 0:
- opnode = G.VirtualDepNode(
- [info.varnode, *io_links], str(io_links[0].device)
- )
- info.varnode = opnode.outputs[0]
- io_links = (info.varnode,)
-
- ivars.append(info.varnode)
-
- if isinstance(op, BackwardGraph):
- ovars = G.apply_backward_varnode(op, *ivars)
- else:
- ovars = G.apply_normal_varnode(op, *ivars)
-
- if require_links and len(ovars) > 0:
- io_links = (ovars[0],)
- assert len(ovars) == len(ohandles)
- for h, v in zip(ohandles, ovars):
- info = self._tinfo[h]
- info.varnode = v
-
- def add_reader(opnode):
- nonlocal in_out_links
- need_reset_nodes.append(opnode)
- readers.append(opnode.outputs[0])
- in_out_links = opnode.outputs
-
- if info.data_read:
- # Shape can be obtained from data so doesn't need its own
- # output node. On the other hand, value is read separately
- # to leverage eager h2d copy
- info.shape_read = False
- opnode = info.data_reader = G.OutputNode(v, *in_out_links)
- add_reader(opnode)
- if info.value_read:
- opnode = info.value_reader = G.ValueOutputNode(v, *in_out_links)
- add_reader(opnode)
- if info.shape_read:
- opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links)
- add_reader(opnode)
-
- # FIXME
- if self._graph_opt_level is not None:
- graph.options.graph_opt_level = self._graph_opt_level
- else:
- graph.options.graph_opt_level = 2
- graph._set_priority_to_id([*readers, *in_out_links, *io_links])
- graph.compile(*readers, *in_out_links, *io_links)
-
- def _reset_exec_env(self):
- for opnode in self._need_reset_nodes:
- opnode.reset()
-
- def __call__(self, *args, **kwargs):
- if is_tracing():
- return self.__wrapped__(*args, **kwargs)
- with self._setup():
- if self._capture_as_const:
- self._process_inputs(*args, **kwargs)
- outputs = self.__wrapped__(*args, **kwargs)
- transform = False
- # outputs can be None
- if outputs is not None:
- if not isinstance(outputs, collections.abc.Sequence):
- transform = True
- outputs = (outputs,)
- for o in outputs:
- # if outputs are copied, then use the newest info in trace data structure
- if o._copied:
- self._active_tensors[o._mixin_handle] = TensorWeakRef(o)
- if self._untraced and self._symbolic:
- self._lazy_eval_tensors[o._mixin_handle] = TensorWeakRef(o)
- if self._capture_as_const:
- self._process_outputs(outputs)
- if transform:
- outputs = outputs[0]
- return outputs
-
- def dump(
- self,
- file,
- *,
- arg_names=None,
- output_names=None,
- append=False,
- optimize_for_inference=True,
- **kwargs
- ):
- r"""
- Serializes trace to file system.
-
- :param file: output file, could be file object or filename.
- :param arg_names: names of the input tensors in the traced function.
- :param output_names: names of the output tensors in the traced function,
- use the default name if not specified.
- :param append: whether output is appended to ``file``.
- Only works when ``file`` is str.
- :param optimize_for_inference: enbale optmizations,
- will skip all optimize options if this is False. Default: True
-
- :Keyword Arguments:
-
- * enable_io16xc32 --
- whether to use float16 for I/O between oprs and use
- float32 as internal computation precision. Note the output var would be
- changed to float16.
- * enable_ioc16 --
- whether to use float16 for both I/O and computation
- precision.
-
- * enable_hwcd4 --
- whether to use NHWCD4 data layout. This is faster on some
- OpenCL backend.
- * enable_nchw88 --
- whether to use NCHW88 data layout, currently
- used in X86 AVX backend.
- * enable_nchw44 --
- whether to use NCHW44 data layout, currently
- used in arm backend.
- * enable_nchw44_dot --
- whether to use NCHW44_dot data layout, currently
- used in armv8.2+dotprod backend.
- * enable_nchw4 --
- whether to use NCHW4 data layout, currently
- used in nvidia backend(based on cudnn).
- * enable_nchw32 --
- whether to use NCHW32 data layout, currently
- used in nvidia backend with tensorcore(based on cudnn).
- * enable_chwn4 --
- whether to use CHWN4 data layout, currently
- used in nvidia backend with tensorcore.
-
- * enable_fuse_conv_bias_nonlinearity: whether to fuse conv+bias+nonlinearty
- into one opr.
- * enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
- input for inference on nvidia backend(this optimization pass will
- result in mismatch of the precision of output of training and
- inference)
- """
- if not self._capture_as_const:
- raise ValueError(
- "you must specify capture_as_const=True at __init__ to use dump"
- )
- if self._untraced:
- raise RuntimeError("should run at least once before calling dump")
- if self._output_names and output_names:
- raise TypeError(
- "cannot specify output_names when output is already in dict format"
- )
- if output_names and not isinstance(output_names, collections.abc.Sequence):
- output_names = (output_names,)
- if output_names and len(output_names) != len(self._output_bindings):
- raise ValueError(
- "wrong number of output_names, should be {} values".format(
- len(self._output_bindings)
- )
- )
- if arg_names is None:
- arg_names = ["arg_%d" % i for i in range(len(self._arg_bindings))]
- if arg_names and not isinstance(arg_names, collections.abc.Sequence):
- arg_names = (arg_names,)
- if arg_names and len(arg_names) != len(self._arg_bindings):
- raise ValueError(
- "wrong number of arg_names, should be {} values".format(
- len(self._arg_bindings)
- )
- )
- output_names = output_names or self._output_names
-
- dumped_device = as_device("xpux")
-
- h2v = {}
- graph = G.Graph()
- # only graph_opt_level takes effect in dump
- self._apply_graph_options(graph)
-
- for i, h in enumerate(self._arg_bindings):
- info = self._tinfo[h]
- h2v[h] = graph.make_h2d(
- dtype=info.dtype,
- device=dumped_device,
- shape=info.shape or (1,),
- name=arg_names[i] if arg_names else None,
- )
- for k, h in self._kwarg_bindings.items():
- info = self._tinfo[h]
- h2v[h] = graph.make_h2d(
- dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k
- )
-
- for op, ihandles, ohandles in self._seq:
- if isinstance(op, str) and op == "Const":
- assert len(ihandles) == 0
- (h,) = ohandles
- info = self._tinfo[h]
- if h not in h2v:
- assert info.external
- assert info.bound_data
- h2v[h] = graph.make_const(
- info.bound_data.numpy(), dtype=info.dtype, device=info.device,
- )
- continue
- ivars = []
- for h in ihandles:
- info = self._tinfo[h]
- if h not in h2v:
- assert info.external
- assert info.bound_data
- h2v[h] = graph.make_const(
- info.bound_data.numpy(), dtype=info.dtype, device=dumped_device
- )
- ivars.append(h2v[h])
- ovars = G.apply_normal_varnode(op, *ivars)
- assert len(ovars) == len(ohandles)
- h2v.update(zip(ohandles, ovars))
-
- dest_vars = []
- for i, h in enumerate(self._output_bindings):
- v = h2v[h]
- if output_names:
- v.name = output_names[i]
- dest_vars.append(v)
-
- if optimize_for_inference:
- dest_vars = G.optimize_for_inference(dest_vars, **kwargs)
-
- if isinstance(file, str):
- permission = "wb" if append == False else "ab"
- file = open(file, permission)
- dump_content, dump_info = G.dump_graph(dest_vars)
- file.write(dump_content)
- return dump_info
-
- def _process_inputs(self, *args, **kwargs):
- if self._untraced:
- self._inputs_to_restore = []
-
- def record_input(x):
- if x is None:
- return
- h, info = self._new_handle()
- info.external = False
- info.device = x.device
- info.dtype = x.dtype
- info.shape = x.numpy().shape
- x._mixin_handle = h
- x._recording = True
- x._trace_mixin_info = info
- self._inputs_to_restore.append(x)
- return h
-
- self._arg_bindings = []
- for i, x in enumerate(args):
- if not isinstance(x, RawTensor):
- raise TypeError(
- "positional arguments should all be tensor "
- "but args[%d] cannot be recognized as one" % i
- )
- self._arg_bindings.append(record_input(x))
-
- self._kwarg_bindings = {}
- for k, x in kwargs.items():
- if isinstance(x, RawTensor):
- self._kwarg_bindings[k] = record_input(x)
- else:
- if len(args) != len(self._arg_bindings):
- raise TraceMismatchError("positional argument length mismatch")
-
- self._tensor_remaps = {}
-
- for i, (h, x) in enumerate(zip(self._arg_bindings, args)):
- if not isinstance(x, RawTensor):
- raise TypeError(
- "positional arguments should all be tensor "
- "but args[%d] cannot be recognized as one" % i
- )
- info = self._tinfo[h]
- if x.dtype != info.dtype:
- raise TypeError("args[%d].dtype different from last time" % i)
- if x.device != info.device:
- raise TypeError("args[%d].device different from last time" % i)
- info.data_setter.set_value(x._dev_tensor())
- self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
-
- kwargs_tensors = {}
- for k, x in kwargs.items():
- if isinstance(x, RawTensor):
- kwargs_tensors[k] = x
- if set(kwargs_tensors) != set(self._kwarg_bindings):
- too_many = set(kwargs_tensors) - set(self._kwarg_bindings)
- too_few = set(self._kwarg_bindings) - set(kwargs_tensors)
- if too_many:
- raise TraceMismatchError(
- "keyword arguments found to be tensor this time "
- "but were non-tensor previously: %s" % " ".join(too_many)
- )
- if too_few:
- raise TraceMismatchError(
- "keyword arguments found to be non-tensor this time "
- "but were tensor previously: %s" % " ".join(too_few)
- )
- for k, h in self._kwarg_bindings.items():
- x = kwargs_tensors[k]
- info = self._tinfo[h]
- if x.dtype != info.dtype:
- raise TypeError("kwargs[%s].dtype different from last time" % k)
- if x.device != info.device:
- raise TypeError("kwargs[%s].device different from last time" % k)
- info.data_setter.set_value(x._dev_tensor())
- self._tensor_remaps[x._handle] = CompiledTensorProxy(h)
-
- def _process_outputs(self, outputs):
- output_names = None
- if isinstance(outputs, collections.abc.Mapping):
- output_names, outputs = zip(*sorted(outputs.items()))
- elif not isinstance(outputs, collections.abc.Sequence):
- outputs = (outputs,)
-
- if not self._untraced:
- if output_names != self._output_names:
- too_many = set(output_names) - set(self._output_names)
- too_few = set(self._output_names) - set(output_names)
- if too_many:
- raise TraceMismatchError(
- "output has more keys than last time: %s" % " ".join(too_many)
- )
- if too_few:
- raise TraceMismatchError(
- "output has less keys than last time: %s" % " ".join(too_few)
- )
- if len(outputs) != len(self._output_bindings):
- raise TraceMismatchError("output size differs from last time")
- else:
- self._output_names = output_names
- self._output_bindings = []
-
- for i, x in enumerate(outputs):
- if not isinstance(x, RawTensor):
- raise TypeError("every item of return value should be tensor")
- if self._untraced:
- h = x._mixin_handle
- if h < 0:
- raise RuntimeError("output is not computed from inputs")
- self._output_bindings.append(h)
- else:
- h = x._mixin_handle
- if h not in self._output_handles:
- raise RuntimeError("output is not computed from inputs")
- if h != self._output_bindings[i]:
- raise TraceMismatchError(
- "retval[%s] is a different tensor than last time"
- % (output_names and output_names[i] or i)
- )
-
- def get_profile(self):
- """
- Get profiling result for compiled trace.
-
- :return: a json compatible object.
- """
- if not self._profiler:
- raise RuntimeError("trace is not set with profiling=True")
- return json.loads(self._profiler.get())
-
- def __del__(self):
- for x in self._tinfo:
- if getattr(x, "bound_data", None):
- x.bound_data = None
-
- def trace(self, *args, **kwargs):
- raise NotImplementedError(
- "trace is deemed unbeneficial with the new "
- "tracing mechanism. You should alwasy use __call__."
- )
-
-
- class CompiledTensorProxy:
- """
- Duck-typed RawTensor
- """
-
- def __init__(self, handle):
- self.__handle = handle
- self._isscalar = False
- self.__info = active_trace._tinfo[handle]
- self.__shape = None
- self.__data = None
- self.__value = None
-
- @property
- def dtype(self):
- return self.__info.varnode.dtype
-
- @property
- def device(self):
- return self.__info.varnode.device
-
- @property
- def shape(self):
- if self._isscalar:
- return ()
- if self.__shape is None:
- if self.__info.shape_read:
- self.__shape = self.__info.shape_reader.get_value().shape
- elif self.__info.data_read:
- self.__shape = self._dev_tensor().shape
- else:
- # c++ will throw TraceReadError
- return None
- return self.__shape
-
- def numpy(self):
- if self.__value is None:
- if self.__info.value_read:
- self.__value = self.__info.value_reader.get_value()
- elif self.__info.data_read:
- self.__value = self._dev_tensor().numpy()
- else:
- # c++ will throw TraceReadError
- return None
- # c++ side will handle scalar case
- return self.__value
-
- def _dev_tensor(self):
- if self.__data is None:
- if not self.__info.data_read:
- # c++ will throw TraceReadError
- return None
- self.__data = self.__info.data_reader.get_value()
- return self.__data
-
- def __del__(self):
- if self.__info.shape_read and self.__shape is not None:
- self.__info.shape_reader.drop_value()
- if self.__info.value_read and self.__value is not None:
- self.__info.value_reader.drop_value()
- if self.__info.data_read and self.__data is not None:
- self.__info.data_reader.drop_value()
-
-
- def assign_raw_tensor(lhs, rhs):
- lhs.__init__(rhs)
-
-
- def apply_symbolic_mode(op: OpDef, *args: RawTensor):
- graph = active_trace._lazy_eval_graph
- ivars = []
- for x in args:
- var = getattr(x, "_varnode", None)
- if var:
- ivars.append(var)
- else:
- data_setter = G.InputNode(
- device=x.device,
- dtype=x.dtype,
- shape=x.numpy().shape or (1,),
- graph=graph,
- use_static_shape=True,
- )
- var = data_setter.outputs[0]
- ivars.append(var)
- data_setter.set_value(x._dev_tensor())
-
- require_links = type(op) in _io_op_types
-
- if require_links and active_trace._lazy_eval_links:
- assert len(ivars) > 0, "op should has at least one input"
- opnode = G.VirtualDepNode(
- [ivars[0], *active_trace._lazy_eval_links],
- str(active_trace._lazy_eval_links[0].device),
- )
- ivars[0] = opnode.outputs[0]
- active_trace._lazy_eval_links = (ivars[0],)
-
- if isinstance(op, BackwardGraph):
- ovars = G.apply_backward_varnode(op, *ivars)
- else:
- ovars = G.apply_normal_varnode(op, *ivars)
- outputs = [RawTensor(o) for o in ovars]
-
- if require_links:
- active_trace._lazy_eval_links = (G.VarNode(outputs[0]._varnode),)
-
- return outputs
-
-
- def apply_const_symbolic_mode(value, dtype, device):
- graph = active_trace._lazy_eval_graph
- # don't need to unset tracing
- # because varnode construction will ignore tracing flag
- ret = RawTensor(graph.make_const(value, dtype=dtype, device=device))
- if np.array(value).ndim == 0:
- setscalar(ret)
- return (ret,)
-
-
- def apply_compiled_mode(op: OpDef, *args: RawTensor):
- if skip_tracing:
- args = [
- RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
- for x in args
- ]
- unset_tracing()
- ret = apply(op, *args)
- set_tracing()
- return ret
- return active_trace._apply_op(op, args)
-
-
- def apply_const_compiled_mode(value, dtype, device, is_const, no_cache):
- if skip_tracing:
- args = [
- RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x
- for x in args
- ]
- unset_tracing()
- ret = RawTensor(value, dtype, device, False)
- set_tracing()
- return ret
- return active_trace._apply_const(value, dtype, device)
-
-
- def apply_with_tracing(op: OpDef, *args: RawTensor):
- if active_trace._symbolic:
- outputs = apply_symbolic_mode(op, *args)
- else:
- unset_tracing()
- outputs = apply(op, *args)
- set_tracing()
-
- active_trace._record_op(op, args, outputs)
- return list(outputs)
-
-
- def apply_const_with_tracing(value, dtype, device, is_const, no_cache):
- if active_trace._symbolic:
- outputs = apply_const_symbolic_mode(value, dtype, device)
- else:
- unset_tracing()
- outputs = (RawTensor(value, dtype, device, False),)
- set_tracing()
- active_trace._record_const(outputs)
- return list(outputs)
|