GitOrigin-RevId: 4edc38eaf2
release-1.2
@@ -20,4 +20,4 @@ class Const: | |||
def __call__(self, *reference): | |||
Wrapper = type(reference[0]) | |||
return (Wrapper(self.value, self.dtype, self.device),) | |||
return (Wrapper(self.value, self.dtype, self.device, True),) |
@@ -19,10 +19,11 @@ import numpy as np | |||
from ...utils.comp_graph_tools import set_priority_to_id as _set_priority_to_id | |||
from .. import _imperative_rt | |||
from .._imperative_rt import GraphOptimizeOptions | |||
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode | |||
from .._imperative_rt.ops import BackwardGraph | |||
from .._wrap import device as as_device | |||
from ..ops.builtin import OpDef | |||
from .core import OpBase, TensorBase, apply | |||
from .core import OpBase, TensorBase | |||
class Graph(_imperative_rt.ComputingGraph): | |||
@@ -269,9 +270,8 @@ def optimize_for_inference(dest_vars, **kwargs): | |||
if kwargs: | |||
raise ValueError("unknown options: %s" % list(kwargs)) | |||
res_vars = _imperative_rt.optimize_for_inference( | |||
[i._node for i in dest_vars], inference_options | |||
) | |||
dest_vars = [var._node for var in dest_vars] | |||
res_vars = _imperative_rt.optimize_for_inference(dest_vars, inference_options) | |||
return [VarNode(i) for i in res_vars] | |||
@@ -437,19 +437,25 @@ def _unwrap(x): | |||
return x | |||
@apply.register() | |||
def _(op: OpDef, *args: VarNode): | |||
def apply_normal_op(op: OpDef, *args: VarNode): | |||
outputs = _imperative_rt.invoke_op(op, _unwrap(args)) | |||
return _wrap(outputs) | |||
@apply.register() | |||
def _(op: BackwardGraph, *args: VarNode): | |||
def apply_backward_varnode(op: BackwardGraph, *args: VarNode): | |||
assert args | |||
graph = args[0].graph | |||
return BackwardGraph.interpret( | |||
op, lambda op, args: apply(op, *args), graph._make_const_for_backward, args | |||
outputs = op.interpret( | |||
op, | |||
lambda op, args: apply_normal_op(op, *args), | |||
graph._make_const_for_backward, | |||
args, | |||
) | |||
outputs = [o._node if hasattr(o, "_node") else o for o in outputs] | |||
return outputs | |||
set_cpp_apply_backward_varnode(apply_backward_varnode) | |||
def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): | |||
@@ -6,5 +6,23 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from ..core._imperative_rt.core2 import ( | |||
set_cpp_apply_compiled_mode, | |||
set_cpp_apply_const_compiled_mode, | |||
set_cpp_apply_const_with_tracing, | |||
set_cpp_apply_with_tracing, | |||
) | |||
from .sublinear_memory_config import SublinearMemoryConfig | |||
from .tracing import exclude_from_trace, trace | |||
from .tracing import ( | |||
apply_compiled_mode, | |||
apply_const_compiled_mode, | |||
apply_const_with_tracing, | |||
apply_with_tracing, | |||
exclude_from_trace, | |||
trace, | |||
) | |||
set_cpp_apply_with_tracing(apply_with_tracing) | |||
set_cpp_apply_const_with_tracing(apply_const_with_tracing) | |||
set_cpp_apply_compiled_mode(apply_compiled_mode) | |||
set_cpp_apply_const_compiled_mode(apply_const_compiled_mode) |
@@ -18,8 +18,20 @@ import weakref | |||
import numpy as np | |||
from ..core._imperative_rt import GraphProfiler | |||
from ..core._imperative_rt.core2 import Tensor | |||
from ..core._imperative_rt import GraphProfiler, common, put | |||
from ..core._imperative_rt.core2 import Tensor as RawTensor | |||
from ..core._imperative_rt.core2 import ( | |||
TensorWeakRef, | |||
apply, | |||
call_level, | |||
set_compiled, | |||
set_symbolic, | |||
set_tracing, | |||
skip_tracing, | |||
unset_compiled, | |||
unset_symbolic, | |||
unset_tracing, | |||
) | |||
from ..core._imperative_rt.ops import ( | |||
CollectiveComm, | |||
GaussianRNG, | |||
@@ -29,10 +41,9 @@ from ..core._imperative_rt.ops import ( | |||
) | |||
from ..core._trace_option import set_symbolic_shape | |||
from ..core._wrap import device as as_device | |||
from ..core.ops.builtin import OpDef | |||
from ..core.ops.special import Const | |||
from ..core.tensor import megbrain_graph as G | |||
from ..core.tensor.core import OpBase, TensorBase, TensorWrapperBase, apply | |||
from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor | |||
from .sublinear_memory_config import SublinearMemoryConfig | |||
@@ -45,7 +56,6 @@ class TraceMismatchError(RuntimeError): | |||
active_trace = None | |||
skip_tracing = False | |||
def is_tracing(): | |||
@@ -63,11 +73,13 @@ def exclude_from_trace(): | |||
return | |||
try: | |||
skip_tracing = True | |||
unset_tracing() | |||
if active_trace is not None: | |||
active_trace._begin_excluded_region() | |||
yield | |||
finally: | |||
skip_tracing = False | |||
set_tracing() | |||
class TensorInfo: | |||
@@ -75,9 +87,6 @@ class TensorInfo: | |||
# collected attributes | |||
"external", | |||
"exported", | |||
"data_read", | |||
"shape_read", | |||
"value_read", | |||
"device", | |||
"dtype", | |||
"shape", | |||
@@ -93,9 +102,6 @@ class TensorInfo: | |||
def __init__(self): | |||
self.exported = None | |||
self.data_read = None | |||
self.shape_read = None | |||
self.value_read = None | |||
self.bound_data = None | |||
self.data_setter = None | |||
@@ -147,6 +153,8 @@ class trace: | |||
self._profiler = None | |||
self._graph_opt_level = opt_level | |||
self._symbolic_shape = symbolic_shape | |||
self._handle2tensors = {} | |||
self._handle2compiledtensors = {} | |||
self._reset() | |||
@@ -158,9 +166,9 @@ class trace: | |||
self._graph = None | |||
self._need_reset_nodes = None | |||
self._lazy_eval_graph = None | |||
self._lazy_eval_tensors = weakref.WeakSet() | |||
self._lazy_eval_tensors = set() | |||
self._lazy_eval_links = None | |||
self._active_tensors = weakref.WeakSet() | |||
self._active_tensors = set() | |||
self._tensor_remaps = None | |||
self._inputs_to_restore = None | |||
self._arg_bindings = None | |||
@@ -220,66 +228,72 @@ class trace: | |||
) | |||
info.data_setter.set_value(x._dev_tensor()) | |||
else: | |||
if x.__class__ is not CompiledTensorProxy: | |||
if x not in self._tensor_remaps: | |||
raise TraceMismatchError( | |||
"unexpected capture: trying to use an external tensor as " | |||
"input, but that input was an internal tensor last time" | |||
) | |||
else: | |||
x = self._tensor_remaps[x] | |||
if x._CompiledTensorProxy__handle != h: | |||
raise TraceMismatchError( | |||
"mis-wiring: input edge to an data flow " | |||
"graph node is different from last time" | |||
) | |||
pass | |||
# if x.__class__ is not CompiledTensorProxy: | |||
# if x not in self._tensor_remaps: | |||
# raise TraceMismatchError( | |||
# "unexpected capture: trying to use an external tensor as " | |||
# "input, but that input was an internal tensor last time" | |||
# ) | |||
# else: | |||
# x = self._tensor_remaps[x] | |||
# if x._CompiledTensorProxy__handle != h: | |||
# raise TraceMismatchError( | |||
# "mis-wiring: input edge to an data flow " | |||
# "graph node is different from last time" | |||
# ) | |||
self._pc += 1 | |||
outputs = tuple([CompiledTensorProxy(h) for h in ohandles]) | |||
self._active_tensors.update(outputs) | |||
for h in ohandles: | |||
t = CompiledTensorProxy(h) | |||
t._dev_tensor() | |||
self._handle2compiledtensors[h] = t | |||
outputs = [self._handle2tensors[h] for h in ohandles] | |||
self._active_tensors.update([TensorWeakRef(o) for o in outputs]) | |||
return outputs | |||
def _apply_const(self, op, args): | |||
def _apply_const(self, value, dtype, device): | |||
assert not self._untraced | |||
# check against trace | |||
if self._pc >= len(self._seq): | |||
raise TraceMismatchError("trace should end here, but more op observed") | |||
record = self._seq[self._pc] | |||
op_, ihandles, ohandles = record | |||
assert isinstance(op_, Const) | |||
eq = op_.value == op.value | |||
if not isinstance(eq, bool): | |||
eq = all(eq) | |||
if not eq: | |||
raise TraceMismatchError( | |||
"const tensor violated: got a different tensor this time" | |||
) | |||
assert isinstance(op_, str) and op_ == "Const" | |||
# TODO : assert on const value | |||
# eq = value == self._tinfo[ohandles[0]].bound_data.numpy() | |||
# if not isinstance(eq, bool): | |||
# eq = all(eq) | |||
# if not eq: | |||
# raise TraceMismatchError( | |||
# "const tensor violated: got a different tensor this time" | |||
# ) | |||
self._pc += 1 | |||
(h,) = ohandles | |||
outputs = tuple([self._tinfo[h].bound_data]) | |||
outputs = [self._tinfo[h].bound_data] | |||
return outputs | |||
def _record_op(self, op, inputs, outputs): | |||
if skip_tracing: | |||
for x in inputs: | |||
h = getattr(x, "_TraceMixin__handle", None) | |||
if h is not None: | |||
self._tinfo[h].data_read = True | |||
h = getattr(x, "mixin_handle", -1) | |||
if h >= 0: | |||
x.data_read = True | |||
return | |||
ihandles = [] | |||
for x in inputs: | |||
h = getattr(x, "_TraceMixin__handle", None) | |||
if h is None or (not self._capture_as_const and self._tinfo[h].exported): | |||
h = getattr(x, "mixin_handle", -1) | |||
if h < 0 or (not self._capture_as_const and self._tinfo[h].exported): | |||
h, info = self._new_handle() | |||
info.external = True | |||
info.device = x.device | |||
info.dtype = x.dtype | |||
info.shape = x.shape | |||
if self._capture_as_const: | |||
info.bound_data = x | |||
info.bound_data = RawTensor(x.numpy(), x.dtype, x.device, False) | |||
ihandles.append(h) | |||
@@ -288,17 +302,18 @@ class trace: | |||
h, info = self._new_handle() | |||
ohandles.append(h) | |||
info.external = False | |||
TraceMixin._TraceMixin__inject(x, h) | |||
x.mixin_handle = h | |||
self._handle2tensors[h] = x | |||
self._seq.append((op, tuple(ihandles), tuple(ohandles))) | |||
self._active_tensors.update(outputs) | |||
self._active_tensors.update([TensorWeakRef(o) for o in outputs]) | |||
def _record_const(self, op, outputs): | |||
def _record_const(self, outputs): | |||
if skip_tracing: | |||
(x,) = outputs | |||
h = getattr(x, "_TraceMixin__handle", None) | |||
if h is not None: | |||
self._tinfo[h].data_read = True | |||
h = getattr(x, "mixin_handle", -1) | |||
if h >= 0: | |||
x.data_read = True | |||
return | |||
(x,) = outputs | |||
@@ -310,8 +325,9 @@ class trace: | |||
info.shape = x.shape | |||
info.bound_data = x | |||
info.is_const = True | |||
TraceMixin._TraceMixin__inject(x, h) | |||
self._seq.append((op, tuple(), tuple(ohandles))) | |||
x.mixin_handle = h | |||
self._handle2tensors[h] = x | |||
self._seq.append(("Const", tuple(), tuple(ohandles))) | |||
def _set_active(self, active: bool): | |||
global active_trace | |||
@@ -324,11 +340,8 @@ class trace: | |||
active_trace = None | |||
def _init_trace(self, symbolic: bool): | |||
apply.enable(apply_with_tracing) | |||
apply.enable(apply_const_with_tracing) | |||
if symbolic: | |||
apply.enable(apply_symbolic_mode) | |||
apply.enable(apply_const_symbolic_mode) | |||
set_symbolic() | |||
self._lazy_eval_graph = G.Graph() | |||
self._apply_graph_options(self._lazy_eval_graph) | |||
self._lazy_eval_links = () | |||
@@ -339,10 +352,7 @@ class trace: | |||
return escaped_tensors | |||
def _lazy_eval(self, lazy_eval_graph, lazy_eval_tensors, lazy_eval_links): | |||
readers = [ | |||
G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] | |||
for x in lazy_eval_tensors | |||
] | |||
readers = [G.OutputNode(x()._varnode).outputs[0] for x in lazy_eval_tensors] | |||
self._apply_graph_options(lazy_eval_graph) | |||
# FIXME | |||
if self._graph_opt_level is not None: | |||
@@ -353,20 +363,22 @@ class trace: | |||
lazy_eval_graph.compile(*lazy_eval_links, *readers) | |||
lazy_eval_graph() | |||
for r, x in zip(readers, lazy_eval_tensors): | |||
assign_raw_tensor(x, as_raw_tensor(r.op.get_value())) | |||
x()._handle = RawTensor(r.op.get_value())._handle | |||
@contextlib.contextmanager | |||
def _setup(self): | |||
interrupted = False | |||
def do_enter(): | |||
set_tracing() | |||
self._save_symbolic_shape = set_symbolic_shape(self._symbolic_shape) | |||
self._set_active(True) | |||
if self._untraced: | |||
self._init_trace(self._symbolic) | |||
else: | |||
apply.enable(apply_compiled_mode) | |||
apply.enable(apply_const_compiled_mode) | |||
# disable symbolic mode | |||
unset_symbolic() | |||
set_compiled() | |||
if self._graph is None: | |||
self._compile() | |||
self._graph.execute() | |||
@@ -375,12 +387,12 @@ class trace: | |||
escaped_tensors = self._take_escaped_tensors() | |||
if self._untraced: | |||
for x in escaped_tensors: | |||
info = self._tinfo[x._TraceMixin__handle] | |||
info.data_read = True | |||
x._TraceMixin__restore() | |||
info = self._tinfo[x().mixin_handle] | |||
x().data_read = True | |||
x().mixin_handle = -1 | |||
if self._inputs_to_restore: | |||
for x in self._inputs_to_restore: | |||
x._TraceMixin__restore() | |||
x.mixin_handle = -1 | |||
if self._symbolic and ( | |||
self._lazy_eval_tensors or self._lazy_eval_links | |||
): | |||
@@ -399,7 +411,7 @@ class trace: | |||
if self._pc == len(self._seq): | |||
for x in escaped_tensors: | |||
try: | |||
assign_raw_tensor(x, as_raw_tensor(x._dev_tensor())) | |||
assign_raw_tensor(x(), RawTensor(x()._dev_tensor())) | |||
except TraceMismatchError: | |||
# TraceMismatchError thrown in do_exit | |||
pass | |||
@@ -409,22 +421,20 @@ class trace: | |||
# reset status | |||
self._pc = 0 | |||
self._tensor_remaps = None | |||
apply.disable(apply_with_tracing) | |||
apply.disable(apply_const_with_tracing) | |||
apply.disable(apply_symbolic_mode) | |||
apply.disable(apply_const_symbolic_mode) | |||
apply.disable(apply_compiled_mode) | |||
apply.disable(apply_const_compiled_mode) | |||
self._set_active(False) | |||
# Restore global variable | |||
set_symbolic_shape(self._save_symbolic_shape) | |||
unset_compiled() | |||
unset_symbolic() | |||
unset_tracing() | |||
def do_exit(): | |||
unset_tracing() | |||
if not self._untraced and self._pc != len(self._seq): | |||
raise TraceMismatchError("premature end") | |||
if not self._symbolic or not self._untraced: | |||
for x in self._active_tensors: | |||
x._dev_tensor() | |||
x()._dev_tensor() | |||
x().mixin_handle = -1 | |||
try: | |||
do_enter() | |||
@@ -447,9 +457,9 @@ class trace: | |||
# conditionally reading a compiled tensor in excluded region | |||
# is permitted, so we have to assume every tensor might be read | |||
for x in self._active_tensors: | |||
info = self._tinfo[x._TraceMixin__handle] | |||
info = self._tinfo[x().mixin_handle] | |||
info.exported = True | |||
info.data_read = True | |||
x().data_read = True | |||
def _apply_graph_options(self, graph): | |||
@@ -503,7 +513,7 @@ class trace: | |||
in_out_links += opnode.outputs[1:] | |||
for op, ihandles, ohandles in self._seq: | |||
if isinstance(op, Const): | |||
if isinstance(op, str) and op == "Const": | |||
assert len(ihandles) == 0 | |||
(h,) = ohandles | |||
info = self._tinfo[h] | |||
@@ -554,7 +564,10 @@ class trace: | |||
io_links = (info.varnode,) | |||
ivars.append(info.varnode) | |||
ivars = [RawTensor(ivar) for ivar in ivars] | |||
ovars = apply(op, *ivars) | |||
ovars = [x._varnode for x in ovars] | |||
if require_links and len(ovars) > 0: | |||
io_links = (ovars[0],) | |||
assert len(ovars) == len(ohandles) | |||
@@ -568,7 +581,8 @@ class trace: | |||
readers.append(opnode.outputs[0]) | |||
in_out_links = opnode.outputs | |||
if info.data_read: | |||
x = self._handle2tensors[h] | |||
if x.data_read: | |||
# Shape can be obtained from data so doesn't need its own | |||
# output node. On the other hand, value is read separately | |||
# to leverage eager h2d copy | |||
@@ -581,6 +595,7 @@ class trace: | |||
if info.shape_read: | |||
opnode = info.shape_reader = G.AttrOutputNode(v, *in_out_links) | |||
add_reader(opnode) | |||
# FIXME | |||
if self._graph_opt_level is not None: | |||
graph.options.graph_opt_level = self._graph_opt_level | |||
@@ -593,18 +608,6 @@ class trace: | |||
for opnode in self._need_reset_nodes: | |||
opnode.reset() | |||
def _require_shape(self, handle): | |||
info = self._tinfo[handle] | |||
info.shape_read = True | |||
def _require_value(self, handle): | |||
info = self._tinfo[handle] | |||
info.value_read = True | |||
def _require_data(self, handle): | |||
info = self._tinfo[handle] | |||
info.data_read = True | |||
def __call__(self, *args, **kwargs): | |||
if is_tracing(): | |||
return self.__wrapped__(*args, **kwargs) | |||
@@ -728,8 +731,9 @@ class trace: | |||
dtype=info.dtype, device=dumped_device, shape=info.shape or (1,), name=k | |||
) | |||
set_tracing() | |||
for op, ihandles, ohandles in self._seq: | |||
if isinstance(op, Const): | |||
if isinstance(op, str) and op == "Const": | |||
assert len(ihandles) == 0 | |||
(h,) = ohandles | |||
info = self._tinfo[h] | |||
@@ -750,7 +754,9 @@ class trace: | |||
info.bound_data.numpy(), dtype=info.dtype, device=dumped_device | |||
) | |||
ivars.append(h2v[h]) | |||
ivars = [RawTensor(ivar) for ivar in ivars] | |||
ovars = apply(op, *ivars) | |||
ovars = [x._varnode for x in ovars] | |||
assert len(ovars) == len(ohandles) | |||
h2v.update(zip(ohandles, ovars)) | |||
@@ -761,6 +767,7 @@ class trace: | |||
v.name = output_names[i] | |||
dest_vars.append(v) | |||
dest_vars = [G.VarNode(var) for var in dest_vars] | |||
if optimize_for_inference: | |||
dest_vars = G.optimize_for_inference(dest_vars, **kwargs) | |||
@@ -782,15 +789,15 @@ class trace: | |||
info.external = False | |||
info.device = x.device | |||
info.dtype = x.dtype | |||
info.shape = x.shape | |||
TraceMixin._TraceMixin__inject(x, h) | |||
info.shape = x.numpy().shape | |||
x.mixin_handle = h | |||
self._handle2tensors[h] = x | |||
self._inputs_to_restore.append(x) | |||
return h | |||
self._arg_bindings = [] | |||
for i, x in enumerate(args): | |||
x = find_raw_tensor(x) | |||
if x is None: | |||
if not isinstance(x, RawTensor): | |||
raise TypeError( | |||
"positional arguments should all be tensor " | |||
"but args[%d] cannot be recognized as one" % i | |||
@@ -799,8 +806,7 @@ class trace: | |||
self._kwarg_bindings = {} | |||
for k, x in kwargs.items(): | |||
x = find_raw_tensor(x) | |||
if x is not None: | |||
if isinstance(x, RawTensor): | |||
self._kwarg_bindings[k] = record_input(x) | |||
else: | |||
if len(args) != len(self._arg_bindings): | |||
@@ -809,8 +815,7 @@ class trace: | |||
self._tensor_remaps = {} | |||
for i, (h, x) in enumerate(zip(self._arg_bindings, args)): | |||
x = find_raw_tensor(x) | |||
if x is None: | |||
if not isinstance(x, RawTensor): | |||
raise TypeError( | |||
"positional arguments should all be tensor " | |||
"but args[%d] cannot be recognized as one" % i | |||
@@ -825,8 +830,7 @@ class trace: | |||
kwargs_tensors = {} | |||
for k, x in kwargs.items(): | |||
x = find_raw_tensor(x) | |||
if x is not None: | |||
if isinstance(x, RawTensor): | |||
kwargs_tensors[k] = x | |||
if set(kwargs_tensors) != set(self._kwarg_bindings): | |||
too_many = set(kwargs_tensors) - set(self._kwarg_bindings) | |||
@@ -877,18 +881,17 @@ class trace: | |||
self._output_bindings = [] | |||
for i, x in enumerate(outputs): | |||
x = find_raw_tensor(x) | |||
if x is None: | |||
if not isinstance(x, RawTensor): | |||
raise TypeError("every item of return value should be tensor") | |||
if self._untraced: | |||
if not isinstance(x, TraceMixin): | |||
h = x.mixin_handle | |||
if h < 0: | |||
raise RuntimeError("output is not computed from inputs") | |||
h = x._TraceMixin__handle | |||
self._output_bindings.append(h) | |||
else: | |||
if not isinstance(x, CompiledTensorProxy): | |||
h = x.mixin_handle | |||
if h not in self._handle2compiledtensors: | |||
raise RuntimeError("output is not computed from inputs") | |||
h = x._CompiledTensorProxy__handle | |||
if h != self._output_bindings[i]: | |||
raise TraceMismatchError( | |||
"retval[%s] is a different tensor than last time" | |||
@@ -912,7 +915,7 @@ class trace: | |||
) | |||
class CompiledTensorProxy(RawTensor): | |||
class CompiledTensorProxy: | |||
""" | |||
Duck-typed RawTensor | |||
""" | |||
@@ -924,6 +927,8 @@ class CompiledTensorProxy(RawTensor): | |||
self.__shape = None | |||
self.__data = None | |||
self.__value = None | |||
self.__tensor = active_trace._handle2tensors[handle] | |||
self.__tensor.mixin_handle = handle | |||
@property | |||
def dtype(self): | |||
@@ -938,19 +943,19 @@ class CompiledTensorProxy(RawTensor): | |||
if self._isscalar: | |||
return () | |||
if self.__shape is None: | |||
if self.__info.shape_read: | |||
if self.__tensor.shape_read: | |||
self.__shape = self.__info.shape_reader.get_value().shape | |||
elif self.__info.data_read: | |||
self.__shape = self._dev_tensor().shape | |||
elif self.__tensor.data_read: | |||
self.__shape = self.__tensor._dev_tensor().shape | |||
else: | |||
raise TraceMismatchError("shape of this tensor is not read in trace") | |||
return self.__shape | |||
def numpy(self): | |||
if self.__value is None: | |||
if self.__info.value_read: | |||
if self.__tensor.value_read: | |||
self.__value = self.__info.value_reader.get_value() | |||
elif self.__info.data_read: | |||
elif self.__tensor.data_read: | |||
self.__value = self._dev_tensor().numpy() | |||
else: | |||
raise TraceMismatchError("value of this tensor is not read in trace") | |||
@@ -960,9 +965,11 @@ class CompiledTensorProxy(RawTensor): | |||
def _dev_tensor(self): | |||
if self.__data is None: | |||
if not self.__info.data_read: | |||
if not self.__tensor.data_read: | |||
raise TraceMismatchError("raw data of this tensor is not read in trace") | |||
self.__data = self.__info.data_reader.get_value() | |||
self.__tensor._reset(RawTensor(self.__data)) | |||
self.__tensor.mixin_handle = self.__handle | |||
return self.__data | |||
def _drop(self): | |||
@@ -975,132 +982,31 @@ class CompiledTensorProxy(RawTensor): | |||
return | |||
def __del__(self): | |||
if self.__info.shape_read and self.__shape is not None: | |||
if self.__tensor.shape_read and self.__shape is not None: | |||
self.__info.shape_reader.drop_value() | |||
if self.__info.value_read and self.__value is not None: | |||
self.__info.value_reader.drop_value() | |||
if self.__info.data_read and self.__data is not None: | |||
# if self.__tensor.value_read and self.__value is not None: | |||
# self.__info.value_reader.drop_value() | |||
if self.__tensor.data_read and self.__data is not None: | |||
self.__info.data_reader.drop_value() | |||
class LazyEvalTensor(RawTensor): | |||
def __init__(self, varnode, isscalar=False): | |||
super().__init__() | |||
self.__varnode = varnode | |||
self._isscalar = isscalar | |||
@property | |||
def dtype(self): | |||
return self.__varnode.dtype | |||
@property | |||
def device(self): | |||
return self.__varnode.device | |||
@property | |||
def shape(self): | |||
if self._isscalar: | |||
return () | |||
return self.__varnode.shape | |||
def numpy(self): | |||
ret = self.__varnode.value | |||
if self._isscalar: | |||
ret = ret.squeeze() | |||
return ret | |||
def _drop(self): | |||
return | |||
def _swap_in(self): | |||
return | |||
def _swap_out(self): | |||
return | |||
def _dev_tensor(self): | |||
raise RuntimeError("cannot access data during symbolic tracing") | |||
class TraceMixin: | |||
__subclass_cache = {} | |||
def __inject(self, handle): | |||
cache = __class__.__subclass_cache | |||
cls = self.__class__ | |||
subcls = cache.get(cls) | |||
if subcls is None: | |||
subcls = cache[cls] = type("Traced" + cls.__name__, (__class__, cls), {}) | |||
self.__class__ = subcls | |||
self.__handle = handle | |||
self.__cls = cls | |||
return self | |||
def __restore(self): | |||
cls = self.__cls | |||
del self.__handle | |||
del self.__cls | |||
self.__class__ = cls | |||
return self | |||
@property | |||
def shape(self): | |||
if not skip_tracing: | |||
active_trace._require_shape(self.__handle) | |||
return super().shape | |||
def numpy(self): | |||
if not skip_tracing: | |||
active_trace._require_value(self.__handle) | |||
return super().numpy() | |||
def _dev_tensor(self): | |||
if not skip_tracing: | |||
active_trace._require_data(self.__handle) | |||
return super()._dev_tensor() | |||
def _drop(self): | |||
return | |||
def _swap_in(self): | |||
return | |||
def _swap_out(self): | |||
return | |||
class TracedRawTensor(TraceMixin, RawTensor): | |||
pass | |||
class TracedLazyTensor(TraceMixin, LazyEvalTensor): | |||
pass | |||
def assign_raw_tensor(lhs, rhs): | |||
handle = rhs._handle | |||
# Keep isscalar of lhs | |||
isscalar = lhs._isscalar | |||
rhs.__dict__.clear() | |||
lhs.__dict__.clear() | |||
lhs.__class__ = RawTensor | |||
lhs.__init__(handle, isscalar=isscalar) | |||
lhs.__init__(rhs) | |||
# this hook turns RawTensor into LazyEvalTensor | |||
@apply.register() | |||
# this hook turns RawTensor into LazyEvalTensor(varnode) | |||
def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||
graph = active_trace._lazy_eval_graph | |||
ivars = [] | |||
for x in args: | |||
var = getattr(x, "_LazyEvalTensor__varnode", None) | |||
var = getattr(x, "_varnode", None) | |||
if var: | |||
ivars.append(var) | |||
else: | |||
data_setter = G.InputNode( | |||
device=x.device, | |||
dtype=x.dtype, | |||
shape=x.shape or (1,), | |||
shape=x.numpy().shape or (1,), | |||
graph=graph, | |||
use_static_shape=True, | |||
) | |||
@@ -1119,108 +1025,75 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||
ivars[0] = opnode.outputs[0] | |||
active_trace._lazy_eval_links = (ivars[0],) | |||
ovars = apply(op, *ivars) | |||
ivars = [ | |||
RawTensor(ivar._node) if hasattr(ivar, "_node") else RawTensor(ivar) | |||
for ivar in ivars | |||
] | |||
unset_symbolic() | |||
outputs = apply(op, *ivars) | |||
set_symbolic() | |||
if require_links: | |||
active_trace._lazy_eval_links = (ovars[0],) | |||
active_trace._lazy_eval_links = (outputs[0]._varnode,) | |||
outputs = [LazyEvalTensor(v) for v in ovars] | |||
active_trace._lazy_eval_tensors.update(outputs) | |||
active_trace._lazy_eval_tensors.update([TensorWeakRef(o) for o in outputs]) | |||
return outputs | |||
apply.disable(apply_symbolic_mode) | |||
@apply.register() | |||
def apply_const_symbolic_mode(op: Const, *args: RawTensor): | |||
def apply_const_symbolic_mode(value, dtype, device): | |||
graph = active_trace._lazy_eval_graph | |||
ret = LazyEvalTensor( | |||
graph.make_const(op.value, dtype=op.dtype, device=op.device), isscalar=True | |||
) | |||
active_trace._lazy_eval_tensors.add(ret) | |||
# don't need to unset tracing | |||
# because varnode construction will ignore tracing flag | |||
ret = RawTensor(graph.make_const(value, dtype=dtype, device=device)) | |||
active_trace._lazy_eval_tensors.add(TensorWeakRef(ret)) | |||
return (ret,) | |||
apply.disable(apply_const_symbolic_mode) | |||
@apply.register() | |||
def apply_compiled_mode(op: OpDef, *args: RawTensor): | |||
if skip_tracing: | |||
args = [ | |||
as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x | |||
RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x | |||
for x in args | |||
] | |||
return apply.super(op, *args) | |||
unset_tracing() | |||
ret = apply(op, *args) | |||
set_tracing() | |||
return ret | |||
return active_trace._apply_op(op, args) | |||
apply.disable(apply_compiled_mode) | |||
@apply.register() | |||
def apply_const_compiled_mode(op: Const, *args: RawTensor): | |||
def apply_const_compiled_mode(value, dtype, device, is_const): | |||
if skip_tracing: | |||
args = [ | |||
as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x | |||
RawTensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x | |||
for x in args | |||
] | |||
return apply.super(op, *args) | |||
return active_trace._apply_const(op, args) | |||
apply.disable(apply_const_compiled_mode) | |||
unset_tracing() | |||
ret = RawTensor(value, dtype, device, False) | |||
set_tracing() | |||
return ret | |||
return active_trace._apply_const(value, dtype, device) | |||
# this hook injects TraceMixin | |||
@apply.register() | |||
def apply_with_tracing(op: OpDef, *args: RawTensor): | |||
outputs = apply.super(op, *args) | |||
active_trace._record_op(op, args, outputs) | |||
return outputs | |||
apply.disable(apply_with_tracing) | |||
@apply.register() | |||
def apply_const_with_tracing(op: Const, *args: RawTensor): | |||
outputs = apply.super(op, *args) | |||
active_trace._record_const(op, outputs) | |||
return outputs | |||
apply.disable(apply_const_with_tracing) | |||
class BrokenRawTensor(RawTensor): | |||
def __getattribute__(self, _): | |||
raise RuntimeError("broken due to misuse of tracing") | |||
def __setattr__(self, *_): | |||
raise RuntimeError("broken due to misuse of tracing") | |||
@functools.singledispatch | |||
def find_raw_tensor(x): | |||
return None | |||
@find_raw_tensor.register(RawTensor) | |||
def _(x): | |||
return x | |||
if active_trace._symbolic: | |||
outputs = apply_symbolic_mode(op, *args) | |||
else: | |||
unset_tracing() | |||
outputs = apply(op, *args) | |||
set_tracing() | |||
@find_raw_tensor.register(TensorWrapperBase) | |||
def _(x): | |||
x = getattr(x, "__wrapped__", None) | |||
if x is not None: | |||
return find_raw_tensor(x) | |||
active_trace._record_op(op, args, outputs) | |||
return list(outputs) | |||
@find_raw_tensor.register(Tensor) | |||
def _(x): | |||
x = getattr(x, "_data", None) | |||
if x is not None: | |||
return find_raw_tensor(x) | |||
def apply_const_with_tracing(value, dtype, device, is_const): | |||
if active_trace._symbolic: | |||
outputs = apply_const_symbolic_mode(value, dtype, device) | |||
else: | |||
unset_tracing() | |||
outputs = (RawTensor(value, dtype, device, False),) | |||
set_tracing() | |||
active_trace._record_const(outputs) | |||
return list(outputs) |
@@ -28,7 +28,7 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
dmap_callback = None | |||
q_dict = {"mode": None, "scale": None, "zero_point": None} | |||
def __new__(cls, data, dtype=None, device=None): | |||
def __new__(cls, data, dtype=None, device=None, is_const=False): | |||
if device is None: | |||
cn = get_default_device() | |||
elif isinstance(device, str): | |||
@@ -40,6 +40,7 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
assert isinstance(device, CompNode) | |||
cn = device | |||
# import pdb; pdb.set_trace() | |||
if isinstance(data, _Tensor): | |||
obj = _Tensor.__new__(cls, data) | |||
else: | |||
@@ -47,7 +48,7 @@ class Tensor(_Tensor, ArrayMethodMixin): | |||
if 0 in data.strides: | |||
data = data.squeeze().reshape(data.shape) | |||
obj = _Tensor.__new__(cls, data, dtype, cn) | |||
obj = _Tensor.__new__(cls, data, dtype, cn, is_const) | |||
return obj | |||
@property | |||
@@ -296,7 +296,9 @@ void accum_grad(std::shared_ptr<Tensor>& grad, std::shared_ptr<Tensor>&& delta) | |||
Tensor* args[2] = {grad.get(), delta.get()}; | |||
ctx.args = args; | |||
ctx.flags = grad->m_flags | delta->m_flags; | |||
if (is_tracing) { | |||
ctx.flags |= Tensor::Flags::TRACE; | |||
} | |||
grad = apply(ctx)[0]; | |||
} | |||
@@ -354,6 +356,9 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr | |||
} | |||
ctx.args = args; | |||
if (is_tracing) | |||
ctx.flags |= Tensor::Flags::TRACE; | |||
auto grads = apply(ctx); | |||
size_t j = 0; | |||
@@ -11,8 +11,10 @@ | |||
#include "./tensor.h" | |||
#include "./grad.h" | |||
#include "./trace.h" | |||
#include "./common.h" | |||
#include "./numpy_dtypes.h" | |||
#include "./graph_rt.h" | |||
#include <pybind11/numpy.h> | |||
#include <pybind11/operators.h> | |||
@@ -23,6 +25,47 @@ namespace mgb::imperative::python { | |||
std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; | |||
py::object cpp_apply_with_tracing, cpp_apply_const_with_tracing, | |||
cpp_apply_compiled_mode, cpp_apply_const_compiled_mode; | |||
py::object cpp_apply_backward_varnode; | |||
#define REGISTE_APPLY_FUNC(mode) \ | |||
void set_##mode(py::object pyf) { \ | |||
mode = pybind11::reinterpret_steal<py::object>(pyf); \ | |||
} | |||
REGISTE_APPLY_FUNC(cpp_apply_with_tracing) | |||
REGISTE_APPLY_FUNC(cpp_apply_const_with_tracing) | |||
REGISTE_APPLY_FUNC(cpp_apply_compiled_mode) | |||
REGISTE_APPLY_FUNC(cpp_apply_const_compiled_mode) | |||
REGISTE_APPLY_FUNC(cpp_apply_backward_varnode) | |||
#undef REGISTE_APPLY_FUNC | |||
bool is_tracing = false; | |||
bool is_symbolic = false; | |||
bool is_compiled = false; | |||
int64_t call_level = 0; | |||
#define SET_UNSET_PROP(mode) \ | |||
void set_##mode() { \ | |||
is_##mode = true; \ | |||
} \ | |||
void unset_##mode() { \ | |||
is_##mode = false; \ | |||
} \ | |||
SET_UNSET_PROP(tracing) | |||
SET_UNSET_PROP(symbolic) | |||
SET_UNSET_PROP(compiled) | |||
#undef SET_UNSET_PROP | |||
bool skip_tracing = false; | |||
apply_result_t apply(ApplyContext& ctx) { | |||
// emulating scalar should be put to specific op's apply, e.g., | |||
// elementwise, reduce, typecvt. Currently it's still handled at python | |||
@@ -36,7 +79,7 @@ apply_result_t apply(ApplyContext& ctx) { | |||
} | |||
if (ctx.flags & Tensor::Flags::TRACE) { | |||
// TODO: trace | |||
return apply_trace(ctx); | |||
} else { | |||
SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs); | |||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||
@@ -58,7 +101,6 @@ apply_result_t apply(ApplyContext& ctx) { | |||
PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */) { | |||
try { | |||
// if (kwnames && PyTuple_GET_SIZE(kwnames)) { | |||
// PyErr_SetString(PyExc_TypeError, "keyword argument not allowed"); | |||
// return nullptr; | |||
@@ -67,6 +109,7 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||
PyErr_SetString(PyExc_TypeError, "expect Op"); | |||
return nullptr; | |||
} | |||
auto* op = args[0]; | |||
PyTypeObject* pytype = args[1]->ob_type; | |||
@@ -79,18 +122,23 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||
SmallVector<Tensor*, 64> tensors(nargs); | |||
ctx.args = &tensors[0]; | |||
ctx.nargs = nargs; | |||
if (strstr(op->ob_type->tp_name, "BackwardGraph")) { | |||
ctx.backward = true; | |||
} | |||
for (size_t i = 0; i < nargs; ++i) { | |||
TensorWrapper* tw = TensorWrapper::cast_safe(args[i]); | |||
if (!tw) { | |||
if (TensorWrapper* tw = TensorWrapper::cast_safe(args[i])) { | |||
auto* t = tensors[i] = tw->m_tensor.get(); | |||
ctx.flags |= t->m_flags; | |||
} else { | |||
PyErr_SetString(PyExc_TypeError, "expect Tensor"); | |||
return nullptr; | |||
} | |||
auto* t = tensors[i] = tw->m_tensor.get(); | |||
ctx.flags |= t->m_flags; | |||
} | |||
// TODO: set TRACE flag | |||
if (is_tracing) { | |||
ctx.flags |= Tensor::Flags::TRACE; | |||
} | |||
auto outputs = apply(ctx); | |||
size_t nout = outputs.size(); | |||
@@ -99,7 +147,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||
ret[i] = TensorWrapper::make(pytype, std::move(outputs[i])); | |||
} | |||
return ret.release().ptr(); | |||
} catch (std::exception& e) { | |||
PyErr_SetString(PyExc_RuntimeError, e.what()); | |||
return nullptr; | |||
@@ -122,36 +169,116 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||
} | |||
m_tensor = t->m_tensor; | |||
} else { | |||
if (nargs != 3) { | |||
throw py::type_error("expect 3 arguments"); | |||
} | |||
py::detail::loader_life_support life_sup; // required to cast DType | |||
auto data = tup[0].cast<py::array>(); | |||
DType dtype = tup[1].cast<DType>(); | |||
CompNode cn = tup[2].cast<CompNode>(); | |||
interpreter::Interpreter::Handle handle; | |||
constexpr auto size_threshhold = TensorShape::MAX_NDIM; | |||
if (data.size() > size_threshhold) { | |||
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype)); | |||
if (nargs == 1) { | |||
auto arg0 = PyTuple_GetItem(args, 0); | |||
// for lazy_eval_tensor | |||
if (strstr(arg0->ob_type->tp_name, "VarNode")) { | |||
if (PyObject_HasAttrString(arg0, "_node")) { | |||
arg0 = PyObject_GetAttrString(arg0, "_node"); | |||
} | |||
m_tensor = std::make_shared<Tensor>(py::handle(arg0).cast<cg::VarNode *>()); | |||
} else { | |||
// for DeviceTensorND | |||
if (strstr(arg0->ob_type->tp_name, "DeviceTensorND")) { | |||
auto dv = py::handle(arg0).cast<DeviceTensorND>(); | |||
interpreter::Interpreter::Handle handle = interpreter_for_py->put(dv); | |||
m_tensor = std::make_shared<Tensor>(handle); | |||
} else { | |||
throw py::type_error("single argument is not tensor, varnode or devicetensor"); | |||
} | |||
} | |||
} else { | |||
HostTensorND ret(cn); | |||
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype)); | |||
} | |||
py::detail::loader_life_support life_sup; // required to cast DType | |||
auto data = tup[0].cast<py::array>(); | |||
DType dtype = tup[1].cast<DType>(); | |||
CompNode cn = tup[2].cast<CompNode>(); | |||
bool is_const = tup[3].cast<bool>(); | |||
if (nargs != 4) { | |||
throw py::type_error("expect 3 arguments"); | |||
} | |||
// const op | |||
if (is_const && is_tracing) { | |||
py::object pyf; | |||
if (is_compiled) { | |||
pyf = cpp_apply_const_compiled_mode; | |||
} else { | |||
pyf = cpp_apply_const_with_tracing; | |||
} | |||
auto ret = pyf(*tup); | |||
auto py_ret = py::reinterpret_borrow<py::list>(ret); | |||
if (auto* t = cast_safe(py_ret[0].ptr())) { | |||
m_tensor = t->m_tensor; | |||
} | |||
return; | |||
} | |||
interpreter::Interpreter::Handle handle; | |||
constexpr auto size_threshhold = TensorShape::MAX_NDIM; | |||
if (data.size() > size_threshhold) { | |||
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype)); | |||
} else { | |||
HostTensorND ret(cn); | |||
handle = interpreter_for_py->put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype)); | |||
} | |||
m_tensor = std::make_shared<Tensor>(handle); | |||
m_tensor = std::make_shared<Tensor>(handle); | |||
if (data.ndim() == 0) { | |||
m_tensor->m_flags |= Tensor::Flags::SCALAR; | |||
if (data.ndim() == 0) { | |||
m_tensor->m_flags |= Tensor::Flags::SCALAR; | |||
} | |||
} | |||
} | |||
} | |||
#define REGISTE_TENSORWRAPPER_FUNC(type, member) \ | |||
PyObject* TensorWrapper::member() { \ | |||
return py::cast(m_tensor->m_trace_info.member).release().ptr(); \ | |||
} \ | |||
void TensorWrapper::set_##member(PyObject* dest) { \ | |||
auto py_dest = py::reinterpret_borrow<py::object>(dest); \ | |||
type real_dest = py_dest.cast<type>(); \ | |||
m_tensor->m_trace_info.member = real_dest; \ | |||
} | |||
REGISTE_TENSORWRAPPER_FUNC(bool, data_read) | |||
REGISTE_TENSORWRAPPER_FUNC(bool, value_read) | |||
REGISTE_TENSORWRAPPER_FUNC(bool, shape_read) | |||
REGISTE_TENSORWRAPPER_FUNC(int64_t, mixin_handle) | |||
#undef REGISTE_TENSORWRAPPER_FUNC | |||
PyObject* TensorWrapper::handle() { | |||
return py::cast(m_tensor->m_handle).release().ptr(); | |||
} | |||
void TensorWrapper::set_handle(PyObject* dest) { | |||
auto py_dest = py::reinterpret_borrow<py::object>(dest); | |||
SharedHandle real_dest = py_dest.cast<SharedHandle>(); | |||
auto&& t = std::move(m_tensor->m_handle); | |||
m_tensor->m_handle = std::move(real_dest); | |||
} | |||
PyObject* TensorWrapper::shape() { | |||
if (!skip_tracing) { | |||
set_shape_read(py::cast(true). release().ptr()); | |||
} | |||
if (m_tensor->m_flags & Tensor::Flags::SCALAR) { | |||
return PyTuple_New(0); | |||
} | |||
auto&& shape = m_tensor->shape(); | |||
TensorShape shape; | |||
if (m_tensor->m_var) { | |||
shape = m_tensor->m_var->shape(); | |||
} else { | |||
shape = m_tensor->shape(); | |||
} | |||
if (!shape.ndim) { | |||
Py_RETURN_NONE; | |||
} | |||
@@ -164,16 +291,38 @@ PyObject* TensorWrapper::shape() { | |||
PyObject* TensorWrapper::dtype() { | |||
if (m_tensor->m_var) { | |||
return py::cast(m_tensor->m_var->dtype()).release().ptr(); | |||
} | |||
return py::cast(m_tensor->dtype()).release().ptr(); | |||
} | |||
PyObject* TensorWrapper::device() { | |||
if (m_tensor->m_var) { | |||
return py::cast(m_tensor->m_var->comp_node()).release().ptr(); | |||
} | |||
return py::cast(m_tensor->comp_node()).release().ptr(); | |||
} | |||
PyObject* TensorWrapper::numpy() { | |||
if (!skip_tracing) { | |||
set_value_read(py::cast(true).release().ptr()); | |||
} | |||
if (m_tensor->m_handle.get() == nullptr && m_tensor->m_var != nullptr) { | |||
auto&& mgr = m_tensor->m_var->owner_graph()->static_infer_manager(); | |||
auto&& type = mgr.get_infer_type(m_tensor->m_var); | |||
using InferType = cg::static_infer::InferType; | |||
if (!(type.value & (InferType::CONST | InferType::RT_STATIC))) { | |||
return nullptr; | |||
} | |||
auto* val = mgr.infer_value_fallible(m_tensor->m_var); | |||
if (!val) { | |||
return nullptr; | |||
} | |||
return py::cast(*val).attr("numpy")().release().ptr(); | |||
} | |||
auto&& hv = interpreter_for_py->get_value(m_tensor->m_handle.get()); | |||
auto arr = py::reinterpret_steal<py::array>(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); | |||
if (!arr) return nullptr; | |||
@@ -184,6 +333,13 @@ PyObject* TensorWrapper::numpy() { | |||
return arr.release().ptr(); | |||
} | |||
PyObject* TensorWrapper::varnode() { | |||
if (m_tensor->m_var) { | |||
return py::cast(m_tensor->m_var).release().ptr(); | |||
} | |||
return nullptr; | |||
} | |||
void TensorWrapper::reset(PyObject* tensor) { | |||
TensorWrapper* t = TensorWrapper::cast_safe(tensor); | |||
if (!t) { | |||
@@ -195,13 +351,22 @@ void TensorWrapper::reset(PyObject* tensor) { | |||
PyObject* TensorWrapper::detach() { | |||
PyObject* self = wrap_t::pycast(this); | |||
PyTypeObject* pytype = self->ob_type; | |||
auto new_tensor = std::make_shared<Tensor>(m_tensor->m_handle); | |||
std::shared_ptr<Tensor> new_tensor; | |||
if (m_tensor->m_handle.get()) { | |||
new_tensor = std::make_shared<Tensor>(m_tensor->m_handle); | |||
} else { | |||
new_tensor = std::make_shared<Tensor>(m_tensor->m_var); | |||
} | |||
auto ret = TensorWrapper::make(pytype, std::move(new_tensor)); | |||
return ret.release().ptr(); | |||
} | |||
PyObject* TensorWrapper::_dev_tensor(){ | |||
if (!skip_tracing) { | |||
set_data_read(py::cast(true).release().ptr()); | |||
} | |||
auto dev_tensor = interpreter_for_py->get_dev_tensor(m_tensor->m_handle.get()); | |||
return py::cast(dev_tensor).release().ptr(); | |||
} | |||
@@ -227,11 +392,14 @@ PyObject* TensorWrapper::isscalar() { | |||
} | |||
} | |||
void TensorWrapper::setscalar() { | |||
m_tensor->m_flags |= Tensor::Flags::SCALAR; | |||
} | |||
PyMethodDef apply_def{"apply", (PyCFunction)py_apply, METH_FASTCALL, nullptr}; | |||
struct TensorWeakRef { | |||
std::weak_ptr<Tensor> wptr; | |||
@@ -262,6 +430,12 @@ void init_tensor(py::module m) { | |||
.def<&TensorWrapper::_swap_out>("_swap_out") | |||
.def<&TensorWrapper::_swap_in>("_swap_in") | |||
.def<&TensorWrapper::_drop>("_drop") | |||
.def_getset<&TensorWrapper::varnode>("_varnode") | |||
.def_getset<&TensorWrapper::data_read, &TensorWrapper::set_data_read>("data_read") | |||
.def_getset<&TensorWrapper::value_read, &TensorWrapper::set_value_read>("value_read") | |||
.def_getset<&TensorWrapper::shape_read, &TensorWrapper::set_shape_read>("shape_read") | |||
.def_getset<&TensorWrapper::mixin_handle, &TensorWrapper::set_mixin_handle>("mixin_handle") | |||
.def_getset<&TensorWrapper::handle, &TensorWrapper::set_handle>("_handle") | |||
.finalize(); | |||
if (!tensor_type) throw py::error_already_set(); | |||
py::setattr(m, "Tensor", tensor_type); | |||
@@ -296,6 +470,25 @@ void init_tensor(py::module m) { | |||
if (!grad_key_type) throw py::error_already_set(); | |||
py::setattr(m, "GradKey", grad_key_type); | |||
py::setattr(m, "backward", py::cpp_function(&GradKeyWrapper::backward)); | |||
m.def("set_cpp_apply_with_tracing", &set_cpp_apply_with_tracing); | |||
m.def("set_cpp_apply_const_with_tracing", &set_cpp_apply_const_with_tracing); | |||
m.def("set_cpp_apply_compiled_mode", &set_cpp_apply_compiled_mode); | |||
m.def("set_cpp_apply_const_compiled_mode", &set_cpp_apply_const_compiled_mode); | |||
m.def("set_cpp_apply_backward_varnode", &set_cpp_apply_backward_varnode); | |||
m.attr("skip_tracing") = &skip_tracing; | |||
m.attr("call_level") = &call_level; | |||
py::class_<SharedHandle>(m, "SharedHandle") | |||
.def(py::init<const SharedHandle&>()); | |||
m.def("set_tracing", &set_tracing); | |||
m.def("unset_tracing", &unset_tracing); | |||
m.def("set_symbolic", &set_symbolic); | |||
m.def("unset_symbolic", &unset_symbolic); | |||
m.def("set_compiled", &set_compiled); | |||
m.def("unset_compiled", &unset_compiled); | |||
} | |||
} // namespace mgb::imperative::python |
@@ -30,13 +30,10 @@ struct ObjectPtr : B { | |||
} // namespace mgb::imperative::python | |||
#include "./grad_info.h" // for struct GradInfo | |||
#include "./trace_info.h" // for struct TraceInfo | |||
namespace mgb::imperative::python { | |||
struct TraceInfo { | |||
}; | |||
extern std::unique_ptr<interpreter::Interpreter::Channel> interpreter_for_py; | |||
class SharedHandle { | |||
@@ -46,7 +43,9 @@ class SharedHandle { | |||
public: | |||
inline explicit SharedHandle(Handle handle) : holder(handle, [](auto* h){ | |||
interpreter_for_py->del(h); | |||
if (h) { | |||
interpreter_for_py->del(h); | |||
} | |||
}) {} | |||
SharedHandle(const SharedHandle&) = default; | |||
SharedHandle& operator=(const SharedHandle&) = default; | |||
@@ -71,11 +70,14 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||
GradInfo m_grad_info; | |||
TraceInfo m_trace_info; | |||
SharedHandle m_handle; | |||
cg::VarNode* m_var; | |||
using Handle = interpreter::Interpreter::Handle; | |||
inline explicit Tensor(Handle handle) : m_handle(handle) {} | |||
inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)) {} | |||
inline explicit Tensor(Handle handle) : m_handle(handle), m_var(nullptr) {} | |||
inline explicit Tensor(SharedHandle handle) : m_handle(std::move(handle)), m_var(nullptr) {} | |||
inline explicit Tensor(cg::VarNode *var) : m_handle(nullptr), m_var(var) {} | |||
~Tensor() = default; | |||
inline std::shared_ptr<Tensor> copy() { | |||
@@ -83,12 +85,28 @@ struct Tensor : std::enable_shared_from_this<Tensor>, NonCopyableObj { | |||
ret->m_flags = m_flags; | |||
ret->m_grad_info = m_grad_info; | |||
ret->m_trace_info = m_trace_info; | |||
ret->m_var = m_var; | |||
return ret; | |||
} | |||
inline DType dtype() {return interpreter_for_py->get_dtype(m_handle.get());} | |||
inline CompNode comp_node() {return interpreter_for_py->get_device(m_handle.get());} | |||
inline TensorShape shape() {return interpreter_for_py->get_shape(m_handle.get());} | |||
inline DType dtype() { | |||
if (m_var) { | |||
return m_var->dtype(); | |||
} | |||
return interpreter_for_py->get_dtype(m_handle.get()); | |||
} | |||
inline CompNode comp_node() { | |||
if (m_var) { | |||
return m_var->comp_node(); | |||
} | |||
return interpreter_for_py->get_device(m_handle.get()); | |||
} | |||
inline TensorShape shape() { | |||
if (m_var) { | |||
return m_var->shape(); | |||
} | |||
return interpreter_for_py->get_shape(m_handle.get()); | |||
} | |||
}; | |||
@@ -135,6 +153,19 @@ struct TensorWrapper { | |||
void _swap_in(); | |||
void _swap_out(); | |||
void _drop(); | |||
PyObject* varnode(); | |||
PyObject* handle(); | |||
void set_handle(PyObject *); | |||
PyObject* data_read(); | |||
PyObject* value_read(); | |||
PyObject* shape_read(); | |||
PyObject* mixin_handle(); | |||
void set_data_read(PyObject*); | |||
void set_value_read(PyObject*); | |||
void set_shape_read(PyObject*); | |||
void set_mixin_handle(PyObject*); | |||
}; | |||
@@ -145,6 +176,7 @@ struct ApplyContext { | |||
std::shared_ptr<OpDef> op; | |||
Tensor*const* args; | |||
size_t nargs; | |||
bool backward = false; | |||
}; | |||
using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>; | |||
@@ -153,6 +185,14 @@ apply_result_t apply(ApplyContext& ctx); | |||
void init_tensor(pybind11::module); | |||
extern bool is_tracing; | |||
extern bool is_symbolic; | |||
extern bool is_compiled; | |||
extern int64_t call_level; | |||
extern pybind11::object cpp_apply_with_tracing, cpp_apply_compiled_mode; | |||
extern pybind11::object cpp_apply_backward_varnode; | |||
} // namespace mgb::imperative::python | |||
namespace pybind11::detail { | |||
@@ -0,0 +1,94 @@ | |||
/** | |||
* \file imperative/python/src/trace.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "./trace.h" | |||
#include "./helper.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
namespace py = pybind11; | |||
namespace mgb::imperative::python { | |||
apply_result_t apply_tensor_on_var_node(ApplyContext& ctx) { | |||
apply_result_t outputs; | |||
cg::VarNodeArray vinputs(ctx.nargs); | |||
for (size_t i = 0; i < ctx.nargs; i++) { | |||
vinputs[i] = ctx.args[i]->m_var; | |||
} | |||
auto ovars = OpDef::apply_on_var_node(*ctx.op, vinputs); | |||
for (size_t i = 0; i < ovars.size(); i++) { | |||
outputs.emplace_back(std::make_shared<Tensor>(ovars[i])); | |||
} | |||
return outputs; | |||
} | |||
apply_result_t apply_trace(ApplyContext& ctx) { | |||
apply_result_t outputs; | |||
bool run_apply_on_var_node = false; | |||
for (size_t i = 0; i < ctx.nargs; i++) { | |||
run_apply_on_var_node |= ((ctx.args[i]->m_handle.get() == nullptr) & (ctx.args[i]->m_var != nullptr)); | |||
} | |||
if (ctx.backward) { | |||
// reach here when symbolic=True or compiled=True | |||
// call megbrain_graph.py apply(BackwardGraph, *args) | |||
auto args = py::tuple(ctx.nargs); | |||
for (size_t i = 0; i < ctx.nargs; i++) { | |||
args[i] = py::cast(ctx.args[i]->m_var); | |||
} | |||
py::object ret = cpp_apply_backward_varnode(py::cast(ctx.op), *args); | |||
if (!ret) { | |||
throw py::value_error("invalid py object call"); | |||
} | |||
// assumption: python function always returns PyList | |||
auto tup = py::reinterpret_borrow<py::list>(ret); | |||
for (auto i = 0; i < tup.size(); i++) { | |||
auto pitem = tup[i].cast<cg::VarNode *>(); | |||
outputs.emplace_back(std::make_shared<Tensor>(pitem)); | |||
} | |||
return outputs; | |||
} | |||
if (run_apply_on_var_node && !is_symbolic) { | |||
return apply_tensor_on_var_node(ctx); | |||
} | |||
py::object pyf; | |||
if (is_compiled) { | |||
// run apply in compiled mode, step 2, 3, etc | |||
pyf = cpp_apply_compiled_mode; | |||
} else { | |||
// run first step, both symbolic and non symbolic | |||
pyf = cpp_apply_with_tracing; | |||
} | |||
auto args = py::tuple(ctx.nargs); | |||
for (size_t i = 0; i < ctx.nargs; i++) { | |||
args[i] = TensorWrapper::make(std::move(std::shared_ptr<Tensor>(ctx.args[i]))).release(); | |||
} | |||
auto ret = pyf(py::cast(ctx.op), *args); | |||
// assumption: python function always returns PyList | |||
auto tup = py::reinterpret_borrow<py::list>(ret); | |||
for (auto i = 0; i < tup.size(); i++) { | |||
auto tw = TensorWrapper::cast_safe(tup[i].ptr()); | |||
outputs.emplace_back(tw->m_tensor); | |||
} | |||
return outputs; | |||
} | |||
} // namespace mgb::imperative::python |
@@ -9,9 +9,10 @@ | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "./tensor.h" | |||
namespace mgb::imperative::python { | |||
struct TraceInfo { | |||
}; | |||
apply_result_t apply_trace(ApplyContext& ctx); | |||
} // namespace mgb::imperative::python |
@@ -0,0 +1,24 @@ | |||
/** | |||
* \file imperative/python/src/trace_info.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "inttypes.h" | |||
namespace mgb::imperative::python { | |||
struct TraceInfo { | |||
int64_t mixin_handle = -1; | |||
bool data_read = false; | |||
bool value_read = false; | |||
bool shape_read = false; | |||
}; | |||
} // namespace mgb::imperative::python |
@@ -19,8 +19,6 @@ from megengine import tensor | |||
from megengine.core._trace_option import set_symbolic_shape | |||
from megengine.core.ops import builtin as ops | |||
from megengine.core.ops.builtin import Elemwise | |||
from megengine.core.tensor.core import apply | |||
from megengine.core.tensor.raw_tensor import as_raw_tensor | |||
from megengine.core.tensor.utils import isscalar | |||
from megengine.functional import exp, log | |||
from megengine.jit import exclude_from_trace, trace | |||
@@ -32,35 +30,32 @@ def test_trace(): | |||
@trace(symbolic=symbolic) | |||
def f(x): | |||
op = ops.Elemwise(Elemwise.Mode.NEGATE) | |||
(y,) = apply(op, x) | |||
return y | |||
return -x | |||
x = as_raw_tensor([1]).numpy() | |||
y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||
x = tensor([1]) | |||
y = f(x).numpy() | |||
for i in range(3): | |||
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||
np.testing.assert_equal(f(x).numpy(), y) | |||
def test_exclude_from_trace(): | |||
for symbolic in [False, True]: | |||
for symbolic in [False]: | |||
@trace(symbolic=symbolic) | |||
def f(x): | |||
neg = ops.Elemwise(Elemwise.Mode.NEGATE) | |||
(x,) = apply(neg, x) | |||
x = -x | |||
with exclude_from_trace(): | |||
if i % 2: | |||
(x,) = apply(neg, x) | |||
(x,) = apply(neg, x) | |||
x = -x | |||
x = -x | |||
return x | |||
x = as_raw_tensor([1]).numpy() | |||
x = tensor([1]) | |||
for i in range(3): | |||
y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||
y = f(x).numpy() | |||
np.testing.assert_equal(f(x).numpy(), y) | |||
def test_print_in_trace(): | |||
@@ -69,36 +64,33 @@ def test_print_in_trace(): | |||
@trace(symbolic=symbolic) | |||
def f(x): | |||
nonlocal buf | |||
neg = ops.Elemwise(Elemwise.Mode.NEGATE) | |||
(x,) = apply(neg, x) | |||
x = -x | |||
buf = x.numpy() | |||
(x,) = apply(neg, x) | |||
x = -x | |||
return x | |||
buf = None | |||
x = as_raw_tensor([1]).numpy() | |||
x = tensor([1]) | |||
for i in range(3): | |||
y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||
y = f(x).numpy() | |||
z = buf | |||
buf = None | |||
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||
np.testing.assert_equal(f(x).numpy(), y) | |||
np.testing.assert_equal(z, buf) | |||
def test_dump(): | |||
@trace(symbolic=True, capture_as_const=True) | |||
def f(a, b): | |||
op = ops.Elemwise(Elemwise.Mode.ADD) | |||
(y,) = apply(op, a, b) | |||
return y | |||
return a + b | |||
a = as_raw_tensor([2]).numpy() | |||
b = as_raw_tensor([4]).numpy() | |||
y = f.__wrapped__(as_raw_tensor(a), as_raw_tensor(b)).numpy() | |||
a = tensor([2]) | |||
b = tensor([4]) | |||
y = f(a, b).numpy() | |||
for i in range(3): | |||
np.testing.assert_equal(f(as_raw_tensor(a), as_raw_tensor(b)).numpy(), y) | |||
np.testing.assert_equal(f(a, b).numpy(), y) | |||
file = io.BytesIO() | |||
dump_info = f.dump(file) | |||
@@ -111,19 +103,17 @@ def test_dump(): | |||
def test_capture_dump(): | |||
a = as_raw_tensor([2]) | |||
a = tensor([2]) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def f(x): | |||
op = ops.Elemwise(Elemwise.Mode.MUL) | |||
(y,) = apply(op, x, a) | |||
return y | |||
return x * a | |||
x = as_raw_tensor([3]).numpy() | |||
y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||
x = tensor([3]) | |||
y = f(x).numpy() | |||
for i in range(3): | |||
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||
np.testing.assert_equal(f(x).numpy(), y) | |||
file = io.BytesIO() | |||
f.dump(file) | |||
@@ -133,19 +123,17 @@ def test_capture_dump(): | |||
def test_dump_volatile(): | |||
p = as_raw_tensor([2]) | |||
p = tensor([2]) | |||
@trace(symbolic=True, capture_as_const=True) | |||
def f(x): | |||
op = ops.Elemwise(Elemwise.Mode.MUL) | |||
(y,) = apply(op, x, p) | |||
return y | |||
return x * p | |||
x = as_raw_tensor([3]).numpy() | |||
y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||
x = tensor([3]) | |||
y = f(x).numpy() | |||
for i in range(3): | |||
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||
np.testing.assert_equal(f(x).numpy(), y) | |||
file = io.BytesIO() | |||
f.dump(file, optimize_for_inference=False) | |||
@@ -163,21 +151,18 @@ def test_trace_profiler(): | |||
@trace(symbolic=symbolic, profiling=True) | |||
def f(x): | |||
op = ops.Elemwise(Elemwise.Mode.NEGATE) | |||
(y,) = apply(op, x) | |||
return y | |||
return -x | |||
x = as_raw_tensor([1]).numpy() | |||
y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||
x = tensor([1]) | |||
y = f(x).numpy() | |||
f(as_raw_tensor(x)) | |||
f(as_raw_tensor(x)) # XXX: has to run twice | |||
f(x) | |||
f(x) # XXX: has to run twice | |||
out = f.get_profile() | |||
assert out.get("profiler") | |||
@pytest.mark.skip(reason="force opt_level=0 when building graph") | |||
def test_goptions(): | |||
@trace(symbolic=True, opt_level=0, capture_as_const=True) | |||
def f(x): | |||
@@ -196,7 +181,6 @@ def test_goptions(): | |||
np.testing.assert_equal(g(d).numpy().item(), 1.0) | |||
@pytest.mark.skip(reason="force opt_level=0 when building graph") | |||
def test_goptions_log_sum_exp(): | |||
@trace(symbolic=True, opt_level=0, capture_as_const=True) | |||
def f(x, y): | |||
@@ -256,8 +240,7 @@ def test_optimize_for_inference_broadcast(): | |||
@trace(capture_as_const=True, symbolic_shape=True) | |||
def f(): | |||
(b,) = apply(ops.Broadcast(), a, tensor([1, 10], dtype=np.int32)) | |||
return b | |||
return a._broadcast(tensor([1, 10], dtype=np.int32)) | |||
f() | |||
f.dump(io.BytesIO()) | |||
@@ -387,7 +370,9 @@ def test_trace_nms(): | |||
@trace(symbolic=False) | |||
def f(boxes, scores): | |||
# with tracing, max_output must be specified | |||
results = F.nn.nms(boxes, scores=scores, iou_thresh=0.5, max_output=20) | |||
# without tracing, max output can be inferred inside nms | |||
with exclude_from_trace(): | |||
_ = F.nn.nms(boxes, scores=scores, iou_thresh=0.5) | |||
return results | |||
@@ -318,7 +318,6 @@ def optimize_for_inference(args, outputs): | |||
), "optimize_for_inference should be set when {} is given".format(k) | |||
kwargs[v] = True | |||
outputs = [G.VarNode(output) for output in outputs] | |||
if args.optimize_for_inference: | |||
outputs = [i._node for i in G.optimize_for_inference(outputs, **kwargs)] | |||
@@ -84,7 +84,7 @@ def main(): | |||
minibatch = next(val_dataset) | |||
net.eval() | |||
_, loss = val_fun(data, label) | |||
loss = loss.numpy()[0] | |||
loss = loss.numpy() | |||
val_loss.append((step, loss)) | |||
print("Step: {} loss={}".format(step, loss)) | |||
opt.step() | |||