@@ -17,15 +17,31 @@ from ..ops.builtin import OpDef | |||||
from .core import OpBase, TensorBase, apply | from .core import OpBase, TensorBase, apply | ||||
class CompiledFunction: | |||||
def __init__(self, graph, function): | |||||
self._graph = graph | |||||
self._function = function | |||||
class Graph(_imperative_rt.ComputingGraph): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self._var_cache = weakref.WeakKeyDictionary() | |||||
self._op_cache = weakref.WeakKeyDictionary() | |||||
self._executor = ThreadPoolExecutor(1) | |||||
self._function = None | |||||
self._future = None | self._future = None | ||||
def _wrap(self, obj): | |||||
if type(obj) is _imperative_rt.VarNode: | |||||
wrapper, cache = VarNode, self._var_cache | |||||
elif type(obj) is _imperative_rt.OperatorNode: | |||||
wrapper, cache = OpNode, self._op_cache | |||||
if obj not in cache: | |||||
cache[obj] = wrapper(obj) | |||||
return cache[obj] | |||||
def compile(self, *args): | |||||
self._function = super().compile(_unwrap(args)) | |||||
return self | |||||
def execute(self, *args): | def execute(self, *args): | ||||
assert self._future is None | assert self._future is None | ||||
self._future = self._graph._executor.submit(self._function.execute, *args) | |||||
self._future = self._executor.submit(self._function.execute, *args) | |||||
def wait(self): | def wait(self): | ||||
assert self._future is not None | assert self._future is not None | ||||
@@ -40,30 +56,23 @@ class CompiledFunction: | |||||
self.execute(*args) | self.execute(*args) | ||||
return self.wait() | return self.wait() | ||||
def make_const(self, data, dtype=None, device=None): | |||||
if isinstance(data, _imperative_rt.DeviceTensorND): | |||||
assert dtype is None and device is None | |||||
return self._wrap(_imperative_rt.make_shared(self, data)) | |||||
else: | |||||
device = as_device(device).to_c() | |||||
return self._wrap(_imperative_rt.make_const(self, data, device, dtype)) | |||||
class Graph(_imperative_rt.ComputingGraph): | |||||
def __init__(self): | |||||
super().__init__() | |||||
self._var_cache = weakref.WeakKeyDictionary() | |||||
self._op_cache = weakref.WeakKeyDictionary() | |||||
self._executor = ThreadPoolExecutor(1) | |||||
def _wrap(self, obj): | |||||
if type(obj) is _imperative_rt.VarNode: | |||||
wrapper, cache = VarNode, self._var_cache | |||||
elif type(obj) is _imperative_rt.OperatorNode: | |||||
wrapper, cache = OpNode, self._op_cache | |||||
if obj not in cache: | |||||
cache[obj] = wrapper(obj) | |||||
return cache[obj] | |||||
def compile(self, *args): | |||||
return CompiledFunction(self, super().compile(_unwrap(args))) | |||||
def make_input(self, *args: "VarNode", device=None, dtype=None, shape=None): | |||||
opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self) | |||||
return opnode.outputs[0] | |||||
class VarNode(TensorBase): | class VarNode(TensorBase): | ||||
def __init__(self, node: _imperative_rt.VarNode): | def __init__(self, node: _imperative_rt.VarNode): | ||||
self._node = node | self._node = node | ||||
self.graph._var_cache[node] = self | |||||
@property | @property | ||||
def graph(self) -> Graph: | def graph(self) -> Graph: | ||||
@@ -81,10 +90,15 @@ class VarNode(TensorBase): | |||||
def device(self): | def device(self): | ||||
return as_device(self._node.comp_node) | return as_device(self._node.comp_node) | ||||
@property | |||||
def shape(self): | |||||
return self._node.shape | |||||
class OpNode: | class OpNode: | ||||
def __init__(self, node: _imperative_rt.OperatorNode): | def __init__(self, node: _imperative_rt.OperatorNode): | ||||
self._node = node | self._node = node | ||||
self.graph._op_cache[node] = self | |||||
@property | @property | ||||
def graph(self) -> Graph: | def graph(self) -> Graph: | ||||
@@ -117,21 +131,21 @@ def _(op: OpDef, *args: VarNode): | |||||
return _wrap(outputs) | return _wrap(outputs) | ||||
def input_callback(callback, *args, device=None, dtype=None, graph=None): | |||||
def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): | |||||
outputs = _imperative_rt.input_callback( | outputs = _imperative_rt.input_callback( | ||||
callback, as_device(device).to_c(), dtype, _unwrap(args), graph=graph | |||||
callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph | |||||
) | ) | ||||
value, dummy = _wrap(outputs) | value, dummy = _wrap(outputs) | ||||
return value, dummy | return value, dummy | ||||
class InputNode(OpNode): | class InputNode(OpNode): | ||||
def __init__(self, *args: VarNode, device=None, dtype=None, graph=None): | |||||
def __init__(self, *args: VarNode, device=None, dtype=None, shape=None, graph=None): | |||||
r = _imperative_rt.DeviceTensorNDRendezvous() | r = _imperative_rt.DeviceTensorNDRendezvous() | ||||
if device is not None: | if device is not None: | ||||
device = as_device(device).to_c() | device = as_device(device).to_c() | ||||
outputs = _imperative_rt.input_callback( | outputs = _imperative_rt.input_callback( | ||||
r, device, dtype, _unwrap(args), graph=graph | |||||
r, device, dtype, shape, _unwrap(args), graph=graph | |||||
) | ) | ||||
super().__init__(outputs[0].owner) | super().__init__(outputs[0].owner) | ||||
self._rendezvous = r | self._rendezvous = r | ||||
@@ -169,6 +183,29 @@ class OutputNode(OpNode): | |||||
def get_value(self): | def get_value(self): | ||||
return self._rendezvous.get() | return self._rendezvous.get() | ||||
def drop_value(self): | |||||
self._rendezvous.drop() | |||||
def reset(self): | |||||
self._rendezvous.reset() | |||||
class ValueOutputNode(OpNode): | |||||
def __init__(self, var, *args): | |||||
args = (var,) + args | |||||
r = _imperative_rt.HostTensorNDRendezvous() | |||||
dummy = _imperative_rt.value_output_callback(r, _unwrap(args)) | |||||
super().__init__(dummy.owner) | |||||
self._rendezvous = r | |||||
def get_value(self): | |||||
hostnd, event = self._rendezvous.get() | |||||
event.wait() | |||||
return hostnd.numpy() | |||||
def drop_value(self): | |||||
self._rendezvous.drop() | |||||
def reset(self): | def reset(self): | ||||
self._rendezvous.reset() | self._rendezvous.reset() | ||||
@@ -192,5 +229,8 @@ class AttrOutputNode(OpNode): | |||||
attr = self._rendezvous.get() | attr = self._rendezvous.get() | ||||
return TensorAttr(attr.shape, attr.dtype, as_device(attr.comp_node)) | return TensorAttr(attr.shape, attr.dtype, as_device(attr.comp_node)) | ||||
def drop_value(self): | |||||
self._rendezvous.drop() | |||||
def reset(self): | def reset(self): | ||||
self._rendezvous.reset() | self._rendezvous.reset() |
@@ -31,11 +31,13 @@ class RawTensor(TensorBase): | |||||
_init_cb = None | _init_cb = None | ||||
_del_cb = None | _del_cb = None | ||||
_handle = None | |||||
def __init__(self, handle): | |||||
def __init__(self, handle=None): | |||||
self._handle = handle | self._handle = handle | ||||
if self._init_cb: | |||||
self._init_cb() | |||||
if handle is not None: | |||||
if self._init_cb: | |||||
self._init_cb() | |||||
@property | @property | ||||
def dtype(self): | def dtype(self): | ||||
@@ -61,9 +63,10 @@ class RawTensor(TensorBase): | |||||
) | ) | ||||
def __del__(self): | def __del__(self): | ||||
if self._del_cb: | |||||
self._del_cb() | |||||
delete(self._handle) | |||||
if self._handle is not None: | |||||
if self._del_cb: | |||||
self._del_cb() | |||||
delete(self._handle) | |||||
@apply.register() | @apply.register() | ||||
@@ -89,6 +92,11 @@ def as_raw_tensor(obj, dtype=None, device=None): | |||||
return as_raw_tensor(obj, device=device) | return as_raw_tensor(obj, device=device) | ||||
@as_raw_tensor.register(DeviceTensorND) | |||||
def _(data: DeviceTensorND): | |||||
return RawTensor(put(data)) | |||||
@as_raw_tensor.register(np.ndarray) | @as_raw_tensor.register(np.ndarray) | ||||
def _(array: np.ndarray, dtype=None, device=None): | def _(array: np.ndarray, dtype=None, device=None): | ||||
device = None if device is None else as_device(device).to_c() | device = None if device is None else as_device(device).to_c() | ||||
@@ -0,0 +1 @@ | |||||
from .tracing import exclude_from_trace, trace |
@@ -0,0 +1,514 @@ | |||||
import contextlib | |||||
import functools | |||||
import typing | |||||
import weakref | |||||
from ..core.ops.special import Const | |||||
from ..core.tensor import megbrain_graph as G | |||||
from ..core.tensor.core import OpBase, apply | |||||
from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor | |||||
class TraceMismatchError(RuntimeError): | |||||
pass | |||||
active_trace = None | |||||
skip_tracing = False | |||||
@contextlib.contextmanager | |||||
def exclude_from_trace(): | |||||
global skip_tracing | |||||
if skip_tracing: | |||||
yield | |||||
return | |||||
try: | |||||
skip_tracing = True | |||||
if active_trace is not None: | |||||
active_trace._begin_excluded_region() | |||||
yield | |||||
finally: | |||||
skip_tracing = False | |||||
class TensorInfo: | |||||
__slots__ = ( | |||||
# collected attributes | |||||
"external", | |||||
"exported", | |||||
"data_read", | |||||
"shape_read", | |||||
"value_read", | |||||
"device", | |||||
"dtype", | |||||
"bound_data", | |||||
# resources for execution | |||||
"varnode", | |||||
"data_setter", | |||||
"shape_reader", | |||||
"value_reader", | |||||
"data_reader", | |||||
) | |||||
def __init__(self): | |||||
self.exported = None | |||||
self.data_read = None | |||||
self.shape_read = None | |||||
self.value_read = None | |||||
self.bound_data = None | |||||
self.data_setter = None | |||||
self.shape_reader = None | |||||
self.value_reader = None | |||||
self.data_reader = None | |||||
class trace: | |||||
def __new__(cls, *args, **kwargs): | |||||
if not args: | |||||
return functools.partial(cls, **kwargs) | |||||
self = super().__new__(cls) | |||||
self.__init__(*args, **kwargs) | |||||
return self | |||||
def __init__(self, function, symbolic=False, capture_as_const=False): | |||||
self.__wrapped__ = function | |||||
self._symbolic = symbolic | |||||
self._capture_as_const = capture_as_const | |||||
self._capture_static_shape = False | |||||
self._untraced = True | |||||
self._tinfo = [] # handle -> TensorInfo | |||||
self._seq = [] | |||||
self._pc = 0 | |||||
self._graph = None | |||||
self._need_reset_nodes = None | |||||
self._lazy_eval_graph = None | |||||
self._lazy_eval_tensors = weakref.WeakSet() | |||||
self._active_tensors = weakref.WeakSet() | |||||
def _new_handle(self): | |||||
handle = len(self._tinfo) | |||||
info = TensorInfo() | |||||
self._tinfo.append(info) | |||||
return handle, info | |||||
def _apply_op(self, op, args): | |||||
assert not self._untraced | |||||
# check against trace | |||||
if self._pc >= len(self._seq): | |||||
raise TraceMismatchError("trace should end here, but more op observed") | |||||
record = self._seq[self._pc] | |||||
op_, ihandles, ohandles = record | |||||
if op != op_: | |||||
raise TraceMismatchError("op different from last time") | |||||
if len(ihandles) != len(args): | |||||
raise TraceMismatchError("op input size different from last time") | |||||
for h, x in zip(ihandles, args): | |||||
info = self._tinfo[h] | |||||
if info.external: | |||||
if ( | |||||
x.__class__ is CompiledTensorProxy | |||||
and not self._tinfo[x._CompiledTensorProxy__handle].exported | |||||
): | |||||
raise TraceMismatchError( | |||||
"failed to capture: input was an external tensor " | |||||
"last time, got an internal tensor this time" | |||||
) | |||||
if info.bound_data: | |||||
if x.__class__ is CompiledTensorProxy: | |||||
raise TraceMismatchError( | |||||
"const capture violated: was an external tensor " | |||||
"last time, got an internal tensor this time" | |||||
) | |||||
if x._handle != info.bound_data._handle: | |||||
raise TraceMismatchError( | |||||
"const capture violated: got " | |||||
"a different tensor this time" | |||||
) | |||||
else: | |||||
if info.dtype != x.dtype: | |||||
raise TraceMismatchError( | |||||
"failed to capture: different dtype from last time" | |||||
) | |||||
if info.device != x.device: | |||||
raise TraceMismatchError( | |||||
"failed to capture: different device from last time" | |||||
) | |||||
info.data_setter.set_value(x._dev_tensor()) | |||||
else: | |||||
if x.__class__ is not CompiledTensorProxy: | |||||
raise TraceMismatchError( | |||||
"unexpected capture: trying to use an external tensor as input, " | |||||
"but that input was an internal tensor last time" | |||||
) | |||||
if x._CompiledTensorProxy__handle != h: | |||||
raise TraceMismatchError( | |||||
"mis-wiring: input edge to an data flow " | |||||
"graph node is different from last time" | |||||
) | |||||
self._pc += 1 | |||||
outputs = tuple([CompiledTensorProxy(h) for h in ohandles]) | |||||
self._active_tensors.update(outputs) | |||||
return outputs | |||||
def _record_op(self, op, inputs, outputs): | |||||
if skip_tracing: | |||||
for x in inputs: | |||||
h = getattr(x, "_TraceMixin__handle", None) | |||||
if h is not None: | |||||
self._tinfo[h].data_read = True | |||||
return | |||||
ihandles = [] | |||||
for x in inputs: | |||||
h = getattr(x, "_TraceMixin__handle", None) | |||||
if h is None or (not self._capture_as_const and self._tinfo[h].exported): | |||||
h, info = self._new_handle() | |||||
info.external = True | |||||
info.device = x.device | |||||
info.dtype = x.dtype | |||||
if self._capture_as_const: | |||||
info.bound_data = x | |||||
ihandles.append(h) | |||||
ohandles = [] | |||||
for x in outputs: | |||||
h, info = self._new_handle() | |||||
ohandles.append(h) | |||||
info.external = False | |||||
TraceMixin._TraceMixin__inject(x, h) | |||||
self._seq.append((op, tuple(ihandles), tuple(ohandles))) | |||||
self._active_tensors.update(outputs) | |||||
@contextlib.contextmanager | |||||
def _setup(self): | |||||
global active_trace | |||||
if active_trace: | |||||
raise NotImplementedError("sorry, not implemented: nested trace") | |||||
active_trace = self | |||||
if self._untraced: | |||||
apply.enable(apply_with_tracing) | |||||
if self._symbolic: | |||||
apply.enable(apply_symbolic_mode) | |||||
self._lazy_eval_graph = G.Graph() | |||||
else: | |||||
apply.enable(apply_compiled_mode) | |||||
if self._graph is None: | |||||
self._compile() | |||||
self._graph.execute() | |||||
yield | |||||
escaped_tensors = tuple(self._active_tensors) | |||||
self._active_tensors.clear() | |||||
if self._untraced: | |||||
for x in escaped_tensors: | |||||
info = self._tinfo[x._TraceMixin__handle] | |||||
info.data_read = True | |||||
x._TraceMixin__restore() | |||||
if self._symbolic: | |||||
# eval lazy eval tensors | |||||
lazy_eval_tensors = tuple(self._lazy_eval_tensors) | |||||
if lazy_eval_tensors: | |||||
readers = [ | |||||
G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] | |||||
for x in lazy_eval_tensors | |||||
] | |||||
self._lazy_eval_graph.compile(*readers) | |||||
self._lazy_eval_graph() | |||||
for r, x in zip(readers, lazy_eval_tensors): | |||||
assign_raw_tensor(x, as_raw_tensor(r.op.get_value())) | |||||
self._lazy_eval_graph = None | |||||
self._lazy_eval_tensors = None | |||||
self._untraced = False | |||||
else: | |||||
if self._pc != len(self._seq): | |||||
raise TraceMismatchError("premature end") | |||||
for x in escaped_tensors: | |||||
assign_raw_tensor(x, as_raw_tensor(x._dev_tensor())) | |||||
self._graph.wait() | |||||
self._reset_exec_env() | |||||
self._pc = 0 | |||||
apply.disable(apply_with_tracing) | |||||
apply.disable(apply_symbolic_mode) | |||||
apply.disable(apply_compiled_mode) | |||||
active_trace = None | |||||
def _begin_excluded_region(self): | |||||
if self._untraced: | |||||
# conditionally reading a compiled tensor in excluded region | |||||
# is permitted, so we have to assume every tensor might be read | |||||
for x in self._active_tensors: | |||||
info = self._tinfo[x._TraceMixin__handle] | |||||
info.exported = True | |||||
info.data_read = True | |||||
def _compile(self): | |||||
graph = self._graph = G.Graph() | |||||
# graph.options.graph_opt_level = 0 | |||||
need_reset_nodes = self._need_reset_nodes = [] | |||||
# links enforce ordering of I/O nodes | |||||
links = () | |||||
for op, ihandles, ohandles in self._seq: | |||||
ivars = [] | |||||
readers = [] | |||||
for h in ihandles: | |||||
info = self._tinfo[h] | |||||
if not hasattr(info, "varnode"): | |||||
assert info.external | |||||
if info.bound_data: | |||||
info.varnode = graph.make_const(info.bound_data._dev_tensor()) | |||||
else: | |||||
opnode = info.data_setter = G.InputNode( | |||||
*links, device=info.device, dtype=info.dtype, graph=graph | |||||
) | |||||
need_reset_nodes.append(opnode) | |||||
info.varnode, *links = opnode.outputs | |||||
ivars.append(info.varnode) | |||||
ovars = apply(op, *ivars) | |||||
assert len(ovars) == len(ohandles) | |||||
for h, v in zip(ohandles, ovars): | |||||
info = self._tinfo[h] | |||||
info.varnode = v | |||||
def add_reader(opnode): | |||||
nonlocal links | |||||
need_reset_nodes.append(opnode) | |||||
readers.append(opnode.outputs[0]) | |||||
links = opnode.outputs | |||||
if info.data_read: | |||||
# Shape can be obtained from data so doesn't need its own | |||||
# output node. On the other hand, value is read separately | |||||
# to leverage eager h2d copy | |||||
info.shape_read = False | |||||
opnode = info.data_reader = G.OutputNode(v, *links) | |||||
add_reader(opnode) | |||||
if info.value_read: | |||||
opnode = info.value_reader = G.ValueOutputNode(v, *links) | |||||
add_reader(opnode) | |||||
if info.shape_read: | |||||
opnode = info.shape_reader = G.AttrOutputNode(v, *links) | |||||
add_reader(opnode) | |||||
graph.compile(*readers) | |||||
def _reset_exec_env(self): | |||||
for opnode in self._need_reset_nodes: | |||||
opnode.reset() | |||||
def _require_shape(self, handle): | |||||
info = self._tinfo[handle] | |||||
info.shape_read = True | |||||
def _require_value(self, handle): | |||||
info = self._tinfo[handle] | |||||
info.value_read = True | |||||
def _require_data(self, handle): | |||||
info = self._tinfo[handle] | |||||
info.data_read = True | |||||
def __call__(self, *args, **kwargs): | |||||
with self._setup(): | |||||
return self.__wrapped__(*args, **kwargs) | |||||
class CompiledTensorProxy(RawTensor): | |||||
""" | |||||
Duck-typed RawTensor | |||||
""" | |||||
def __init__(self, handle): | |||||
self.__handle = handle | |||||
self.__info = active_trace._tinfo[handle] | |||||
self.__shape = None | |||||
self.__data = None | |||||
self.__value = None | |||||
@property | |||||
def dtype(self): | |||||
return self.__info.varnode.dtype | |||||
@property | |||||
def device(self): | |||||
return self.__info.varnode.device | |||||
@property | |||||
def shape(self): | |||||
if self.__shape is None: | |||||
if self.__info.shape_read: | |||||
self.__shape = self.__info.shape_reader.get_value().shape | |||||
elif self.__info.data_read: | |||||
self.__shape = self._dev_tensor().shape | |||||
else: | |||||
raise TraceMismatchError("shape of this tensor is not read in trace") | |||||
return self.__shape | |||||
def numpy(self): | |||||
if self.__value is None: | |||||
if self.__info.value_read: | |||||
self.__value = self.__info.value_reader.get_value() | |||||
elif self.__info.data_read: | |||||
self.__value = self._dev_tensor().numpy() | |||||
else: | |||||
raise TraceMismatchError("value of this tensor is not read in trace") | |||||
return self.__value | |||||
def _dev_tensor(self): | |||||
if self.__data is None: | |||||
if not self.__info.data_read: | |||||
raise TraceMismatchError("raw data of this tensor is not read in trace") | |||||
self.__data = self.__info.data_reader.get_value() | |||||
return self.__data | |||||
def __del__(self): | |||||
if self.__info.shape_read and self.__shape is not None: | |||||
self.__info.shape_reader.drop_value() | |||||
if self.__info.value_read and self.__value is not None: | |||||
self.__info.value_reader.drop_value() | |||||
if self.__info.data_read and self.__data is not None: | |||||
self.__info.data_reader.drop_value() | |||||
class LazyEvalTensor(RawTensor): | |||||
def __init__(self, varnode): | |||||
self.__varnode = varnode | |||||
@property | |||||
def dtype(self): | |||||
return self.__varnode.dtype | |||||
@property | |||||
def device(self): | |||||
return self.__varnode.device | |||||
@property | |||||
def shape(self): | |||||
return self.__varnode.shape | |||||
def numpy(self): | |||||
raise RuntimeError("cannot read value during symbolic tracing") | |||||
def _dev_tensor(self): | |||||
raise RuntimeError("cannot access data during symbolic tracing") | |||||
class TraceMixin: | |||||
__subclass_cache = {} | |||||
def __inject(self, handle): | |||||
cache = __class__.__subclass_cache | |||||
cls = self.__class__ | |||||
subcls = cache.get(cls) | |||||
if subcls is None: | |||||
subcls = cache[cls] = type("Traced" + cls.__name__, (__class__, cls), {}) | |||||
self.__class__ = subcls | |||||
self.__handle = handle | |||||
self.__cls = cls | |||||
return self | |||||
def __restore(self): | |||||
cls = self.__cls | |||||
del self.__handle | |||||
del self.__cls | |||||
self.__class__ = cls | |||||
return self | |||||
@property | |||||
def shape(self): | |||||
if not skip_tracing: | |||||
active_trace._require_shape(self.__handle) | |||||
return super().shape | |||||
def numpy(self): | |||||
if not skip_tracing: | |||||
active_trace._require_value(self.__handle) | |||||
return super().numpy() | |||||
def _dev_tensor(self): | |||||
if not skip_tracing: | |||||
active_trace._require_data(self.__handle) | |||||
return super()._dev_tensor() | |||||
class TracedRawTensor(TraceMixin, RawTensor): | |||||
pass | |||||
class TracedLazyTensor(TraceMixin, LazyEvalTensor): | |||||
pass | |||||
def assign_raw_tensor(lhs, rhs): | |||||
handle = rhs._handle | |||||
rhs.__dict__.clear() | |||||
lhs.__dict__.clear() | |||||
lhs.__class__ = RawTensor | |||||
lhs.__init__(handle) | |||||
# this hook turns RawTensor into LazyEvalTensor | |||||
@apply.register() | |||||
def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||||
graph = active_trace._lazy_eval_graph | |||||
ivars = [ | |||||
getattr(x, "_LazyEvalTensor__varnode", None) | |||||
or graph.make_const(x._dev_tensor()) | |||||
for x in args | |||||
] | |||||
ovars = apply(op, *ivars) | |||||
outputs = [LazyEvalTensor(v) for v in ovars] | |||||
active_trace._lazy_eval_tensors.update(outputs) | |||||
return outputs | |||||
apply.disable(apply_symbolic_mode) | |||||
@apply.register() | |||||
def apply_compiled_mode(op: OpDef, *args: RawTensor): | |||||
if skip_tracing: | |||||
args = [ | |||||
as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x | |||||
for x in args | |||||
] | |||||
return apply.super(op, *args) | |||||
return active_trace._apply_op(op, args) | |||||
apply.disable(apply_compiled_mode) | |||||
# this hook injects TraceMixin | |||||
@apply.register() | |||||
def apply_with_tracing(op: OpDef, *args: RawTensor): | |||||
outputs = apply.super(op, *args) | |||||
active_trace._record_op(op, args, outputs) | |||||
return outputs | |||||
apply.disable(apply_with_tracing) | |||||
# @apply.register() | |||||
# def _(op: Const, *args: RawTensor): | |||||
# return active_trace._apply_const(op, args) | |||||
class BrokenRawTensor(RawTensor): | |||||
def __getattribute__(self, _): | |||||
raise RuntimeError("broken due to misuse of tracing") | |||||
def __setattr__(self, *_): | |||||
raise RuntimeError("broken due to misuse of tracing") |
@@ -23,10 +23,29 @@ namespace py = pybind11; | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace imperative; | using namespace imperative; | ||||
namespace { | |||||
template<typename XTensorND> | |||||
auto def_TensorND(py::object parent, const char* name) { | |||||
return py::class_<XTensorND>(parent, name) | |||||
.def_property_readonly("shape", py::overload_cast<>(&XTensorND::shape, py::const_)) | |||||
.def_property_readonly("dtype", py::overload_cast<>(&XTensorND::dtype, py::const_)) | |||||
.def_property_readonly("comp_node", py::overload_cast<>(&XTensorND::comp_node, py::const_)) | |||||
.def("copy_from", &XTensorND::template copy_from<DeviceTensorStorage>) | |||||
.def("copy_from", &XTensorND::template copy_from<HostTensorStorage>) | |||||
.def("copy_from_fixlayout", py::overload_cast<const DeviceTensorND&>( | |||||
&XTensorND::template copy_from_fixlayout<DeviceTensorStorage>)) | |||||
.def("copy_from_fixlayout", py::overload_cast<const HostTensorND&>( | |||||
&XTensorND::template copy_from_fixlayout<HostTensorStorage>)); | |||||
} | |||||
} // namespace | |||||
void init_common(py::module m) { | void init_common(py::module m) { | ||||
py::class_<CompNode>(m, "CompNode") | |||||
auto&& PyCompNode = py::class_<CompNode>(m, "CompNode") | |||||
.def(py::init()) | .def(py::init()) | ||||
.def(py::init(py::overload_cast<const std::string&>(&CompNode::load))) | .def(py::init(py::overload_cast<const std::string&>(&CompNode::load))) | ||||
.def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) | |||||
.def("__str__", &CompNode::to_string_logical) | .def("__str__", &CompNode::to_string_logical) | ||||
.def_static("_sync_all", &CompNode::sync_all) | .def_static("_sync_all", &CompNode::sync_all) | ||||
.def(py::self == py::self) | .def(py::self == py::self) | ||||
@@ -40,19 +59,30 @@ void init_common(py::module m) { | |||||
return CompNode::load(cn); | return CompNode::load(cn); | ||||
})); | })); | ||||
py::class_<CompNode::Event, std::shared_ptr<CompNode::Event>>(PyCompNode, "Event") | |||||
.def("record", &CompNode::Event::record) | |||||
.def("wait", &CompNode::Event::host_wait); | |||||
py::implicitly_convertible<std::string, CompNode>(); | py::implicitly_convertible<std::string, CompNode>(); | ||||
py::class_<DeviceTensorND>(m, "DeviceTensorND") | |||||
.def(py::init()) | |||||
.def_property_readonly("shape", py::overload_cast<>(&DeviceTensorND::shape, py::const_)) | |||||
.def_property_readonly("dtype", py::overload_cast<>(&DeviceTensorND::dtype, py::const_)) | |||||
.def_property_readonly("comp_node", py::overload_cast<>(&DeviceTensorND::comp_node, py::const_)) | |||||
def_TensorND<DeviceTensorND>(m, "DeviceTensorND") | |||||
.def("numpy", [](const DeviceTensorND& self) { | .def("numpy", [](const DeviceTensorND& self) { | ||||
HostTensorND hv; | HostTensorND hv; | ||||
hv.copy_from(self).sync(); | hv.copy_from(self).sync(); | ||||
return py::handle(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); | return py::handle(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); | ||||
}); | }); | ||||
def_TensorND<HostTensorND>(m, "HostTensorND") | |||||
.def(py::init([](py::array data, CompNode cn, DType dtype) { | |||||
if (!cn.valid()) { | |||||
throw py::type_error("device must not be None"); | |||||
} | |||||
return npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); | |||||
})) | |||||
.def("numpy", [](const HostTensorND& self) { | |||||
return py::reinterpret_steal<py::object>(npy::ndarray_from_tensor(self, npy::ShareType::TRY_SHARE)); | |||||
}); | |||||
py::class_<cg::OperatorNodeConfig>(m, "OperatorNodeConfig") | py::class_<cg::OperatorNodeConfig>(m, "OperatorNodeConfig") | ||||
.def(py::init()) | .def(py::init()) | ||||
.def_property("name", | .def_property("name", | ||||
@@ -12,6 +12,7 @@ | |||||
#include "./graph_rt.h" | #include "./graph_rt.h" | ||||
#include "megbrain/imperative/opr_utility.h" | #include "megbrain/imperative/opr_utility.h" | ||||
#include "megbrain/opr/io.h" | |||||
#include "megbrain/opr/basic_arith.h" | #include "megbrain/opr/basic_arith.h" | ||||
#include "megbrain/imperative.h" | #include "megbrain/imperative.h" | ||||
#include "./helper.h" | #include "./helper.h" | ||||
@@ -29,29 +30,44 @@ auto def_rendezvous(py::object m, const char* name) { | |||||
.def(py::init([](){return std::make_shared<Rendezvous<T>>();})) | .def(py::init([](){return std::make_shared<Rendezvous<T>>();})) | ||||
.def("set", [](Rendezvous<T>& r, T v) {r.set(std::move(v));}) | .def("set", [](Rendezvous<T>& r, T v) {r.set(std::move(v));}) | ||||
.def("get", [](Rendezvous<T>& r) {return r.get();}, py::call_guard<py::gil_scoped_release>()) | .def("get", [](Rendezvous<T>& r) {return r.get();}, py::call_guard<py::gil_scoped_release>()) | ||||
.def("drop", &Rendezvous<T>::drop) | |||||
.def("reset", &Rendezvous<T>::reset); | .def("reset", &Rendezvous<T>::reset); | ||||
} | } | ||||
using TensorAttr = LogicalTensorDesc; | using TensorAttr = LogicalTensorDesc; | ||||
using HostNDWithEvent = std::pair<HostTensorND, std::shared_ptr<CompNode::Event>>; | |||||
void init_graph_rt(py::module m) { | void init_graph_rt(py::module m) { | ||||
def_rendezvous<DeviceTensorND>(m, "DeviceTensorNDRendezvous"); | def_rendezvous<DeviceTensorND>(m, "DeviceTensorNDRendezvous"); | ||||
def_rendezvous<HostNDWithEvent>(m, "HostTensorNDRendezvous"); | |||||
def_rendezvous<TensorAttr>(m, "TensorAttrRendezvous"); | def_rendezvous<TensorAttr>(m, "TensorAttrRendezvous"); | ||||
py::class_<cg::VarNode, GraphNodePtr<cg::VarNode>>(m, "VarNode") | py::class_<cg::VarNode, GraphNodePtr<cg::VarNode>>(m, "VarNode") | ||||
.def_property_readonly("owner", [](cg::VarNode* v) {return v->owner_opr();}) | .def_property_readonly("owner", [](cg::VarNode* v) {return v->owner_opr();}) | ||||
.def_property_readonly("graph", [](cg::VarNode* v) {return v->owner_graph();}) | .def_property_readonly("graph", [](cg::VarNode* v) {return v->owner_graph();}) | ||||
.def_property_readonly("name", py::overload_cast<>(&VarNode::name, py::const_)) | |||||
.def_property_readonly("dtype", [](cg::VarNode* v) {return v->dtype();}) | .def_property_readonly("dtype", [](cg::VarNode* v) {return v->dtype();}) | ||||
.def_property_readonly("comp_node", [](cg::VarNode* v) {return v->comp_node();}); | |||||
.def_property_readonly("comp_node", [](cg::VarNode* v) {return v->comp_node();}) | |||||
.def_property_readonly("shape", [](cg::VarNode* v) -> const TensorShape* { | |||||
auto&& mgr = v->owner_graph()->static_infer_manager(); | |||||
auto&& type = mgr.get_infer_type(v); | |||||
using InferType = cg::static_infer::InferType; | |||||
if (!(type.shape & (InferType::CONST | InferType::RT_STATIC))) { | |||||
return nullptr; | |||||
} | |||||
return mgr.infer_shape_fallible(v); | |||||
}); | |||||
py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(m, "OperatorNode") | py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(m, "OperatorNode") | ||||
.def_property_readonly("graph", [](cg::OperatorNodeBase* opr) {return opr->owner_graph();}) | .def_property_readonly("graph", [](cg::OperatorNodeBase* opr) {return opr->owner_graph();}) | ||||
.def_property_readonly("name", py::overload_cast<>(&cg::OperatorNodeBase::name, py::const_)) | |||||
.def_property_readonly("inputs", [](cg::OperatorNodeBase* opr) { | .def_property_readonly("inputs", [](cg::OperatorNodeBase* opr) { | ||||
return to_tuple(opr->input()); | return to_tuple(opr->input()); | ||||
}) | }) | ||||
.def_property_readonly("outputs", [](cg::OperatorNodeBase* opr) { | .def_property_readonly("outputs", [](cg::OperatorNodeBase* opr) { | ||||
return to_tuple(opr->output()); | |||||
return to_tuple(opr->usable_output()); | |||||
}); | }); | ||||
py::class_<cg::AsyncExecutable>(m, "AsyncExecutable") | py::class_<cg::AsyncExecutable>(m, "AsyncExecutable") | ||||
@@ -117,7 +133,7 @@ void init_graph_rt(py::module m) { | |||||
common.def("invoke_op", [](const OpDef& def, const std::vector<cg::VarNode*> inputs, cg::ComputingGraph* graph) { | common.def("invoke_op", [](const OpDef& def, const std::vector<cg::VarNode*> inputs, cg::ComputingGraph* graph) { | ||||
cg::VarNodeArray vinputs(inputs.begin(), inputs.end()); | cg::VarNodeArray vinputs(inputs.begin(), inputs.end()); | ||||
auto opr = OpDef::apply_on_var_node(def, vinputs); | auto opr = OpDef::apply_on_var_node(def, vinputs); | ||||
auto outputs = opr->output(); | |||||
auto outputs = opr->usable_output(); | |||||
return to_tuple(outputs); | return to_tuple(outputs); | ||||
}, | }, | ||||
py::arg(), py::arg(), py::arg("graph") = py::none()); | py::arg(), py::arg(), py::arg("graph") = py::none()); | ||||
@@ -125,6 +141,7 @@ void init_graph_rt(py::module m) { | |||||
auto input_callback = [](auto callback, | auto input_callback = [](auto callback, | ||||
const CompNode& comp_node, | const CompNode& comp_node, | ||||
const DType& dtype, | const DType& dtype, | ||||
const TensorShape& shape, | |||||
const std::vector<cg::VarNode*>& inputs, | const std::vector<cg::VarNode*>& inputs, | ||||
cg::ComputingGraph* graph) { | cg::ComputingGraph* graph) { | ||||
if (!graph) { | if (!graph) { | ||||
@@ -135,7 +152,7 @@ void init_graph_rt(py::module m) { | |||||
sinputs.emplace_back(i); | sinputs.emplace_back(i); | ||||
} | } | ||||
static_assert(!std::is_reference<decltype(callback)>::value); | static_assert(!std::is_reference<decltype(callback)>::value); | ||||
auto soutputs = opr::InputCallback::make(*graph, std::move(callback), comp_node, dtype, sinputs); | |||||
auto soutputs = opr::InputCallback::make(*graph, std::move(callback), comp_node, dtype, shape, sinputs); | |||||
std::vector<VarNode*> outputs; | std::vector<VarNode*> outputs; | ||||
outputs.reserve(soutputs.size()); | outputs.reserve(soutputs.size()); | ||||
for (auto i : soutputs) { | for (auto i : soutputs) { | ||||
@@ -144,26 +161,40 @@ void init_graph_rt(py::module m) { | |||||
return outputs; | return outputs; | ||||
}; | }; | ||||
m.def("make_shared", [](cg::ComputingGraph* graph, const DeviceTensorND& data) { | |||||
return opr::SharedDeviceTensor::make(*graph, std::make_shared<DeviceTensorND>(data)).node(); | |||||
}); | |||||
m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) { | |||||
if (!cn.valid()) { | |||||
throw py::type_error("device must not be None"); | |||||
} | |||||
auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); | |||||
opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); | |||||
}); | |||||
m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback, | m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback, | ||||
const CompNode& comp_node, | const CompNode& comp_node, | ||||
const DType& dtype, | const DType& dtype, | ||||
const TensorShape& shape, | |||||
const std::vector<cg::VarNode*>& inputs, | const std::vector<cg::VarNode*>& inputs, | ||||
cg::ComputingGraph* graph) { | cg::ComputingGraph* graph) { | ||||
return input_callback([f=std::move(callback)](){py::gil_scoped_acquire _; return f();}, comp_node, dtype, inputs, graph); | |||||
return input_callback([f=std::move(callback)](){py::gil_scoped_acquire _; return f();}, comp_node, dtype, shape, inputs, graph); | |||||
}, | }, | ||||
py::arg(), py::arg(), py::arg(), py::arg() = py::tuple(), py::arg("graph") = py::none()); | |||||
py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none()); | |||||
m.def("input_callback", [input_callback](std::shared_ptr<Rendezvous<DeviceTensorND>> p, | m.def("input_callback", [input_callback](std::shared_ptr<Rendezvous<DeviceTensorND>> p, | ||||
const CompNode& comp_node, | const CompNode& comp_node, | ||||
const DType& dtype, | const DType& dtype, | ||||
const TensorShape& shape, | |||||
const std::vector<cg::VarNode*>& inputs, | const std::vector<cg::VarNode*>& inputs, | ||||
cg::ComputingGraph* graph) { | cg::ComputingGraph* graph) { | ||||
auto f = [p]() -> DeviceTensorND { | auto f = [p]() -> DeviceTensorND { | ||||
return p->get(); | return p->get(); | ||||
}; | }; | ||||
return input_callback(std::move(f), comp_node, dtype, inputs, graph); | |||||
return input_callback(std::move(f), comp_node, dtype, shape, inputs, graph); | |||||
}, | }, | ||||
py::arg(), py::arg(), py::arg(), py::arg() = py::tuple(), py::arg("graph") = py::none()); | |||||
py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none()); | |||||
auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs, bool borrow = false) { | auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs, bool borrow = false) { | ||||
SymbolVarArray sinputs; | SymbolVarArray sinputs; | ||||
@@ -193,6 +224,17 @@ void init_graph_rt(py::module m) { | |||||
return output_callback(std::move(f), std::move(inputs)); | return output_callback(std::move(f), std::move(inputs)); | ||||
}); | }); | ||||
m.def("value_output_callback", [output_callback](std::shared_ptr<Rendezvous<HostNDWithEvent>> p, std::vector<cg::VarNode*> inputs) { | |||||
auto f = [p](DeviceTensorND dv) { | |||||
HostNDWithEvent hv_with_event; | |||||
hv_with_event.first.copy_from(dv); | |||||
hv_with_event.second = dv.comp_node().create_event(); | |||||
hv_with_event.second->record(); | |||||
p->set(std::move(hv_with_event)); | |||||
}; | |||||
return output_callback(std::move(f), std::move(inputs), true); | |||||
}); | |||||
m.def("attr_output_callback", [output_callback](std::shared_ptr<Rendezvous<TensorAttr>> p, std::vector<cg::VarNode*> inputs) { | m.def("attr_output_callback", [output_callback](std::shared_ptr<Rendezvous<TensorAttr>> p, std::vector<cg::VarNode*> inputs) { | ||||
auto f = [p](DeviceTensorND dv) { | auto f = [p](DeviceTensorND dv) { | ||||
p->set(TensorAttr{TensorLayout{dv.shape(), dv.dtype()}, dv.comp_node()}); | p->set(TensorAttr{TensorLayout{dv.shape(), dv.dtype()}, dv.comp_node()}); | ||||
@@ -39,6 +39,7 @@ template<typename R> | |||||
class Rendezvous { | class Rendezvous { | ||||
std::mutex m_lock; | std::mutex m_lock; | ||||
int m_read_ahead = 0; | int m_read_ahead = 0; | ||||
bool m_drop_next = false; | |||||
std::promise<R> m_promise; | std::promise<R> m_promise; | ||||
public: | public: | ||||
Rendezvous() = default; | Rendezvous() = default; | ||||
@@ -47,6 +48,7 @@ public: | |||||
Rendezvous& operator=(const Rendezvous& rhs) = delete; | Rendezvous& operator=(const Rendezvous& rhs) = delete; | ||||
Rendezvous& operator=(Rendezvous&& rhs) { | Rendezvous& operator=(Rendezvous&& rhs) { | ||||
MGB_LOCK_GUARD(m_lock); | MGB_LOCK_GUARD(m_lock); | ||||
m_drop_next = rhs.m_drop_next; | |||||
m_read_ahead = rhs.m_read_ahead; | m_read_ahead = rhs.m_read_ahead; | ||||
m_promise = std::move(rhs.m_promise); | m_promise = std::move(rhs.m_promise); | ||||
return *this; | return *this; | ||||
@@ -67,12 +69,28 @@ public: | |||||
return f.get(); | return f.get(); | ||||
} | } | ||||
void drop() { | |||||
MGB_LOCK_GUARD(m_lock); | |||||
mgb_assert(m_read_ahead <= 0); | |||||
mgb_assert(m_read_ahead >= -1); | |||||
if (m_read_ahead == -1) { | |||||
m_promise = {}; | |||||
} else { | |||||
m_drop_next = true; | |||||
} | |||||
++m_read_ahead; | |||||
} | |||||
template<typename T> | template<typename T> | ||||
void set(T&& value) { | void set(T&& value) { | ||||
MGB_LOCK_GUARD(m_lock); | MGB_LOCK_GUARD(m_lock); | ||||
mgb_assert(m_read_ahead >= 0); | mgb_assert(m_read_ahead >= 0); | ||||
mgb_assert(m_read_ahead <= 1); | mgb_assert(m_read_ahead <= 1); | ||||
m_promise.set_value(std::forward<T>(value)); | |||||
if (m_drop_next) { | |||||
m_drop_next = false; | |||||
} else { | |||||
m_promise.set_value(std::forward<T>(value)); | |||||
} | |||||
if (m_read_ahead == 1) { | if (m_read_ahead == 1) { | ||||
m_promise = {}; | m_promise = {}; | ||||
} | } | ||||
@@ -83,6 +101,7 @@ public: | |||||
MGB_LOCK_GUARD(m_lock); | MGB_LOCK_GUARD(m_lock); | ||||
m_promise = {}; | m_promise = {}; | ||||
m_read_ahead = 0; | m_read_ahead = 0; | ||||
m_drop_next = false; | |||||
} | } | ||||
}; | }; | ||||
@@ -280,9 +280,12 @@ namespace detail { | |||||
public: | public: | ||||
bool load(handle src, bool convert) { | bool load(handle src, bool convert) { | ||||
auto obj = reinterpret_steal<object>(src); | auto obj = reinterpret_steal<object>(src); | ||||
if (!isinstance<tuple>(obj)) { | |||||
if (!convert && !isinstance<tuple>(obj)) { | |||||
return false; | return false; | ||||
} | } | ||||
if (obj.is_none()) { | |||||
return true; | |||||
} | |||||
value.ndim = len(obj); | value.ndim = len(obj); | ||||
mgb_assert(value.ndim <= mgb::TensorShape::MAX_NDIM); | mgb_assert(value.ndim <= mgb::TensorShape::MAX_NDIM); | ||||
size_t i = 0; | size_t i = 0; | ||||
@@ -63,6 +63,7 @@ void init_imperative_rt(py::module m) { | |||||
return self.put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype)); | return self.put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype)); | ||||
} | } | ||||
}, py::arg(), py::arg("dtype") = py::none(), py::arg("device") = py::none()) | }, py::arg(), py::arg("dtype") = py::none(), py::arg("device") = py::none()) | ||||
.def("put", py::overload_cast<const DeviceTensorND&>(&Interpreter::Channel::put)) | |||||
.def("delete", [](Interpreter::Channel& self, Interpreter::Handle handle) { | .def("delete", [](Interpreter::Channel& self, Interpreter::Handle handle) { | ||||
return self.del(handle); | return self.del(handle); | ||||
}) | }) | ||||
@@ -24,6 +24,12 @@ constexpr bool has_fastcall = true; | |||||
constexpr bool has_fastcall = false; | constexpr bool has_fastcall = false; | ||||
#endif | #endif | ||||
#ifdef _Py_TPFLAGS_HAVE_VECTORCALL | |||||
constexpr bool has_vectorcall = true; | |||||
#else | |||||
constexpr bool has_vectorcall = false; | |||||
#endif | |||||
template<typename... Args> | template<typename... Args> | ||||
struct invocable_with { | struct invocable_with { | ||||
template<typename T> | template<typename T> | ||||
@@ -55,6 +61,9 @@ private: | |||||
public: | public: | ||||
PyObject_HEAD | PyObject_HEAD | ||||
std::aligned_storage_t<sizeof(T), alignof(T)> storage; | std::aligned_storage_t<sizeof(T), alignof(T)> storage; | ||||
#ifdef _Py_TPFLAGS_HAVE_VECTORCALL | |||||
PyObject* vectorcall_slot; | |||||
#endif | |||||
inline T* inst() { | inline T* inst() { | ||||
return reinterpret_cast<T*>(&storage); | return reinterpret_cast<T*>(&storage); | ||||
@@ -155,6 +164,51 @@ private: | |||||
// polyfills | // polyfills | ||||
struct tp_vectorcall { | |||||
static constexpr bool valid = HAS_MEMBER(T, tp_vectorcall); | |||||
static constexpr bool haskw = [](){if constexpr (valid) | |||||
if constexpr (std::is_invocable_v<T::tp_vectorcall, T, PyObject*const*, size_t, PyObject*>) | |||||
return true; | |||||
return false;}(); | |||||
template<typename = void> | |||||
static PyObject* impl(PyObject* self, PyObject*const* args, size_t nargsf, PyObject *kwnames) { | |||||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||||
if constexpr (haskw) { | |||||
CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf, kwnames)); | |||||
} else { | |||||
if (kwnames && PyTuple_GET_SIZE(kwnames)) { | |||||
PyErr_SetString(PyExc_TypeError, "expect no keyword argument"); | |||||
return nullptr; | |||||
} | |||||
CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf)); | |||||
} | |||||
} | |||||
static constexpr Py_ssize_t offset = []() {if constexpr (valid) return offsetof(wrap_t, vectorcall_slot); | |||||
else return 0;}(); | |||||
}; | |||||
struct tp_call { | |||||
static constexpr bool provided = HAS_MEMBER(T, tp_call); | |||||
static constexpr bool static_form = invocable_with<T, PyObject*, PyObject*, PyObject*>{}( | |||||
[](auto&& t, auto... args) -> decltype(std::decay_t<decltype(t)>::tp_call(args...)) {}); | |||||
static constexpr bool valid = provided || tp_vectorcall::valid; | |||||
template<typename = void> | |||||
static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { | |||||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||||
CVT_RET_PYOBJ(inst->tp_call(args, kwargs)); | |||||
} | |||||
static constexpr ternaryfunc value = []() {if constexpr (static_form) return T::tp_call; | |||||
else if constexpr (provided) return impl<>; | |||||
#ifdef _Py_TPFLAGS_HAVE_VECTORCALL | |||||
else if constexpr (valid) return PyVectorcall_Call; | |||||
#endif | |||||
else return nullptr;}(); | |||||
}; | |||||
struct tp_new { | struct tp_new { | ||||
static constexpr bool provided = HAS_MEMBER(T, tp_new); | static constexpr bool provided = HAS_MEMBER(T, tp_new); | ||||
static constexpr bool varkw = std::is_constructible_v<T, PyObject*, PyObject*>; | static constexpr bool varkw = std::is_constructible_v<T, PyObject*, PyObject*>; | ||||
@@ -163,11 +217,14 @@ private: | |||||
template<typename = void> | template<typename = void> | ||||
static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) { | static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) { | ||||
auto* self = type->tp_alloc(type, 0); | auto* self = type->tp_alloc(type, 0); | ||||
auto* ptr = reinterpret_cast<wrap_t*>(self)->inst(); | |||||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||||
if constexpr (has_vectorcall && tp_vectorcall::valid) { | |||||
reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>; | |||||
} | |||||
if constexpr (varkw) { | if constexpr (varkw) { | ||||
new(ptr) T(args, kwargs); | |||||
new(inst) T(args, kwargs); | |||||
} else { | } else { | ||||
new(ptr) T(); | |||||
new(inst) T(); | |||||
} | } | ||||
return self; | return self; | ||||
} | } | ||||
@@ -190,22 +247,6 @@ private: | |||||
else return impl<>;}(); | else return impl<>;}(); | ||||
}; | }; | ||||
struct tp_call { | |||||
static constexpr bool valid = HAS_MEMBER(T, tp_call); | |||||
static constexpr bool static_form = invocable_with<T, PyObject*, PyObject*, PyObject*>{}( | |||||
[](auto&& t, auto... args) -> decltype(std::decay_t<decltype(t)>::tp_call(args...)) {}); | |||||
template<typename = void> | |||||
static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { | |||||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||||
CVT_RET_PYOBJ(inst->tp_call(args, kwargs)); | |||||
} | |||||
static constexpr ternaryfunc value = []() {if constexpr (static_form) return T::tp_call; | |||||
else if constexpr (valid) return impl<>; | |||||
else return nullptr;}(); | |||||
}; | |||||
public: | public: | ||||
class TypeBuilder { | class TypeBuilder { | ||||
std::vector<PyMethodDef> m_methods; | std::vector<PyMethodDef> m_methods; | ||||
@@ -228,9 +269,17 @@ public: | |||||
m_type.tp_name = T::tp_name; | m_type.tp_name = T::tp_name; | ||||
} | } | ||||
m_type.tp_dealloc = tp_dealloc::value; | m_type.tp_dealloc = tp_dealloc::value; | ||||
#ifdef _Py_TPFLAGS_HAVE_VECTORCALL | |||||
m_type.tp_vectorcall_offset = tp_vectorcall::offset; | |||||
#endif | |||||
m_type.tp_call = tp_call::value; | m_type.tp_call = tp_call::value; | ||||
m_type.tp_basicsize = sizeof(wrap_t); | m_type.tp_basicsize = sizeof(wrap_t); | ||||
m_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | m_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | ||||
#ifdef _Py_TPFLAGS_HAVE_VECTORCALL | |||||
if constexpr (tp_vectorcall::valid) { | |||||
m_type.tp_flags |= _Py_TPFLAGS_HAVE_VECTORCALL; | |||||
} | |||||
#endif | |||||
m_type.tp_new = tp_new::value; | m_type.tp_new = tp_new::value; | ||||
} | } | ||||
@@ -0,0 +1,65 @@ | |||||
import numpy as np | |||||
from megengine.core.ops import builtin as ops | |||||
from megengine.core.tensor.core import apply | |||||
from megengine.core.tensor.raw_tensor import as_raw_tensor | |||||
from megengine.jit import exclude_from_trace, trace | |||||
def test_trace(): | |||||
for symbolic in [False, True]: | |||||
@trace(symbolic=symbolic) | |||||
def f(x): | |||||
op = ops.Elemwise(mode="negate") | |||||
(y,) = apply(op, x) | |||||
return y | |||||
x = as_raw_tensor([1]).numpy() | |||||
y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||||
for i in range(3): | |||||
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||||
def test_exclude_from_trace(): | |||||
for symbolic in [False, True]: | |||||
@trace(symbolic=symbolic) | |||||
def f(x): | |||||
neg = ops.Elemwise(mode="negate") | |||||
(x,) = apply(neg, x) | |||||
with exclude_from_trace(): | |||||
if i % 2: | |||||
(x,) = apply(neg, x) | |||||
(x,) = apply(neg, x) | |||||
return x | |||||
x = as_raw_tensor([1]).numpy() | |||||
for i in range(3): | |||||
y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||||
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||||
def test_print_in_trace(): | |||||
for symbolic in [False]: # cannot read value in symbolic mode | |||||
@trace(symbolic=symbolic) | |||||
def f(x): | |||||
nonlocal buf | |||||
neg = ops.Elemwise(mode="negate") | |||||
(x,) = apply(neg, x) | |||||
buf = x.numpy() | |||||
(x,) = apply(neg, x) | |||||
return x | |||||
buf = None | |||||
x = as_raw_tensor([1]).numpy() | |||||
for i in range(3): | |||||
y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||||
z = buf | |||||
buf = None | |||||
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||||
np.testing.assert_equal(z, buf) |
@@ -37,6 +37,15 @@ void* ChannelImpl::put(const HostTensorND& value) { | |||||
return info; | return info; | ||||
} | } | ||||
void* ChannelImpl::put(const DeviceTensorND& data) { | |||||
auto info = alloc(); | |||||
info->desc.layout = data.layout(); | |||||
info->desc.comp_node = data.comp_node(); | |||||
info->ptr = Tensor::make(data); | |||||
m_valid_handle.insert(info); | |||||
return info; | |||||
} | |||||
void ChannelImpl::del(void* handle) { | void ChannelImpl::del(void* handle) { | ||||
mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle); | mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle); | ||||
m_worker.add_task(Del{reinterpret_cast<TensorInfo*>(handle)}); | m_worker.add_task(Del{reinterpret_cast<TensorInfo*>(handle)}); | ||||
@@ -55,6 +55,7 @@ struct ChannelImpl : Interpreter::Channel { | |||||
~ChannelImpl() override; | ~ChannelImpl() override; | ||||
Handle put(const HostTensorND& value) override; | Handle put(const HostTensorND& value) override; | ||||
Handle put(const DeviceTensorND& value) override; | |||||
void del(Handle) override; | void del(Handle) override; | ||||
@@ -31,9 +31,10 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(InputCallback); | |||||
InputCallback::InputCallback(cg::ComputingGraph& graph, callback_t callback, | InputCallback::InputCallback(cg::ComputingGraph& graph, callback_t callback, | ||||
const VarNodeArray& inputs, | const VarNodeArray& inputs, | ||||
const TensorShape& output_shape, | |||||
const OperatorNodeConfig& config) | const OperatorNodeConfig& config) | ||||
: Super(&graph, config, "input_callback", inputs), | : Super(&graph, config, "input_callback", inputs), | ||||
m_callback(callback) { | |||||
m_output_shape(output_shape), m_callback(callback) { | |||||
for (VarNode* i : inputs) { | for (VarNode* i : inputs) { | ||||
add_input({i}); | add_input({i}); | ||||
} | } | ||||
@@ -48,7 +49,8 @@ InputCallback::InputCallback(cg::ComputingGraph& graph, callback_t callback, | |||||
SymbolVarArray InputCallback::make(cg::ComputingGraph& graph, | SymbolVarArray InputCallback::make(cg::ComputingGraph& graph, | ||||
callback_t callback, CompNode comp_node, | callback_t callback, CompNode comp_node, | ||||
DType dtype, const SymbolVarArray& inputs) { | |||||
DType dtype, const TensorShape& shape, | |||||
const SymbolVarArray& inputs) { | |||||
mgb_assert(comp_node.valid()); | mgb_assert(comp_node.valid()); | ||||
mgb_assert(dtype.valid()); | mgb_assert(dtype.valid()); | ||||
OperatorNodeConfig config; | OperatorNodeConfig config; | ||||
@@ -56,11 +58,22 @@ SymbolVarArray InputCallback::make(cg::ComputingGraph& graph, | |||||
config.output_dtype(dtype); | config.output_dtype(dtype); | ||||
auto vinputs = to_var_node_array(inputs); | auto vinputs = to_var_node_array(inputs); | ||||
auto opr = graph.insert_opr( | auto opr = graph.insert_opr( | ||||
std::make_unique<InputCallback>(graph, callback, vinputs, config)); | |||||
std::make_unique<InputCallback>(graph, callback, vinputs, shape, config)); | |||||
return to_symbol_var_array(opr->output()); | return to_symbol_var_array(opr->output()); | ||||
} | } | ||||
void InputCallback::init_output_static_infer_desc() {} | |||||
void InputCallback::init_output_static_infer_desc() { | |||||
if (m_output_shape.ndim) { | |||||
using namespace cg::static_infer; | |||||
auto &&mgr = owner_graph()->static_infer_manager(); | |||||
auto infer_shape = [this](TensorShape &dest, const InpVal &) { | |||||
dest = m_output_shape; | |||||
return true; | |||||
}; | |||||
mgr.register_shape_infer(output(0), | |||||
{SourceType::CONSTANT, {}, infer_shape}); | |||||
} | |||||
} | |||||
cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const { | cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const { | ||||
NodeProp* prop = Super::do_make_node_prop(); | NodeProp* prop = Super::do_make_node_prop(); | ||||
@@ -73,9 +86,23 @@ cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const { | |||||
void InputCallback::scn_do_execute() { | void InputCallback::scn_do_execute() { | ||||
auto dev_tensor = m_callback(); | auto dev_tensor = m_callback(); | ||||
if (m_output_shape.ndim) { | |||||
mgb_assert(dev_tensor.shape().eq_shape(m_output_shape)); | |||||
} | |||||
output(0)->reset_dev_tensor_from_tensor(dev_tensor); | output(0)->reset_dev_tensor_from_tensor(dev_tensor); | ||||
} | } | ||||
cg::OperatorNodeBase* InputCallback::shallow_copy( | |||||
const serialization::OprShallowCopyContext &ctx, | |||||
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | |||||
const OperatorNodeConfig &config) { | |||||
auto &&opr = opr_.cast_final_safe<InputCallback>(); | |||||
auto* graph = ctx.owner_graph(opr, inputs); | |||||
return graph->insert_opr(std::make_unique<InputCallback>(*graph, opr.m_callback, inputs, opr.m_output_shape, config)); | |||||
} | |||||
MGB_REG_OPR_SHALLOW_COPY(InputCallback, InputCallback::shallow_copy); | |||||
/* ================ OutputCallback ================== */ | /* ================ OutputCallback ================== */ | ||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(OutputCallback); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(OutputCallback); | ||||
@@ -122,6 +149,17 @@ void OutputCallback::scn_do_execute() { | |||||
m_param.callback(input(0)->dev_tensor()); | m_param.callback(input(0)->dev_tensor()); | ||||
} | } | ||||
cg::OperatorNodeBase* OutputCallback::shallow_copy( | |||||
const serialization::OprShallowCopyContext &ctx, | |||||
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | |||||
const OperatorNodeConfig &config) { | |||||
auto &&opr = opr_.cast_final_safe<OutputCallback>(); | |||||
auto* graph = ctx.owner_graph(opr, inputs); | |||||
return graph->insert_opr(std::make_unique<OutputCallback>(opr.m_param, inputs, config)); | |||||
} | |||||
MGB_REG_OPR_SHALLOW_COPY(OutputCallback, OutputCallback::shallow_copy); | |||||
/* ================ NopCallback ================== */ | /* ================ NopCallback ================== */ | ||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(NopCallback); | MGB_DYN_TYPE_OBJ_FINAL_IMPL(NopCallback); | ||||
@@ -22,6 +22,7 @@ struct Interpreter { | |||||
virtual ~Channel() = default; | virtual ~Channel() = default; | ||||
virtual Handle put(const HostTensorND& value) = 0; | virtual Handle put(const HostTensorND& value) = 0; | ||||
virtual Handle put(const DeviceTensorND& value) = 0; | |||||
virtual void del(Handle) = 0; | virtual void del(Handle) = 0; | ||||
@@ -17,6 +17,7 @@ | |||||
#include "megbrain/opr/internal/param_tag_defs.h" | #include "megbrain/opr/internal/param_tag_defs.h" | ||||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | #include "megbrain/opr/internal/megdnn_opr_wrapper.h" | ||||
#include "megbrain/opr/param_defs.h" | #include "megbrain/opr/param_defs.h" | ||||
#include "megbrain/serialization/sereg.h" | |||||
#include "megdnn/oprs/utils.h" | #include "megdnn/oprs/utils.h" | ||||
@@ -33,17 +34,24 @@ public: | |||||
InputCallback(cg::ComputingGraph& graph, | InputCallback(cg::ComputingGraph& graph, | ||||
callback_t callback, | callback_t callback, | ||||
const VarNodeArray& inputs, | const VarNodeArray& inputs, | ||||
const TensorShape& output_shape, | |||||
const OperatorNodeConfig &config); | const OperatorNodeConfig &config); | ||||
static SymbolVarArray make(cg::ComputingGraph& graph, | static SymbolVarArray make(cg::ComputingGraph& graph, | ||||
callback_t callback, | callback_t callback, | ||||
CompNode comp_node, | CompNode comp_node, | ||||
DType dtype, | DType dtype, | ||||
const TensorShape& shape, | |||||
const SymbolVarArray& inputs = {}); | const SymbolVarArray& inputs = {}); | ||||
static cg::OperatorNodeBase* shallow_copy( | |||||
const serialization::OprShallowCopyContext &ctx, | |||||
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | |||||
const OperatorNodeConfig &config); | |||||
protected: | protected: | ||||
void scn_do_execute() override; | void scn_do_execute() override; | ||||
void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||
NodeProp* do_make_node_prop() const override; | NodeProp* do_make_node_prop() const override; | ||||
private: | private: | ||||
TensorShape m_output_shape; | |||||
callback_t m_callback; | callback_t m_callback; | ||||
}; | }; | ||||
@@ -63,6 +71,10 @@ public: | |||||
SymbolVar input) { | SymbolVar input) { | ||||
return make(std::move(param), SymbolVarArray{input}); | return make(std::move(param), SymbolVarArray{input}); | ||||
} | } | ||||
static cg::OperatorNodeBase* shallow_copy( | |||||
const serialization::OprShallowCopyContext &ctx, | |||||
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | |||||
const OperatorNodeConfig &config); | |||||
protected: | protected: | ||||
void scn_do_execute() override; | void scn_do_execute() override; | ||||
void init_output_static_infer_desc() override; | void init_output_static_infer_desc() override; | ||||