@@ -17,15 +17,31 @@ from ..ops.builtin import OpDef | |||
from .core import OpBase, TensorBase, apply | |||
class CompiledFunction: | |||
def __init__(self, graph, function): | |||
self._graph = graph | |||
self._function = function | |||
class Graph(_imperative_rt.ComputingGraph): | |||
def __init__(self): | |||
super().__init__() | |||
self._var_cache = weakref.WeakKeyDictionary() | |||
self._op_cache = weakref.WeakKeyDictionary() | |||
self._executor = ThreadPoolExecutor(1) | |||
self._function = None | |||
self._future = None | |||
def _wrap(self, obj): | |||
if type(obj) is _imperative_rt.VarNode: | |||
wrapper, cache = VarNode, self._var_cache | |||
elif type(obj) is _imperative_rt.OperatorNode: | |||
wrapper, cache = OpNode, self._op_cache | |||
if obj not in cache: | |||
cache[obj] = wrapper(obj) | |||
return cache[obj] | |||
def compile(self, *args): | |||
self._function = super().compile(_unwrap(args)) | |||
return self | |||
def execute(self, *args): | |||
assert self._future is None | |||
self._future = self._graph._executor.submit(self._function.execute, *args) | |||
self._future = self._executor.submit(self._function.execute, *args) | |||
def wait(self): | |||
assert self._future is not None | |||
@@ -40,30 +56,23 @@ class CompiledFunction: | |||
self.execute(*args) | |||
return self.wait() | |||
def make_const(self, data, dtype=None, device=None): | |||
if isinstance(data, _imperative_rt.DeviceTensorND): | |||
assert dtype is None and device is None | |||
return self._wrap(_imperative_rt.make_shared(self, data)) | |||
else: | |||
device = as_device(device).to_c() | |||
return self._wrap(_imperative_rt.make_const(self, data, device, dtype)) | |||
class Graph(_imperative_rt.ComputingGraph): | |||
def __init__(self): | |||
super().__init__() | |||
self._var_cache = weakref.WeakKeyDictionary() | |||
self._op_cache = weakref.WeakKeyDictionary() | |||
self._executor = ThreadPoolExecutor(1) | |||
def _wrap(self, obj): | |||
if type(obj) is _imperative_rt.VarNode: | |||
wrapper, cache = VarNode, self._var_cache | |||
elif type(obj) is _imperative_rt.OperatorNode: | |||
wrapper, cache = OpNode, self._op_cache | |||
if obj not in cache: | |||
cache[obj] = wrapper(obj) | |||
return cache[obj] | |||
def compile(self, *args): | |||
return CompiledFunction(self, super().compile(_unwrap(args))) | |||
def make_input(self, *args: "VarNode", device=None, dtype=None, shape=None): | |||
opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self) | |||
return opnode.outputs[0] | |||
class VarNode(TensorBase): | |||
def __init__(self, node: _imperative_rt.VarNode): | |||
self._node = node | |||
self.graph._var_cache[node] = self | |||
@property | |||
def graph(self) -> Graph: | |||
@@ -81,10 +90,15 @@ class VarNode(TensorBase): | |||
def device(self): | |||
return as_device(self._node.comp_node) | |||
@property | |||
def shape(self): | |||
return self._node.shape | |||
class OpNode: | |||
def __init__(self, node: _imperative_rt.OperatorNode): | |||
self._node = node | |||
self.graph._op_cache[node] = self | |||
@property | |||
def graph(self) -> Graph: | |||
@@ -117,21 +131,21 @@ def _(op: OpDef, *args: VarNode): | |||
return _wrap(outputs) | |||
def input_callback(callback, *args, device=None, dtype=None, graph=None): | |||
def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): | |||
outputs = _imperative_rt.input_callback( | |||
callback, as_device(device).to_c(), dtype, _unwrap(args), graph=graph | |||
callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph | |||
) | |||
value, dummy = _wrap(outputs) | |||
return value, dummy | |||
class InputNode(OpNode): | |||
def __init__(self, *args: VarNode, device=None, dtype=None, graph=None): | |||
def __init__(self, *args: VarNode, device=None, dtype=None, shape=None, graph=None): | |||
r = _imperative_rt.DeviceTensorNDRendezvous() | |||
if device is not None: | |||
device = as_device(device).to_c() | |||
outputs = _imperative_rt.input_callback( | |||
r, device, dtype, _unwrap(args), graph=graph | |||
r, device, dtype, shape, _unwrap(args), graph=graph | |||
) | |||
super().__init__(outputs[0].owner) | |||
self._rendezvous = r | |||
@@ -169,6 +183,29 @@ class OutputNode(OpNode): | |||
def get_value(self): | |||
return self._rendezvous.get() | |||
def drop_value(self): | |||
self._rendezvous.drop() | |||
def reset(self): | |||
self._rendezvous.reset() | |||
class ValueOutputNode(OpNode): | |||
def __init__(self, var, *args): | |||
args = (var,) + args | |||
r = _imperative_rt.HostTensorNDRendezvous() | |||
dummy = _imperative_rt.value_output_callback(r, _unwrap(args)) | |||
super().__init__(dummy.owner) | |||
self._rendezvous = r | |||
def get_value(self): | |||
hostnd, event = self._rendezvous.get() | |||
event.wait() | |||
return hostnd.numpy() | |||
def drop_value(self): | |||
self._rendezvous.drop() | |||
def reset(self): | |||
self._rendezvous.reset() | |||
@@ -192,5 +229,8 @@ class AttrOutputNode(OpNode): | |||
attr = self._rendezvous.get() | |||
return TensorAttr(attr.shape, attr.dtype, as_device(attr.comp_node)) | |||
def drop_value(self): | |||
self._rendezvous.drop() | |||
def reset(self): | |||
self._rendezvous.reset() |
@@ -31,11 +31,13 @@ class RawTensor(TensorBase): | |||
_init_cb = None | |||
_del_cb = None | |||
_handle = None | |||
def __init__(self, handle): | |||
def __init__(self, handle=None): | |||
self._handle = handle | |||
if self._init_cb: | |||
self._init_cb() | |||
if handle is not None: | |||
if self._init_cb: | |||
self._init_cb() | |||
@property | |||
def dtype(self): | |||
@@ -61,9 +63,10 @@ class RawTensor(TensorBase): | |||
) | |||
def __del__(self): | |||
if self._del_cb: | |||
self._del_cb() | |||
delete(self._handle) | |||
if self._handle is not None: | |||
if self._del_cb: | |||
self._del_cb() | |||
delete(self._handle) | |||
@apply.register() | |||
@@ -89,6 +92,11 @@ def as_raw_tensor(obj, dtype=None, device=None): | |||
return as_raw_tensor(obj, device=device) | |||
@as_raw_tensor.register(DeviceTensorND) | |||
def _(data: DeviceTensorND): | |||
return RawTensor(put(data)) | |||
@as_raw_tensor.register(np.ndarray) | |||
def _(array: np.ndarray, dtype=None, device=None): | |||
device = None if device is None else as_device(device).to_c() | |||
@@ -0,0 +1 @@ | |||
from .tracing import exclude_from_trace, trace |
@@ -0,0 +1,514 @@ | |||
import contextlib | |||
import functools | |||
import typing | |||
import weakref | |||
from ..core.ops.special import Const | |||
from ..core.tensor import megbrain_graph as G | |||
from ..core.tensor.core import OpBase, apply | |||
from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor | |||
class TraceMismatchError(RuntimeError): | |||
pass | |||
active_trace = None | |||
skip_tracing = False | |||
@contextlib.contextmanager | |||
def exclude_from_trace(): | |||
global skip_tracing | |||
if skip_tracing: | |||
yield | |||
return | |||
try: | |||
skip_tracing = True | |||
if active_trace is not None: | |||
active_trace._begin_excluded_region() | |||
yield | |||
finally: | |||
skip_tracing = False | |||
class TensorInfo: | |||
__slots__ = ( | |||
# collected attributes | |||
"external", | |||
"exported", | |||
"data_read", | |||
"shape_read", | |||
"value_read", | |||
"device", | |||
"dtype", | |||
"bound_data", | |||
# resources for execution | |||
"varnode", | |||
"data_setter", | |||
"shape_reader", | |||
"value_reader", | |||
"data_reader", | |||
) | |||
def __init__(self): | |||
self.exported = None | |||
self.data_read = None | |||
self.shape_read = None | |||
self.value_read = None | |||
self.bound_data = None | |||
self.data_setter = None | |||
self.shape_reader = None | |||
self.value_reader = None | |||
self.data_reader = None | |||
class trace: | |||
def __new__(cls, *args, **kwargs): | |||
if not args: | |||
return functools.partial(cls, **kwargs) | |||
self = super().__new__(cls) | |||
self.__init__(*args, **kwargs) | |||
return self | |||
def __init__(self, function, symbolic=False, capture_as_const=False): | |||
self.__wrapped__ = function | |||
self._symbolic = symbolic | |||
self._capture_as_const = capture_as_const | |||
self._capture_static_shape = False | |||
self._untraced = True | |||
self._tinfo = [] # handle -> TensorInfo | |||
self._seq = [] | |||
self._pc = 0 | |||
self._graph = None | |||
self._need_reset_nodes = None | |||
self._lazy_eval_graph = None | |||
self._lazy_eval_tensors = weakref.WeakSet() | |||
self._active_tensors = weakref.WeakSet() | |||
def _new_handle(self): | |||
handle = len(self._tinfo) | |||
info = TensorInfo() | |||
self._tinfo.append(info) | |||
return handle, info | |||
def _apply_op(self, op, args): | |||
assert not self._untraced | |||
# check against trace | |||
if self._pc >= len(self._seq): | |||
raise TraceMismatchError("trace should end here, but more op observed") | |||
record = self._seq[self._pc] | |||
op_, ihandles, ohandles = record | |||
if op != op_: | |||
raise TraceMismatchError("op different from last time") | |||
if len(ihandles) != len(args): | |||
raise TraceMismatchError("op input size different from last time") | |||
for h, x in zip(ihandles, args): | |||
info = self._tinfo[h] | |||
if info.external: | |||
if ( | |||
x.__class__ is CompiledTensorProxy | |||
and not self._tinfo[x._CompiledTensorProxy__handle].exported | |||
): | |||
raise TraceMismatchError( | |||
"failed to capture: input was an external tensor " | |||
"last time, got an internal tensor this time" | |||
) | |||
if info.bound_data: | |||
if x.__class__ is CompiledTensorProxy: | |||
raise TraceMismatchError( | |||
"const capture violated: was an external tensor " | |||
"last time, got an internal tensor this time" | |||
) | |||
if x._handle != info.bound_data._handle: | |||
raise TraceMismatchError( | |||
"const capture violated: got " | |||
"a different tensor this time" | |||
) | |||
else: | |||
if info.dtype != x.dtype: | |||
raise TraceMismatchError( | |||
"failed to capture: different dtype from last time" | |||
) | |||
if info.device != x.device: | |||
raise TraceMismatchError( | |||
"failed to capture: different device from last time" | |||
) | |||
info.data_setter.set_value(x._dev_tensor()) | |||
else: | |||
if x.__class__ is not CompiledTensorProxy: | |||
raise TraceMismatchError( | |||
"unexpected capture: trying to use an external tensor as input, " | |||
"but that input was an internal tensor last time" | |||
) | |||
if x._CompiledTensorProxy__handle != h: | |||
raise TraceMismatchError( | |||
"mis-wiring: input edge to an data flow " | |||
"graph node is different from last time" | |||
) | |||
self._pc += 1 | |||
outputs = tuple([CompiledTensorProxy(h) for h in ohandles]) | |||
self._active_tensors.update(outputs) | |||
return outputs | |||
def _record_op(self, op, inputs, outputs): | |||
if skip_tracing: | |||
for x in inputs: | |||
h = getattr(x, "_TraceMixin__handle", None) | |||
if h is not None: | |||
self._tinfo[h].data_read = True | |||
return | |||
ihandles = [] | |||
for x in inputs: | |||
h = getattr(x, "_TraceMixin__handle", None) | |||
if h is None or (not self._capture_as_const and self._tinfo[h].exported): | |||
h, info = self._new_handle() | |||
info.external = True | |||
info.device = x.device | |||
info.dtype = x.dtype | |||
if self._capture_as_const: | |||
info.bound_data = x | |||
ihandles.append(h) | |||
ohandles = [] | |||
for x in outputs: | |||
h, info = self._new_handle() | |||
ohandles.append(h) | |||
info.external = False | |||
TraceMixin._TraceMixin__inject(x, h) | |||
self._seq.append((op, tuple(ihandles), tuple(ohandles))) | |||
self._active_tensors.update(outputs) | |||
@contextlib.contextmanager | |||
def _setup(self): | |||
global active_trace | |||
if active_trace: | |||
raise NotImplementedError("sorry, not implemented: nested trace") | |||
active_trace = self | |||
if self._untraced: | |||
apply.enable(apply_with_tracing) | |||
if self._symbolic: | |||
apply.enable(apply_symbolic_mode) | |||
self._lazy_eval_graph = G.Graph() | |||
else: | |||
apply.enable(apply_compiled_mode) | |||
if self._graph is None: | |||
self._compile() | |||
self._graph.execute() | |||
yield | |||
escaped_tensors = tuple(self._active_tensors) | |||
self._active_tensors.clear() | |||
if self._untraced: | |||
for x in escaped_tensors: | |||
info = self._tinfo[x._TraceMixin__handle] | |||
info.data_read = True | |||
x._TraceMixin__restore() | |||
if self._symbolic: | |||
# eval lazy eval tensors | |||
lazy_eval_tensors = tuple(self._lazy_eval_tensors) | |||
if lazy_eval_tensors: | |||
readers = [ | |||
G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] | |||
for x in lazy_eval_tensors | |||
] | |||
self._lazy_eval_graph.compile(*readers) | |||
self._lazy_eval_graph() | |||
for r, x in zip(readers, lazy_eval_tensors): | |||
assign_raw_tensor(x, as_raw_tensor(r.op.get_value())) | |||
self._lazy_eval_graph = None | |||
self._lazy_eval_tensors = None | |||
self._untraced = False | |||
else: | |||
if self._pc != len(self._seq): | |||
raise TraceMismatchError("premature end") | |||
for x in escaped_tensors: | |||
assign_raw_tensor(x, as_raw_tensor(x._dev_tensor())) | |||
self._graph.wait() | |||
self._reset_exec_env() | |||
self._pc = 0 | |||
apply.disable(apply_with_tracing) | |||
apply.disable(apply_symbolic_mode) | |||
apply.disable(apply_compiled_mode) | |||
active_trace = None | |||
def _begin_excluded_region(self): | |||
if self._untraced: | |||
# conditionally reading a compiled tensor in excluded region | |||
# is permitted, so we have to assume every tensor might be read | |||
for x in self._active_tensors: | |||
info = self._tinfo[x._TraceMixin__handle] | |||
info.exported = True | |||
info.data_read = True | |||
def _compile(self): | |||
graph = self._graph = G.Graph() | |||
# graph.options.graph_opt_level = 0 | |||
need_reset_nodes = self._need_reset_nodes = [] | |||
# links enforce ordering of I/O nodes | |||
links = () | |||
for op, ihandles, ohandles in self._seq: | |||
ivars = [] | |||
readers = [] | |||
for h in ihandles: | |||
info = self._tinfo[h] | |||
if not hasattr(info, "varnode"): | |||
assert info.external | |||
if info.bound_data: | |||
info.varnode = graph.make_const(info.bound_data._dev_tensor()) | |||
else: | |||
opnode = info.data_setter = G.InputNode( | |||
*links, device=info.device, dtype=info.dtype, graph=graph | |||
) | |||
need_reset_nodes.append(opnode) | |||
info.varnode, *links = opnode.outputs | |||
ivars.append(info.varnode) | |||
ovars = apply(op, *ivars) | |||
assert len(ovars) == len(ohandles) | |||
for h, v in zip(ohandles, ovars): | |||
info = self._tinfo[h] | |||
info.varnode = v | |||
def add_reader(opnode): | |||
nonlocal links | |||
need_reset_nodes.append(opnode) | |||
readers.append(opnode.outputs[0]) | |||
links = opnode.outputs | |||
if info.data_read: | |||
# Shape can be obtained from data so doesn't need its own | |||
# output node. On the other hand, value is read separately | |||
# to leverage eager h2d copy | |||
info.shape_read = False | |||
opnode = info.data_reader = G.OutputNode(v, *links) | |||
add_reader(opnode) | |||
if info.value_read: | |||
opnode = info.value_reader = G.ValueOutputNode(v, *links) | |||
add_reader(opnode) | |||
if info.shape_read: | |||
opnode = info.shape_reader = G.AttrOutputNode(v, *links) | |||
add_reader(opnode) | |||
graph.compile(*readers) | |||
def _reset_exec_env(self): | |||
for opnode in self._need_reset_nodes: | |||
opnode.reset() | |||
def _require_shape(self, handle): | |||
info = self._tinfo[handle] | |||
info.shape_read = True | |||
def _require_value(self, handle): | |||
info = self._tinfo[handle] | |||
info.value_read = True | |||
def _require_data(self, handle): | |||
info = self._tinfo[handle] | |||
info.data_read = True | |||
def __call__(self, *args, **kwargs): | |||
with self._setup(): | |||
return self.__wrapped__(*args, **kwargs) | |||
class CompiledTensorProxy(RawTensor): | |||
""" | |||
Duck-typed RawTensor | |||
""" | |||
def __init__(self, handle): | |||
self.__handle = handle | |||
self.__info = active_trace._tinfo[handle] | |||
self.__shape = None | |||
self.__data = None | |||
self.__value = None | |||
@property | |||
def dtype(self): | |||
return self.__info.varnode.dtype | |||
@property | |||
def device(self): | |||
return self.__info.varnode.device | |||
@property | |||
def shape(self): | |||
if self.__shape is None: | |||
if self.__info.shape_read: | |||
self.__shape = self.__info.shape_reader.get_value().shape | |||
elif self.__info.data_read: | |||
self.__shape = self._dev_tensor().shape | |||
else: | |||
raise TraceMismatchError("shape of this tensor is not read in trace") | |||
return self.__shape | |||
def numpy(self): | |||
if self.__value is None: | |||
if self.__info.value_read: | |||
self.__value = self.__info.value_reader.get_value() | |||
elif self.__info.data_read: | |||
self.__value = self._dev_tensor().numpy() | |||
else: | |||
raise TraceMismatchError("value of this tensor is not read in trace") | |||
return self.__value | |||
def _dev_tensor(self): | |||
if self.__data is None: | |||
if not self.__info.data_read: | |||
raise TraceMismatchError("raw data of this tensor is not read in trace") | |||
self.__data = self.__info.data_reader.get_value() | |||
return self.__data | |||
def __del__(self): | |||
if self.__info.shape_read and self.__shape is not None: | |||
self.__info.shape_reader.drop_value() | |||
if self.__info.value_read and self.__value is not None: | |||
self.__info.value_reader.drop_value() | |||
if self.__info.data_read and self.__data is not None: | |||
self.__info.data_reader.drop_value() | |||
class LazyEvalTensor(RawTensor): | |||
def __init__(self, varnode): | |||
self.__varnode = varnode | |||
@property | |||
def dtype(self): | |||
return self.__varnode.dtype | |||
@property | |||
def device(self): | |||
return self.__varnode.device | |||
@property | |||
def shape(self): | |||
return self.__varnode.shape | |||
def numpy(self): | |||
raise RuntimeError("cannot read value during symbolic tracing") | |||
def _dev_tensor(self): | |||
raise RuntimeError("cannot access data during symbolic tracing") | |||
class TraceMixin: | |||
__subclass_cache = {} | |||
def __inject(self, handle): | |||
cache = __class__.__subclass_cache | |||
cls = self.__class__ | |||
subcls = cache.get(cls) | |||
if subcls is None: | |||
subcls = cache[cls] = type("Traced" + cls.__name__, (__class__, cls), {}) | |||
self.__class__ = subcls | |||
self.__handle = handle | |||
self.__cls = cls | |||
return self | |||
def __restore(self): | |||
cls = self.__cls | |||
del self.__handle | |||
del self.__cls | |||
self.__class__ = cls | |||
return self | |||
@property | |||
def shape(self): | |||
if not skip_tracing: | |||
active_trace._require_shape(self.__handle) | |||
return super().shape | |||
def numpy(self): | |||
if not skip_tracing: | |||
active_trace._require_value(self.__handle) | |||
return super().numpy() | |||
def _dev_tensor(self): | |||
if not skip_tracing: | |||
active_trace._require_data(self.__handle) | |||
return super()._dev_tensor() | |||
class TracedRawTensor(TraceMixin, RawTensor): | |||
pass | |||
class TracedLazyTensor(TraceMixin, LazyEvalTensor): | |||
pass | |||
def assign_raw_tensor(lhs, rhs): | |||
handle = rhs._handle | |||
rhs.__dict__.clear() | |||
lhs.__dict__.clear() | |||
lhs.__class__ = RawTensor | |||
lhs.__init__(handle) | |||
# this hook turns RawTensor into LazyEvalTensor | |||
@apply.register() | |||
def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||
graph = active_trace._lazy_eval_graph | |||
ivars = [ | |||
getattr(x, "_LazyEvalTensor__varnode", None) | |||
or graph.make_const(x._dev_tensor()) | |||
for x in args | |||
] | |||
ovars = apply(op, *ivars) | |||
outputs = [LazyEvalTensor(v) for v in ovars] | |||
active_trace._lazy_eval_tensors.update(outputs) | |||
return outputs | |||
apply.disable(apply_symbolic_mode) | |||
@apply.register() | |||
def apply_compiled_mode(op: OpDef, *args: RawTensor): | |||
if skip_tracing: | |||
args = [ | |||
as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x | |||
for x in args | |||
] | |||
return apply.super(op, *args) | |||
return active_trace._apply_op(op, args) | |||
apply.disable(apply_compiled_mode) | |||
# this hook injects TraceMixin | |||
@apply.register() | |||
def apply_with_tracing(op: OpDef, *args: RawTensor): | |||
outputs = apply.super(op, *args) | |||
active_trace._record_op(op, args, outputs) | |||
return outputs | |||
apply.disable(apply_with_tracing) | |||
# @apply.register() | |||
# def _(op: Const, *args: RawTensor): | |||
# return active_trace._apply_const(op, args) | |||
class BrokenRawTensor(RawTensor): | |||
def __getattribute__(self, _): | |||
raise RuntimeError("broken due to misuse of tracing") | |||
def __setattr__(self, *_): | |||
raise RuntimeError("broken due to misuse of tracing") |
@@ -23,10 +23,29 @@ namespace py = pybind11; | |||
using namespace mgb; | |||
using namespace imperative; | |||
namespace { | |||
template<typename XTensorND> | |||
auto def_TensorND(py::object parent, const char* name) { | |||
return py::class_<XTensorND>(parent, name) | |||
.def_property_readonly("shape", py::overload_cast<>(&XTensorND::shape, py::const_)) | |||
.def_property_readonly("dtype", py::overload_cast<>(&XTensorND::dtype, py::const_)) | |||
.def_property_readonly("comp_node", py::overload_cast<>(&XTensorND::comp_node, py::const_)) | |||
.def("copy_from", &XTensorND::template copy_from<DeviceTensorStorage>) | |||
.def("copy_from", &XTensorND::template copy_from<HostTensorStorage>) | |||
.def("copy_from_fixlayout", py::overload_cast<const DeviceTensorND&>( | |||
&XTensorND::template copy_from_fixlayout<DeviceTensorStorage>)) | |||
.def("copy_from_fixlayout", py::overload_cast<const HostTensorND&>( | |||
&XTensorND::template copy_from_fixlayout<HostTensorStorage>)); | |||
} | |||
} // namespace | |||
void init_common(py::module m) { | |||
py::class_<CompNode>(m, "CompNode") | |||
auto&& PyCompNode = py::class_<CompNode>(m, "CompNode") | |||
.def(py::init()) | |||
.def(py::init(py::overload_cast<const std::string&>(&CompNode::load))) | |||
.def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) | |||
.def("__str__", &CompNode::to_string_logical) | |||
.def_static("_sync_all", &CompNode::sync_all) | |||
.def(py::self == py::self) | |||
@@ -40,19 +59,30 @@ void init_common(py::module m) { | |||
return CompNode::load(cn); | |||
})); | |||
py::class_<CompNode::Event, std::shared_ptr<CompNode::Event>>(PyCompNode, "Event") | |||
.def("record", &CompNode::Event::record) | |||
.def("wait", &CompNode::Event::host_wait); | |||
py::implicitly_convertible<std::string, CompNode>(); | |||
py::class_<DeviceTensorND>(m, "DeviceTensorND") | |||
.def(py::init()) | |||
.def_property_readonly("shape", py::overload_cast<>(&DeviceTensorND::shape, py::const_)) | |||
.def_property_readonly("dtype", py::overload_cast<>(&DeviceTensorND::dtype, py::const_)) | |||
.def_property_readonly("comp_node", py::overload_cast<>(&DeviceTensorND::comp_node, py::const_)) | |||
def_TensorND<DeviceTensorND>(m, "DeviceTensorND") | |||
.def("numpy", [](const DeviceTensorND& self) { | |||
HostTensorND hv; | |||
hv.copy_from(self).sync(); | |||
return py::handle(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); | |||
}); | |||
def_TensorND<HostTensorND>(m, "HostTensorND") | |||
.def(py::init([](py::array data, CompNode cn, DType dtype) { | |||
if (!cn.valid()) { | |||
throw py::type_error("device must not be None"); | |||
} | |||
return npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); | |||
})) | |||
.def("numpy", [](const HostTensorND& self) { | |||
return py::reinterpret_steal<py::object>(npy::ndarray_from_tensor(self, npy::ShareType::TRY_SHARE)); | |||
}); | |||
py::class_<cg::OperatorNodeConfig>(m, "OperatorNodeConfig") | |||
.def(py::init()) | |||
.def_property("name", | |||
@@ -12,6 +12,7 @@ | |||
#include "./graph_rt.h" | |||
#include "megbrain/imperative/opr_utility.h" | |||
#include "megbrain/opr/io.h" | |||
#include "megbrain/opr/basic_arith.h" | |||
#include "megbrain/imperative.h" | |||
#include "./helper.h" | |||
@@ -29,29 +30,44 @@ auto def_rendezvous(py::object m, const char* name) { | |||
.def(py::init([](){return std::make_shared<Rendezvous<T>>();})) | |||
.def("set", [](Rendezvous<T>& r, T v) {r.set(std::move(v));}) | |||
.def("get", [](Rendezvous<T>& r) {return r.get();}, py::call_guard<py::gil_scoped_release>()) | |||
.def("drop", &Rendezvous<T>::drop) | |||
.def("reset", &Rendezvous<T>::reset); | |||
} | |||
using TensorAttr = LogicalTensorDesc; | |||
using HostNDWithEvent = std::pair<HostTensorND, std::shared_ptr<CompNode::Event>>; | |||
void init_graph_rt(py::module m) { | |||
def_rendezvous<DeviceTensorND>(m, "DeviceTensorNDRendezvous"); | |||
def_rendezvous<HostNDWithEvent>(m, "HostTensorNDRendezvous"); | |||
def_rendezvous<TensorAttr>(m, "TensorAttrRendezvous"); | |||
py::class_<cg::VarNode, GraphNodePtr<cg::VarNode>>(m, "VarNode") | |||
.def_property_readonly("owner", [](cg::VarNode* v) {return v->owner_opr();}) | |||
.def_property_readonly("graph", [](cg::VarNode* v) {return v->owner_graph();}) | |||
.def_property_readonly("name", py::overload_cast<>(&VarNode::name, py::const_)) | |||
.def_property_readonly("dtype", [](cg::VarNode* v) {return v->dtype();}) | |||
.def_property_readonly("comp_node", [](cg::VarNode* v) {return v->comp_node();}); | |||
.def_property_readonly("comp_node", [](cg::VarNode* v) {return v->comp_node();}) | |||
.def_property_readonly("shape", [](cg::VarNode* v) -> const TensorShape* { | |||
auto&& mgr = v->owner_graph()->static_infer_manager(); | |||
auto&& type = mgr.get_infer_type(v); | |||
using InferType = cg::static_infer::InferType; | |||
if (!(type.shape & (InferType::CONST | InferType::RT_STATIC))) { | |||
return nullptr; | |||
} | |||
return mgr.infer_shape_fallible(v); | |||
}); | |||
py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(m, "OperatorNode") | |||
.def_property_readonly("graph", [](cg::OperatorNodeBase* opr) {return opr->owner_graph();}) | |||
.def_property_readonly("name", py::overload_cast<>(&cg::OperatorNodeBase::name, py::const_)) | |||
.def_property_readonly("inputs", [](cg::OperatorNodeBase* opr) { | |||
return to_tuple(opr->input()); | |||
}) | |||
.def_property_readonly("outputs", [](cg::OperatorNodeBase* opr) { | |||
return to_tuple(opr->output()); | |||
return to_tuple(opr->usable_output()); | |||
}); | |||
py::class_<cg::AsyncExecutable>(m, "AsyncExecutable") | |||
@@ -117,7 +133,7 @@ void init_graph_rt(py::module m) { | |||
common.def("invoke_op", [](const OpDef& def, const std::vector<cg::VarNode*> inputs, cg::ComputingGraph* graph) { | |||
cg::VarNodeArray vinputs(inputs.begin(), inputs.end()); | |||
auto opr = OpDef::apply_on_var_node(def, vinputs); | |||
auto outputs = opr->output(); | |||
auto outputs = opr->usable_output(); | |||
return to_tuple(outputs); | |||
}, | |||
py::arg(), py::arg(), py::arg("graph") = py::none()); | |||
@@ -125,6 +141,7 @@ void init_graph_rt(py::module m) { | |||
auto input_callback = [](auto callback, | |||
const CompNode& comp_node, | |||
const DType& dtype, | |||
const TensorShape& shape, | |||
const std::vector<cg::VarNode*>& inputs, | |||
cg::ComputingGraph* graph) { | |||
if (!graph) { | |||
@@ -135,7 +152,7 @@ void init_graph_rt(py::module m) { | |||
sinputs.emplace_back(i); | |||
} | |||
static_assert(!std::is_reference<decltype(callback)>::value); | |||
auto soutputs = opr::InputCallback::make(*graph, std::move(callback), comp_node, dtype, sinputs); | |||
auto soutputs = opr::InputCallback::make(*graph, std::move(callback), comp_node, dtype, shape, sinputs); | |||
std::vector<VarNode*> outputs; | |||
outputs.reserve(soutputs.size()); | |||
for (auto i : soutputs) { | |||
@@ -144,26 +161,40 @@ void init_graph_rt(py::module m) { | |||
return outputs; | |||
}; | |||
m.def("make_shared", [](cg::ComputingGraph* graph, const DeviceTensorND& data) { | |||
return opr::SharedDeviceTensor::make(*graph, std::make_shared<DeviceTensorND>(data)).node(); | |||
}); | |||
m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) { | |||
if (!cn.valid()) { | |||
throw py::type_error("device must not be None"); | |||
} | |||
auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); | |||
opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); | |||
}); | |||
m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback, | |||
const CompNode& comp_node, | |||
const DType& dtype, | |||
const TensorShape& shape, | |||
const std::vector<cg::VarNode*>& inputs, | |||
cg::ComputingGraph* graph) { | |||
return input_callback([f=std::move(callback)](){py::gil_scoped_acquire _; return f();}, comp_node, dtype, inputs, graph); | |||
return input_callback([f=std::move(callback)](){py::gil_scoped_acquire _; return f();}, comp_node, dtype, shape, inputs, graph); | |||
}, | |||
py::arg(), py::arg(), py::arg(), py::arg() = py::tuple(), py::arg("graph") = py::none()); | |||
py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none()); | |||
m.def("input_callback", [input_callback](std::shared_ptr<Rendezvous<DeviceTensorND>> p, | |||
const CompNode& comp_node, | |||
const DType& dtype, | |||
const TensorShape& shape, | |||
const std::vector<cg::VarNode*>& inputs, | |||
cg::ComputingGraph* graph) { | |||
auto f = [p]() -> DeviceTensorND { | |||
return p->get(); | |||
}; | |||
return input_callback(std::move(f), comp_node, dtype, inputs, graph); | |||
return input_callback(std::move(f), comp_node, dtype, shape, inputs, graph); | |||
}, | |||
py::arg(), py::arg(), py::arg(), py::arg() = py::tuple(), py::arg("graph") = py::none()); | |||
py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none()); | |||
auto output_callback = [](auto callback, const std::vector<cg::VarNode*>& inputs, bool borrow = false) { | |||
SymbolVarArray sinputs; | |||
@@ -193,6 +224,17 @@ void init_graph_rt(py::module m) { | |||
return output_callback(std::move(f), std::move(inputs)); | |||
}); | |||
m.def("value_output_callback", [output_callback](std::shared_ptr<Rendezvous<HostNDWithEvent>> p, std::vector<cg::VarNode*> inputs) { | |||
auto f = [p](DeviceTensorND dv) { | |||
HostNDWithEvent hv_with_event; | |||
hv_with_event.first.copy_from(dv); | |||
hv_with_event.second = dv.comp_node().create_event(); | |||
hv_with_event.second->record(); | |||
p->set(std::move(hv_with_event)); | |||
}; | |||
return output_callback(std::move(f), std::move(inputs), true); | |||
}); | |||
m.def("attr_output_callback", [output_callback](std::shared_ptr<Rendezvous<TensorAttr>> p, std::vector<cg::VarNode*> inputs) { | |||
auto f = [p](DeviceTensorND dv) { | |||
p->set(TensorAttr{TensorLayout{dv.shape(), dv.dtype()}, dv.comp_node()}); | |||
@@ -39,6 +39,7 @@ template<typename R> | |||
class Rendezvous { | |||
std::mutex m_lock; | |||
int m_read_ahead = 0; | |||
bool m_drop_next = false; | |||
std::promise<R> m_promise; | |||
public: | |||
Rendezvous() = default; | |||
@@ -47,6 +48,7 @@ public: | |||
Rendezvous& operator=(const Rendezvous& rhs) = delete; | |||
Rendezvous& operator=(Rendezvous&& rhs) { | |||
MGB_LOCK_GUARD(m_lock); | |||
m_drop_next = rhs.m_drop_next; | |||
m_read_ahead = rhs.m_read_ahead; | |||
m_promise = std::move(rhs.m_promise); | |||
return *this; | |||
@@ -67,12 +69,28 @@ public: | |||
return f.get(); | |||
} | |||
void drop() { | |||
MGB_LOCK_GUARD(m_lock); | |||
mgb_assert(m_read_ahead <= 0); | |||
mgb_assert(m_read_ahead >= -1); | |||
if (m_read_ahead == -1) { | |||
m_promise = {}; | |||
} else { | |||
m_drop_next = true; | |||
} | |||
++m_read_ahead; | |||
} | |||
template<typename T> | |||
void set(T&& value) { | |||
MGB_LOCK_GUARD(m_lock); | |||
mgb_assert(m_read_ahead >= 0); | |||
mgb_assert(m_read_ahead <= 1); | |||
m_promise.set_value(std::forward<T>(value)); | |||
if (m_drop_next) { | |||
m_drop_next = false; | |||
} else { | |||
m_promise.set_value(std::forward<T>(value)); | |||
} | |||
if (m_read_ahead == 1) { | |||
m_promise = {}; | |||
} | |||
@@ -83,6 +101,7 @@ public: | |||
MGB_LOCK_GUARD(m_lock); | |||
m_promise = {}; | |||
m_read_ahead = 0; | |||
m_drop_next = false; | |||
} | |||
}; | |||
@@ -280,9 +280,12 @@ namespace detail { | |||
public: | |||
bool load(handle src, bool convert) { | |||
auto obj = reinterpret_steal<object>(src); | |||
if (!isinstance<tuple>(obj)) { | |||
if (!convert && !isinstance<tuple>(obj)) { | |||
return false; | |||
} | |||
if (obj.is_none()) { | |||
return true; | |||
} | |||
value.ndim = len(obj); | |||
mgb_assert(value.ndim <= mgb::TensorShape::MAX_NDIM); | |||
size_t i = 0; | |||
@@ -63,6 +63,7 @@ void init_imperative_rt(py::module m) { | |||
return self.put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype)); | |||
} | |||
}, py::arg(), py::arg("dtype") = py::none(), py::arg("device") = py::none()) | |||
.def("put", py::overload_cast<const DeviceTensorND&>(&Interpreter::Channel::put)) | |||
.def("delete", [](Interpreter::Channel& self, Interpreter::Handle handle) { | |||
return self.del(handle); | |||
}) | |||
@@ -24,6 +24,12 @@ constexpr bool has_fastcall = true; | |||
constexpr bool has_fastcall = false; | |||
#endif | |||
#ifdef _Py_TPFLAGS_HAVE_VECTORCALL | |||
constexpr bool has_vectorcall = true; | |||
#else | |||
constexpr bool has_vectorcall = false; | |||
#endif | |||
template<typename... Args> | |||
struct invocable_with { | |||
template<typename T> | |||
@@ -55,6 +61,9 @@ private: | |||
public: | |||
PyObject_HEAD | |||
std::aligned_storage_t<sizeof(T), alignof(T)> storage; | |||
#ifdef _Py_TPFLAGS_HAVE_VECTORCALL | |||
PyObject* vectorcall_slot; | |||
#endif | |||
inline T* inst() { | |||
return reinterpret_cast<T*>(&storage); | |||
@@ -155,6 +164,51 @@ private: | |||
// polyfills | |||
struct tp_vectorcall { | |||
static constexpr bool valid = HAS_MEMBER(T, tp_vectorcall); | |||
static constexpr bool haskw = [](){if constexpr (valid) | |||
if constexpr (std::is_invocable_v<T::tp_vectorcall, T, PyObject*const*, size_t, PyObject*>) | |||
return true; | |||
return false;}(); | |||
template<typename = void> | |||
static PyObject* impl(PyObject* self, PyObject*const* args, size_t nargsf, PyObject *kwnames) { | |||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
if constexpr (haskw) { | |||
CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf, kwnames)); | |||
} else { | |||
if (kwnames && PyTuple_GET_SIZE(kwnames)) { | |||
PyErr_SetString(PyExc_TypeError, "expect no keyword argument"); | |||
return nullptr; | |||
} | |||
CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf)); | |||
} | |||
} | |||
static constexpr Py_ssize_t offset = []() {if constexpr (valid) return offsetof(wrap_t, vectorcall_slot); | |||
else return 0;}(); | |||
}; | |||
struct tp_call { | |||
static constexpr bool provided = HAS_MEMBER(T, tp_call); | |||
static constexpr bool static_form = invocable_with<T, PyObject*, PyObject*, PyObject*>{}( | |||
[](auto&& t, auto... args) -> decltype(std::decay_t<decltype(t)>::tp_call(args...)) {}); | |||
static constexpr bool valid = provided || tp_vectorcall::valid; | |||
template<typename = void> | |||
static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { | |||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
CVT_RET_PYOBJ(inst->tp_call(args, kwargs)); | |||
} | |||
static constexpr ternaryfunc value = []() {if constexpr (static_form) return T::tp_call; | |||
else if constexpr (provided) return impl<>; | |||
#ifdef _Py_TPFLAGS_HAVE_VECTORCALL | |||
else if constexpr (valid) return PyVectorcall_Call; | |||
#endif | |||
else return nullptr;}(); | |||
}; | |||
struct tp_new { | |||
static constexpr bool provided = HAS_MEMBER(T, tp_new); | |||
static constexpr bool varkw = std::is_constructible_v<T, PyObject*, PyObject*>; | |||
@@ -163,11 +217,14 @@ private: | |||
template<typename = void> | |||
static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) { | |||
auto* self = type->tp_alloc(type, 0); | |||
auto* ptr = reinterpret_cast<wrap_t*>(self)->inst(); | |||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
if constexpr (has_vectorcall && tp_vectorcall::valid) { | |||
reinterpret_cast<wrap_t*>(self)->vectorcall_slot = &tp_vectorcall::template impl<>; | |||
} | |||
if constexpr (varkw) { | |||
new(ptr) T(args, kwargs); | |||
new(inst) T(args, kwargs); | |||
} else { | |||
new(ptr) T(); | |||
new(inst) T(); | |||
} | |||
return self; | |||
} | |||
@@ -190,22 +247,6 @@ private: | |||
else return impl<>;}(); | |||
}; | |||
struct tp_call { | |||
static constexpr bool valid = HAS_MEMBER(T, tp_call); | |||
static constexpr bool static_form = invocable_with<T, PyObject*, PyObject*, PyObject*>{}( | |||
[](auto&& t, auto... args) -> decltype(std::decay_t<decltype(t)>::tp_call(args...)) {}); | |||
template<typename = void> | |||
static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { | |||
auto* inst = reinterpret_cast<wrap_t*>(self)->inst(); | |||
CVT_RET_PYOBJ(inst->tp_call(args, kwargs)); | |||
} | |||
static constexpr ternaryfunc value = []() {if constexpr (static_form) return T::tp_call; | |||
else if constexpr (valid) return impl<>; | |||
else return nullptr;}(); | |||
}; | |||
public: | |||
class TypeBuilder { | |||
std::vector<PyMethodDef> m_methods; | |||
@@ -228,9 +269,17 @@ public: | |||
m_type.tp_name = T::tp_name; | |||
} | |||
m_type.tp_dealloc = tp_dealloc::value; | |||
#ifdef _Py_TPFLAGS_HAVE_VECTORCALL | |||
m_type.tp_vectorcall_offset = tp_vectorcall::offset; | |||
#endif | |||
m_type.tp_call = tp_call::value; | |||
m_type.tp_basicsize = sizeof(wrap_t); | |||
m_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
#ifdef _Py_TPFLAGS_HAVE_VECTORCALL | |||
if constexpr (tp_vectorcall::valid) { | |||
m_type.tp_flags |= _Py_TPFLAGS_HAVE_VECTORCALL; | |||
} | |||
#endif | |||
m_type.tp_new = tp_new::value; | |||
} | |||
@@ -0,0 +1,65 @@ | |||
import numpy as np | |||
from megengine.core.ops import builtin as ops | |||
from megengine.core.tensor.core import apply | |||
from megengine.core.tensor.raw_tensor import as_raw_tensor | |||
from megengine.jit import exclude_from_trace, trace | |||
def test_trace(): | |||
for symbolic in [False, True]: | |||
@trace(symbolic=symbolic) | |||
def f(x): | |||
op = ops.Elemwise(mode="negate") | |||
(y,) = apply(op, x) | |||
return y | |||
x = as_raw_tensor([1]).numpy() | |||
y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||
for i in range(3): | |||
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||
def test_exclude_from_trace(): | |||
for symbolic in [False, True]: | |||
@trace(symbolic=symbolic) | |||
def f(x): | |||
neg = ops.Elemwise(mode="negate") | |||
(x,) = apply(neg, x) | |||
with exclude_from_trace(): | |||
if i % 2: | |||
(x,) = apply(neg, x) | |||
(x,) = apply(neg, x) | |||
return x | |||
x = as_raw_tensor([1]).numpy() | |||
for i in range(3): | |||
y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||
def test_print_in_trace(): | |||
for symbolic in [False]: # cannot read value in symbolic mode | |||
@trace(symbolic=symbolic) | |||
def f(x): | |||
nonlocal buf | |||
neg = ops.Elemwise(mode="negate") | |||
(x,) = apply(neg, x) | |||
buf = x.numpy() | |||
(x,) = apply(neg, x) | |||
return x | |||
buf = None | |||
x = as_raw_tensor([1]).numpy() | |||
for i in range(3): | |||
y = f.__wrapped__(as_raw_tensor(x)).numpy() | |||
z = buf | |||
buf = None | |||
np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) | |||
np.testing.assert_equal(z, buf) |
@@ -37,6 +37,15 @@ void* ChannelImpl::put(const HostTensorND& value) { | |||
return info; | |||
} | |||
void* ChannelImpl::put(const DeviceTensorND& data) { | |||
auto info = alloc(); | |||
info->desc.layout = data.layout(); | |||
info->desc.comp_node = data.comp_node(); | |||
info->ptr = Tensor::make(data); | |||
m_valid_handle.insert(info); | |||
return info; | |||
} | |||
void ChannelImpl::del(void* handle) { | |||
mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle); | |||
m_worker.add_task(Del{reinterpret_cast<TensorInfo*>(handle)}); | |||
@@ -55,6 +55,7 @@ struct ChannelImpl : Interpreter::Channel { | |||
~ChannelImpl() override; | |||
Handle put(const HostTensorND& value) override; | |||
Handle put(const DeviceTensorND& value) override; | |||
void del(Handle) override; | |||
@@ -31,9 +31,10 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(InputCallback); | |||
InputCallback::InputCallback(cg::ComputingGraph& graph, callback_t callback, | |||
const VarNodeArray& inputs, | |||
const TensorShape& output_shape, | |||
const OperatorNodeConfig& config) | |||
: Super(&graph, config, "input_callback", inputs), | |||
m_callback(callback) { | |||
m_output_shape(output_shape), m_callback(callback) { | |||
for (VarNode* i : inputs) { | |||
add_input({i}); | |||
} | |||
@@ -48,7 +49,8 @@ InputCallback::InputCallback(cg::ComputingGraph& graph, callback_t callback, | |||
SymbolVarArray InputCallback::make(cg::ComputingGraph& graph, | |||
callback_t callback, CompNode comp_node, | |||
DType dtype, const SymbolVarArray& inputs) { | |||
DType dtype, const TensorShape& shape, | |||
const SymbolVarArray& inputs) { | |||
mgb_assert(comp_node.valid()); | |||
mgb_assert(dtype.valid()); | |||
OperatorNodeConfig config; | |||
@@ -56,11 +58,22 @@ SymbolVarArray InputCallback::make(cg::ComputingGraph& graph, | |||
config.output_dtype(dtype); | |||
auto vinputs = to_var_node_array(inputs); | |||
auto opr = graph.insert_opr( | |||
std::make_unique<InputCallback>(graph, callback, vinputs, config)); | |||
std::make_unique<InputCallback>(graph, callback, vinputs, shape, config)); | |||
return to_symbol_var_array(opr->output()); | |||
} | |||
void InputCallback::init_output_static_infer_desc() {} | |||
void InputCallback::init_output_static_infer_desc() { | |||
if (m_output_shape.ndim) { | |||
using namespace cg::static_infer; | |||
auto &&mgr = owner_graph()->static_infer_manager(); | |||
auto infer_shape = [this](TensorShape &dest, const InpVal &) { | |||
dest = m_output_shape; | |||
return true; | |||
}; | |||
mgr.register_shape_infer(output(0), | |||
{SourceType::CONSTANT, {}, infer_shape}); | |||
} | |||
} | |||
cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const { | |||
NodeProp* prop = Super::do_make_node_prop(); | |||
@@ -73,9 +86,23 @@ cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const { | |||
void InputCallback::scn_do_execute() { | |||
auto dev_tensor = m_callback(); | |||
if (m_output_shape.ndim) { | |||
mgb_assert(dev_tensor.shape().eq_shape(m_output_shape)); | |||
} | |||
output(0)->reset_dev_tensor_from_tensor(dev_tensor); | |||
} | |||
cg::OperatorNodeBase* InputCallback::shallow_copy( | |||
const serialization::OprShallowCopyContext &ctx, | |||
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | |||
const OperatorNodeConfig &config) { | |||
auto &&opr = opr_.cast_final_safe<InputCallback>(); | |||
auto* graph = ctx.owner_graph(opr, inputs); | |||
return graph->insert_opr(std::make_unique<InputCallback>(*graph, opr.m_callback, inputs, opr.m_output_shape, config)); | |||
} | |||
MGB_REG_OPR_SHALLOW_COPY(InputCallback, InputCallback::shallow_copy); | |||
/* ================ OutputCallback ================== */ | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(OutputCallback); | |||
@@ -122,6 +149,17 @@ void OutputCallback::scn_do_execute() { | |||
m_param.callback(input(0)->dev_tensor()); | |||
} | |||
cg::OperatorNodeBase* OutputCallback::shallow_copy( | |||
const serialization::OprShallowCopyContext &ctx, | |||
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | |||
const OperatorNodeConfig &config) { | |||
auto &&opr = opr_.cast_final_safe<OutputCallback>(); | |||
auto* graph = ctx.owner_graph(opr, inputs); | |||
return graph->insert_opr(std::make_unique<OutputCallback>(opr.m_param, inputs, config)); | |||
} | |||
MGB_REG_OPR_SHALLOW_COPY(OutputCallback, OutputCallback::shallow_copy); | |||
/* ================ NopCallback ================== */ | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(NopCallback); | |||
@@ -22,6 +22,7 @@ struct Interpreter { | |||
virtual ~Channel() = default; | |||
virtual Handle put(const HostTensorND& value) = 0; | |||
virtual Handle put(const DeviceTensorND& value) = 0; | |||
virtual void del(Handle) = 0; | |||
@@ -17,6 +17,7 @@ | |||
#include "megbrain/opr/internal/param_tag_defs.h" | |||
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" | |||
#include "megbrain/opr/param_defs.h" | |||
#include "megbrain/serialization/sereg.h" | |||
#include "megdnn/oprs/utils.h" | |||
@@ -33,17 +34,24 @@ public: | |||
InputCallback(cg::ComputingGraph& graph, | |||
callback_t callback, | |||
const VarNodeArray& inputs, | |||
const TensorShape& output_shape, | |||
const OperatorNodeConfig &config); | |||
static SymbolVarArray make(cg::ComputingGraph& graph, | |||
callback_t callback, | |||
CompNode comp_node, | |||
DType dtype, | |||
const TensorShape& shape, | |||
const SymbolVarArray& inputs = {}); | |||
static cg::OperatorNodeBase* shallow_copy( | |||
const serialization::OprShallowCopyContext &ctx, | |||
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | |||
const OperatorNodeConfig &config); | |||
protected: | |||
void scn_do_execute() override; | |||
void init_output_static_infer_desc() override; | |||
NodeProp* do_make_node_prop() const override; | |||
private: | |||
TensorShape m_output_shape; | |||
callback_t m_callback; | |||
}; | |||
@@ -63,6 +71,10 @@ public: | |||
SymbolVar input) { | |||
return make(std::move(param), SymbolVarArray{input}); | |||
} | |||
static cg::OperatorNodeBase* shallow_copy( | |||
const serialization::OprShallowCopyContext &ctx, | |||
const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, | |||
const OperatorNodeConfig &config); | |||
protected: | |||
void scn_do_execute() override; | |||
void init_output_static_infer_desc() override; | |||