diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index 3ecda5cf..37a86369 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -17,15 +17,31 @@ from ..ops.builtin import OpDef from .core import OpBase, TensorBase, apply -class CompiledFunction: - def __init__(self, graph, function): - self._graph = graph - self._function = function +class Graph(_imperative_rt.ComputingGraph): + def __init__(self): + super().__init__() + self._var_cache = weakref.WeakKeyDictionary() + self._op_cache = weakref.WeakKeyDictionary() + self._executor = ThreadPoolExecutor(1) + self._function = None self._future = None + def _wrap(self, obj): + if type(obj) is _imperative_rt.VarNode: + wrapper, cache = VarNode, self._var_cache + elif type(obj) is _imperative_rt.OperatorNode: + wrapper, cache = OpNode, self._op_cache + if obj not in cache: + cache[obj] = wrapper(obj) + return cache[obj] + + def compile(self, *args): + self._function = super().compile(_unwrap(args)) + return self + def execute(self, *args): assert self._future is None - self._future = self._graph._executor.submit(self._function.execute, *args) + self._future = self._executor.submit(self._function.execute, *args) def wait(self): assert self._future is not None @@ -40,30 +56,23 @@ class CompiledFunction: self.execute(*args) return self.wait() + def make_const(self, data, dtype=None, device=None): + if isinstance(data, _imperative_rt.DeviceTensorND): + assert dtype is None and device is None + return self._wrap(_imperative_rt.make_shared(self, data)) + else: + device = as_device(device).to_c() + return self._wrap(_imperative_rt.make_const(self, data, device, dtype)) -class Graph(_imperative_rt.ComputingGraph): - def __init__(self): - super().__init__() - self._var_cache = weakref.WeakKeyDictionary() - self._op_cache = weakref.WeakKeyDictionary() - self._executor = ThreadPoolExecutor(1) - - def _wrap(self, obj): - if type(obj) is _imperative_rt.VarNode: - wrapper, cache = VarNode, self._var_cache - elif type(obj) is _imperative_rt.OperatorNode: - wrapper, cache = OpNode, self._op_cache - if obj not in cache: - cache[obj] = wrapper(obj) - return cache[obj] - - def compile(self, *args): - return CompiledFunction(self, super().compile(_unwrap(args))) + def make_input(self, *args: "VarNode", device=None, dtype=None, shape=None): + opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self) + return opnode.outputs[0] class VarNode(TensorBase): def __init__(self, node: _imperative_rt.VarNode): self._node = node + self.graph._var_cache[node] = self @property def graph(self) -> Graph: @@ -81,10 +90,15 @@ class VarNode(TensorBase): def device(self): return as_device(self._node.comp_node) + @property + def shape(self): + return self._node.shape + class OpNode: def __init__(self, node: _imperative_rt.OperatorNode): self._node = node + self.graph._op_cache[node] = self @property def graph(self) -> Graph: @@ -117,21 +131,21 @@ def _(op: OpDef, *args: VarNode): return _wrap(outputs) -def input_callback(callback, *args, device=None, dtype=None, graph=None): +def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): outputs = _imperative_rt.input_callback( - callback, as_device(device).to_c(), dtype, _unwrap(args), graph=graph + callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph ) value, dummy = _wrap(outputs) return value, dummy class InputNode(OpNode): - def __init__(self, *args: VarNode, device=None, dtype=None, graph=None): + def __init__(self, *args: VarNode, device=None, dtype=None, shape=None, graph=None): r = _imperative_rt.DeviceTensorNDRendezvous() if device is not None: device = as_device(device).to_c() outputs = _imperative_rt.input_callback( - r, device, dtype, _unwrap(args), graph=graph + r, device, dtype, shape, _unwrap(args), graph=graph ) super().__init__(outputs[0].owner) self._rendezvous = r @@ -169,6 +183,29 @@ class OutputNode(OpNode): def get_value(self): return self._rendezvous.get() + def drop_value(self): + self._rendezvous.drop() + + def reset(self): + self._rendezvous.reset() + + +class ValueOutputNode(OpNode): + def __init__(self, var, *args): + args = (var,) + args + r = _imperative_rt.HostTensorNDRendezvous() + dummy = _imperative_rt.value_output_callback(r, _unwrap(args)) + super().__init__(dummy.owner) + self._rendezvous = r + + def get_value(self): + hostnd, event = self._rendezvous.get() + event.wait() + return hostnd.numpy() + + def drop_value(self): + self._rendezvous.drop() + def reset(self): self._rendezvous.reset() @@ -192,5 +229,8 @@ class AttrOutputNode(OpNode): attr = self._rendezvous.get() return TensorAttr(attr.shape, attr.dtype, as_device(attr.comp_node)) + def drop_value(self): + self._rendezvous.drop() + def reset(self): self._rendezvous.reset() diff --git a/imperative/python/megengine/core/tensor/raw_tensor/__init__.py b/imperative/python/megengine/core/tensor/raw_tensor/__init__.py index da44c689..a14fd197 100644 --- a/imperative/python/megengine/core/tensor/raw_tensor/__init__.py +++ b/imperative/python/megengine/core/tensor/raw_tensor/__init__.py @@ -31,11 +31,13 @@ class RawTensor(TensorBase): _init_cb = None _del_cb = None + _handle = None - def __init__(self, handle): + def __init__(self, handle=None): self._handle = handle - if self._init_cb: - self._init_cb() + if handle is not None: + if self._init_cb: + self._init_cb() @property def dtype(self): @@ -61,9 +63,10 @@ class RawTensor(TensorBase): ) def __del__(self): - if self._del_cb: - self._del_cb() - delete(self._handle) + if self._handle is not None: + if self._del_cb: + self._del_cb() + delete(self._handle) @apply.register() @@ -89,6 +92,11 @@ def as_raw_tensor(obj, dtype=None, device=None): return as_raw_tensor(obj, device=device) +@as_raw_tensor.register(DeviceTensorND) +def _(data: DeviceTensorND): + return RawTensor(put(data)) + + @as_raw_tensor.register(np.ndarray) def _(array: np.ndarray, dtype=None, device=None): device = None if device is None else as_device(device).to_c() diff --git a/imperative/python/megengine/jit/__init__.py b/imperative/python/megengine/jit/__init__.py new file mode 100644 index 00000000..13fb0cc0 --- /dev/null +++ b/imperative/python/megengine/jit/__init__.py @@ -0,0 +1 @@ +from .tracing import exclude_from_trace, trace diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py new file mode 100644 index 00000000..01819de1 --- /dev/null +++ b/imperative/python/megengine/jit/tracing.py @@ -0,0 +1,514 @@ +import contextlib +import functools +import typing +import weakref + +from ..core.ops.special import Const +from ..core.tensor import megbrain_graph as G +from ..core.tensor.core import OpBase, apply +from ..core.tensor.raw_tensor import OpDef, RawTensor, as_raw_tensor + + +class TraceMismatchError(RuntimeError): + pass + + +active_trace = None +skip_tracing = False + + +@contextlib.contextmanager +def exclude_from_trace(): + global skip_tracing + if skip_tracing: + yield + return + try: + skip_tracing = True + if active_trace is not None: + active_trace._begin_excluded_region() + yield + finally: + skip_tracing = False + + +class TensorInfo: + __slots__ = ( + # collected attributes + "external", + "exported", + "data_read", + "shape_read", + "value_read", + "device", + "dtype", + "bound_data", + # resources for execution + "varnode", + "data_setter", + "shape_reader", + "value_reader", + "data_reader", + ) + + def __init__(self): + self.exported = None + self.data_read = None + self.shape_read = None + self.value_read = None + self.bound_data = None + + self.data_setter = None + self.shape_reader = None + self.value_reader = None + self.data_reader = None + + +class trace: + def __new__(cls, *args, **kwargs): + if not args: + return functools.partial(cls, **kwargs) + self = super().__new__(cls) + self.__init__(*args, **kwargs) + return self + + def __init__(self, function, symbolic=False, capture_as_const=False): + self.__wrapped__ = function + self._symbolic = symbolic + self._capture_as_const = capture_as_const + self._capture_static_shape = False + + self._untraced = True + self._tinfo = [] # handle -> TensorInfo + self._seq = [] + self._pc = 0 + self._graph = None + self._need_reset_nodes = None + self._lazy_eval_graph = None + self._lazy_eval_tensors = weakref.WeakSet() + self._active_tensors = weakref.WeakSet() + + def _new_handle(self): + handle = len(self._tinfo) + info = TensorInfo() + self._tinfo.append(info) + return handle, info + + def _apply_op(self, op, args): + assert not self._untraced + # check against trace + if self._pc >= len(self._seq): + raise TraceMismatchError("trace should end here, but more op observed") + record = self._seq[self._pc] + op_, ihandles, ohandles = record + if op != op_: + raise TraceMismatchError("op different from last time") + if len(ihandles) != len(args): + raise TraceMismatchError("op input size different from last time") + + for h, x in zip(ihandles, args): + info = self._tinfo[h] + if info.external: + if ( + x.__class__ is CompiledTensorProxy + and not self._tinfo[x._CompiledTensorProxy__handle].exported + ): + raise TraceMismatchError( + "failed to capture: input was an external tensor " + "last time, got an internal tensor this time" + ) + if info.bound_data: + if x.__class__ is CompiledTensorProxy: + raise TraceMismatchError( + "const capture violated: was an external tensor " + "last time, got an internal tensor this time" + ) + if x._handle != info.bound_data._handle: + raise TraceMismatchError( + "const capture violated: got " + "a different tensor this time" + ) + else: + if info.dtype != x.dtype: + raise TraceMismatchError( + "failed to capture: different dtype from last time" + ) + if info.device != x.device: + raise TraceMismatchError( + "failed to capture: different device from last time" + ) + info.data_setter.set_value(x._dev_tensor()) + else: + if x.__class__ is not CompiledTensorProxy: + raise TraceMismatchError( + "unexpected capture: trying to use an external tensor as input, " + "but that input was an internal tensor last time" + ) + if x._CompiledTensorProxy__handle != h: + raise TraceMismatchError( + "mis-wiring: input edge to an data flow " + "graph node is different from last time" + ) + + self._pc += 1 + outputs = tuple([CompiledTensorProxy(h) for h in ohandles]) + self._active_tensors.update(outputs) + return outputs + + def _record_op(self, op, inputs, outputs): + if skip_tracing: + for x in inputs: + h = getattr(x, "_TraceMixin__handle", None) + if h is not None: + self._tinfo[h].data_read = True + return + + ihandles = [] + for x in inputs: + h = getattr(x, "_TraceMixin__handle", None) + if h is None or (not self._capture_as_const and self._tinfo[h].exported): + h, info = self._new_handle() + info.external = True + info.device = x.device + info.dtype = x.dtype + if self._capture_as_const: + info.bound_data = x + + ihandles.append(h) + + ohandles = [] + for x in outputs: + h, info = self._new_handle() + ohandles.append(h) + info.external = False + TraceMixin._TraceMixin__inject(x, h) + + self._seq.append((op, tuple(ihandles), tuple(ohandles))) + self._active_tensors.update(outputs) + + @contextlib.contextmanager + def _setup(self): + global active_trace + if active_trace: + raise NotImplementedError("sorry, not implemented: nested trace") + active_trace = self + + if self._untraced: + apply.enable(apply_with_tracing) + if self._symbolic: + apply.enable(apply_symbolic_mode) + self._lazy_eval_graph = G.Graph() + else: + apply.enable(apply_compiled_mode) + if self._graph is None: + self._compile() + self._graph.execute() + + yield + + escaped_tensors = tuple(self._active_tensors) + self._active_tensors.clear() + + if self._untraced: + for x in escaped_tensors: + info = self._tinfo[x._TraceMixin__handle] + info.data_read = True + x._TraceMixin__restore() + if self._symbolic: + # eval lazy eval tensors + lazy_eval_tensors = tuple(self._lazy_eval_tensors) + if lazy_eval_tensors: + readers = [ + G.OutputNode(x._LazyEvalTensor__varnode).outputs[0] + for x in lazy_eval_tensors + ] + self._lazy_eval_graph.compile(*readers) + self._lazy_eval_graph() + for r, x in zip(readers, lazy_eval_tensors): + assign_raw_tensor(x, as_raw_tensor(r.op.get_value())) + self._lazy_eval_graph = None + self._lazy_eval_tensors = None + self._untraced = False + else: + if self._pc != len(self._seq): + raise TraceMismatchError("premature end") + for x in escaped_tensors: + assign_raw_tensor(x, as_raw_tensor(x._dev_tensor())) + self._graph.wait() + self._reset_exec_env() + self._pc = 0 + + apply.disable(apply_with_tracing) + apply.disable(apply_symbolic_mode) + apply.disable(apply_compiled_mode) + active_trace = None + + def _begin_excluded_region(self): + if self._untraced: + # conditionally reading a compiled tensor in excluded region + # is permitted, so we have to assume every tensor might be read + for x in self._active_tensors: + info = self._tinfo[x._TraceMixin__handle] + info.exported = True + info.data_read = True + + def _compile(self): + graph = self._graph = G.Graph() + # graph.options.graph_opt_level = 0 + need_reset_nodes = self._need_reset_nodes = [] + # links enforce ordering of I/O nodes + links = () + for op, ihandles, ohandles in self._seq: + ivars = [] + readers = [] + for h in ihandles: + info = self._tinfo[h] + if not hasattr(info, "varnode"): + assert info.external + if info.bound_data: + info.varnode = graph.make_const(info.bound_data._dev_tensor()) + else: + opnode = info.data_setter = G.InputNode( + *links, device=info.device, dtype=info.dtype, graph=graph + ) + need_reset_nodes.append(opnode) + info.varnode, *links = opnode.outputs + + ivars.append(info.varnode) + ovars = apply(op, *ivars) + assert len(ovars) == len(ohandles) + for h, v in zip(ohandles, ovars): + info = self._tinfo[h] + info.varnode = v + + def add_reader(opnode): + nonlocal links + need_reset_nodes.append(opnode) + readers.append(opnode.outputs[0]) + links = opnode.outputs + + if info.data_read: + # Shape can be obtained from data so doesn't need its own + # output node. On the other hand, value is read separately + # to leverage eager h2d copy + info.shape_read = False + opnode = info.data_reader = G.OutputNode(v, *links) + add_reader(opnode) + if info.value_read: + opnode = info.value_reader = G.ValueOutputNode(v, *links) + add_reader(opnode) + if info.shape_read: + opnode = info.shape_reader = G.AttrOutputNode(v, *links) + add_reader(opnode) + + graph.compile(*readers) + + def _reset_exec_env(self): + for opnode in self._need_reset_nodes: + opnode.reset() + + def _require_shape(self, handle): + info = self._tinfo[handle] + info.shape_read = True + + def _require_value(self, handle): + info = self._tinfo[handle] + info.value_read = True + + def _require_data(self, handle): + info = self._tinfo[handle] + info.data_read = True + + def __call__(self, *args, **kwargs): + with self._setup(): + return self.__wrapped__(*args, **kwargs) + + +class CompiledTensorProxy(RawTensor): + """ + Duck-typed RawTensor + """ + + def __init__(self, handle): + self.__handle = handle + self.__info = active_trace._tinfo[handle] + self.__shape = None + self.__data = None + self.__value = None + + @property + def dtype(self): + return self.__info.varnode.dtype + + @property + def device(self): + return self.__info.varnode.device + + @property + def shape(self): + if self.__shape is None: + if self.__info.shape_read: + self.__shape = self.__info.shape_reader.get_value().shape + elif self.__info.data_read: + self.__shape = self._dev_tensor().shape + else: + raise TraceMismatchError("shape of this tensor is not read in trace") + return self.__shape + + def numpy(self): + if self.__value is None: + if self.__info.value_read: + self.__value = self.__info.value_reader.get_value() + elif self.__info.data_read: + self.__value = self._dev_tensor().numpy() + else: + raise TraceMismatchError("value of this tensor is not read in trace") + return self.__value + + def _dev_tensor(self): + if self.__data is None: + if not self.__info.data_read: + raise TraceMismatchError("raw data of this tensor is not read in trace") + self.__data = self.__info.data_reader.get_value() + return self.__data + + def __del__(self): + if self.__info.shape_read and self.__shape is not None: + self.__info.shape_reader.drop_value() + if self.__info.value_read and self.__value is not None: + self.__info.value_reader.drop_value() + if self.__info.data_read and self.__data is not None: + self.__info.data_reader.drop_value() + + +class LazyEvalTensor(RawTensor): + def __init__(self, varnode): + self.__varnode = varnode + + @property + def dtype(self): + return self.__varnode.dtype + + @property + def device(self): + return self.__varnode.device + + @property + def shape(self): + return self.__varnode.shape + + def numpy(self): + raise RuntimeError("cannot read value during symbolic tracing") + + def _dev_tensor(self): + raise RuntimeError("cannot access data during symbolic tracing") + + +class TraceMixin: + __subclass_cache = {} + + def __inject(self, handle): + cache = __class__.__subclass_cache + cls = self.__class__ + subcls = cache.get(cls) + if subcls is None: + subcls = cache[cls] = type("Traced" + cls.__name__, (__class__, cls), {}) + self.__class__ = subcls + self.__handle = handle + self.__cls = cls + return self + + def __restore(self): + cls = self.__cls + del self.__handle + del self.__cls + self.__class__ = cls + return self + + @property + def shape(self): + if not skip_tracing: + active_trace._require_shape(self.__handle) + return super().shape + + def numpy(self): + if not skip_tracing: + active_trace._require_value(self.__handle) + return super().numpy() + + def _dev_tensor(self): + if not skip_tracing: + active_trace._require_data(self.__handle) + return super()._dev_tensor() + + +class TracedRawTensor(TraceMixin, RawTensor): + pass + + +class TracedLazyTensor(TraceMixin, LazyEvalTensor): + pass + + +def assign_raw_tensor(lhs, rhs): + handle = rhs._handle + rhs.__dict__.clear() + lhs.__dict__.clear() + lhs.__class__ = RawTensor + lhs.__init__(handle) + + +# this hook turns RawTensor into LazyEvalTensor +@apply.register() +def apply_symbolic_mode(op: OpDef, *args: RawTensor): + graph = active_trace._lazy_eval_graph + ivars = [ + getattr(x, "_LazyEvalTensor__varnode", None) + or graph.make_const(x._dev_tensor()) + for x in args + ] + ovars = apply(op, *ivars) + outputs = [LazyEvalTensor(v) for v in ovars] + active_trace._lazy_eval_tensors.update(outputs) + return outputs + + +apply.disable(apply_symbolic_mode) + + +@apply.register() +def apply_compiled_mode(op: OpDef, *args: RawTensor): + if skip_tracing: + args = [ + as_raw_tensor(x._dev_tensor()) if x.__class__ is CompiledTensorProxy else x + for x in args + ] + return apply.super(op, *args) + return active_trace._apply_op(op, args) + + +apply.disable(apply_compiled_mode) + + +# this hook injects TraceMixin +@apply.register() +def apply_with_tracing(op: OpDef, *args: RawTensor): + outputs = apply.super(op, *args) + active_trace._record_op(op, args, outputs) + return outputs + + +apply.disable(apply_with_tracing) + + +# @apply.register() +# def _(op: Const, *args: RawTensor): +# return active_trace._apply_const(op, args) + + +class BrokenRawTensor(RawTensor): + def __getattribute__(self, _): + raise RuntimeError("broken due to misuse of tracing") + + def __setattr__(self, *_): + raise RuntimeError("broken due to misuse of tracing") diff --git a/imperative/python/src/common.cpp b/imperative/python/src/common.cpp index 54a8549e..10807a2e 100644 --- a/imperative/python/src/common.cpp +++ b/imperative/python/src/common.cpp @@ -23,10 +23,29 @@ namespace py = pybind11; using namespace mgb; using namespace imperative; +namespace { + +template +auto def_TensorND(py::object parent, const char* name) { + return py::class_(parent, name) + .def_property_readonly("shape", py::overload_cast<>(&XTensorND::shape, py::const_)) + .def_property_readonly("dtype", py::overload_cast<>(&XTensorND::dtype, py::const_)) + .def_property_readonly("comp_node", py::overload_cast<>(&XTensorND::comp_node, py::const_)) + .def("copy_from", &XTensorND::template copy_from) + .def("copy_from", &XTensorND::template copy_from) + .def("copy_from_fixlayout", py::overload_cast( + &XTensorND::template copy_from_fixlayout)) + .def("copy_from_fixlayout", py::overload_cast( + &XTensorND::template copy_from_fixlayout)); +} + +} // namespace + void init_common(py::module m) { - py::class_(m, "CompNode") + auto&& PyCompNode = py::class_(m, "CompNode") .def(py::init()) .def(py::init(py::overload_cast(&CompNode::load))) + .def("create_event", &CompNode::create_event, py::arg("flags") = 0ul) .def("__str__", &CompNode::to_string_logical) .def_static("_sync_all", &CompNode::sync_all) .def(py::self == py::self) @@ -40,19 +59,30 @@ void init_common(py::module m) { return CompNode::load(cn); })); + py::class_>(PyCompNode, "Event") + .def("record", &CompNode::Event::record) + .def("wait", &CompNode::Event::host_wait); + py::implicitly_convertible(); - py::class_(m, "DeviceTensorND") - .def(py::init()) - .def_property_readonly("shape", py::overload_cast<>(&DeviceTensorND::shape, py::const_)) - .def_property_readonly("dtype", py::overload_cast<>(&DeviceTensorND::dtype, py::const_)) - .def_property_readonly("comp_node", py::overload_cast<>(&DeviceTensorND::comp_node, py::const_)) + def_TensorND(m, "DeviceTensorND") .def("numpy", [](const DeviceTensorND& self) { HostTensorND hv; hv.copy_from(self).sync(); return py::handle(npy::ndarray_from_tensor(hv, npy::ShareType::TRY_SHARE)); }); + def_TensorND(m, "HostTensorND") + .def(py::init([](py::array data, CompNode cn, DType dtype) { + if (!cn.valid()) { + throw py::type_error("device must not be None"); + } + return npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); + })) + .def("numpy", [](const HostTensorND& self) { + return py::reinterpret_steal(npy::ndarray_from_tensor(self, npy::ShareType::TRY_SHARE)); + }); + py::class_(m, "OperatorNodeConfig") .def(py::init()) .def_property("name", diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 28fdccc5..caabf439 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -12,6 +12,7 @@ #include "./graph_rt.h" #include "megbrain/imperative/opr_utility.h" +#include "megbrain/opr/io.h" #include "megbrain/opr/basic_arith.h" #include "megbrain/imperative.h" #include "./helper.h" @@ -29,29 +30,44 @@ auto def_rendezvous(py::object m, const char* name) { .def(py::init([](){return std::make_shared>();})) .def("set", [](Rendezvous& r, T v) {r.set(std::move(v));}) .def("get", [](Rendezvous& r) {return r.get();}, py::call_guard()) + .def("drop", &Rendezvous::drop) .def("reset", &Rendezvous::reset); } using TensorAttr = LogicalTensorDesc; +using HostNDWithEvent = std::pair>; void init_graph_rt(py::module m) { def_rendezvous(m, "DeviceTensorNDRendezvous"); + def_rendezvous(m, "HostTensorNDRendezvous"); + def_rendezvous(m, "TensorAttrRendezvous"); py::class_>(m, "VarNode") .def_property_readonly("owner", [](cg::VarNode* v) {return v->owner_opr();}) .def_property_readonly("graph", [](cg::VarNode* v) {return v->owner_graph();}) + .def_property_readonly("name", py::overload_cast<>(&VarNode::name, py::const_)) .def_property_readonly("dtype", [](cg::VarNode* v) {return v->dtype();}) - .def_property_readonly("comp_node", [](cg::VarNode* v) {return v->comp_node();}); + .def_property_readonly("comp_node", [](cg::VarNode* v) {return v->comp_node();}) + .def_property_readonly("shape", [](cg::VarNode* v) -> const TensorShape* { + auto&& mgr = v->owner_graph()->static_infer_manager(); + auto&& type = mgr.get_infer_type(v); + using InferType = cg::static_infer::InferType; + if (!(type.shape & (InferType::CONST | InferType::RT_STATIC))) { + return nullptr; + } + return mgr.infer_shape_fallible(v); + }); py::class_>(m, "OperatorNode") .def_property_readonly("graph", [](cg::OperatorNodeBase* opr) {return opr->owner_graph();}) + .def_property_readonly("name", py::overload_cast<>(&cg::OperatorNodeBase::name, py::const_)) .def_property_readonly("inputs", [](cg::OperatorNodeBase* opr) { return to_tuple(opr->input()); }) .def_property_readonly("outputs", [](cg::OperatorNodeBase* opr) { - return to_tuple(opr->output()); + return to_tuple(opr->usable_output()); }); py::class_(m, "AsyncExecutable") @@ -117,7 +133,7 @@ void init_graph_rt(py::module m) { common.def("invoke_op", [](const OpDef& def, const std::vector inputs, cg::ComputingGraph* graph) { cg::VarNodeArray vinputs(inputs.begin(), inputs.end()); auto opr = OpDef::apply_on_var_node(def, vinputs); - auto outputs = opr->output(); + auto outputs = opr->usable_output(); return to_tuple(outputs); }, py::arg(), py::arg(), py::arg("graph") = py::none()); @@ -125,6 +141,7 @@ void init_graph_rt(py::module m) { auto input_callback = [](auto callback, const CompNode& comp_node, const DType& dtype, + const TensorShape& shape, const std::vector& inputs, cg::ComputingGraph* graph) { if (!graph) { @@ -135,7 +152,7 @@ void init_graph_rt(py::module m) { sinputs.emplace_back(i); } static_assert(!std::is_reference::value); - auto soutputs = opr::InputCallback::make(*graph, std::move(callback), comp_node, dtype, sinputs); + auto soutputs = opr::InputCallback::make(*graph, std::move(callback), comp_node, dtype, shape, sinputs); std::vector outputs; outputs.reserve(soutputs.size()); for (auto i : soutputs) { @@ -144,26 +161,40 @@ void init_graph_rt(py::module m) { return outputs; }; + m.def("make_shared", [](cg::ComputingGraph* graph, const DeviceTensorND& data) { + return opr::SharedDeviceTensor::make(*graph, std::make_shared(data)).node(); + }); + + m.def("make_const", [](cg::ComputingGraph* graph, py::array data, CompNode cn, DType dtype) { + if (!cn.valid()) { + throw py::type_error("device must not be None"); + } + auto hv = npy::np2tensor(data.ptr(), npy::Meth::borrow(cn), dtype); + opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); + }); + m.def("input_callback", [input_callback](std::function callback, const CompNode& comp_node, const DType& dtype, + const TensorShape& shape, const std::vector& inputs, cg::ComputingGraph* graph) { - return input_callback([f=std::move(callback)](){py::gil_scoped_acquire _; return f();}, comp_node, dtype, inputs, graph); + return input_callback([f=std::move(callback)](){py::gil_scoped_acquire _; return f();}, comp_node, dtype, shape, inputs, graph); }, - py::arg(), py::arg(), py::arg(), py::arg() = py::tuple(), py::arg("graph") = py::none()); + py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none()); m.def("input_callback", [input_callback](std::shared_ptr> p, const CompNode& comp_node, const DType& dtype, + const TensorShape& shape, const std::vector& inputs, cg::ComputingGraph* graph) { auto f = [p]() -> DeviceTensorND { return p->get(); }; - return input_callback(std::move(f), comp_node, dtype, inputs, graph); + return input_callback(std::move(f), comp_node, dtype, shape, inputs, graph); }, - py::arg(), py::arg(), py::arg(), py::arg() = py::tuple(), py::arg("graph") = py::none()); + py::arg(), py::arg(), py::arg(), py::arg() = py::none(), py::arg() = py::tuple(), py::arg("graph") = py::none()); auto output_callback = [](auto callback, const std::vector& inputs, bool borrow = false) { SymbolVarArray sinputs; @@ -193,6 +224,17 @@ void init_graph_rt(py::module m) { return output_callback(std::move(f), std::move(inputs)); }); + m.def("value_output_callback", [output_callback](std::shared_ptr> p, std::vector inputs) { + auto f = [p](DeviceTensorND dv) { + HostNDWithEvent hv_with_event; + hv_with_event.first.copy_from(dv); + hv_with_event.second = dv.comp_node().create_event(); + hv_with_event.second->record(); + p->set(std::move(hv_with_event)); + }; + return output_callback(std::move(f), std::move(inputs), true); + }); + m.def("attr_output_callback", [output_callback](std::shared_ptr> p, std::vector inputs) { auto f = [p](DeviceTensorND dv) { p->set(TensorAttr{TensorLayout{dv.shape(), dv.dtype()}, dv.comp_node()}); diff --git a/imperative/python/src/graph_rt.h b/imperative/python/src/graph_rt.h index 251e370a..85d84bd4 100644 --- a/imperative/python/src/graph_rt.h +++ b/imperative/python/src/graph_rt.h @@ -39,6 +39,7 @@ template class Rendezvous { std::mutex m_lock; int m_read_ahead = 0; + bool m_drop_next = false; std::promise m_promise; public: Rendezvous() = default; @@ -47,6 +48,7 @@ public: Rendezvous& operator=(const Rendezvous& rhs) = delete; Rendezvous& operator=(Rendezvous&& rhs) { MGB_LOCK_GUARD(m_lock); + m_drop_next = rhs.m_drop_next; m_read_ahead = rhs.m_read_ahead; m_promise = std::move(rhs.m_promise); return *this; @@ -67,12 +69,28 @@ public: return f.get(); } + void drop() { + MGB_LOCK_GUARD(m_lock); + mgb_assert(m_read_ahead <= 0); + mgb_assert(m_read_ahead >= -1); + if (m_read_ahead == -1) { + m_promise = {}; + } else { + m_drop_next = true; + } + ++m_read_ahead; + } + template void set(T&& value) { MGB_LOCK_GUARD(m_lock); mgb_assert(m_read_ahead >= 0); mgb_assert(m_read_ahead <= 1); - m_promise.set_value(std::forward(value)); + if (m_drop_next) { + m_drop_next = false; + } else { + m_promise.set_value(std::forward(value)); + } if (m_read_ahead == 1) { m_promise = {}; } @@ -83,6 +101,7 @@ public: MGB_LOCK_GUARD(m_lock); m_promise = {}; m_read_ahead = 0; + m_drop_next = false; } }; diff --git a/imperative/python/src/helper.h b/imperative/python/src/helper.h index 294826ef..82cd083f 100644 --- a/imperative/python/src/helper.h +++ b/imperative/python/src/helper.h @@ -280,9 +280,12 @@ namespace detail { public: bool load(handle src, bool convert) { auto obj = reinterpret_steal(src); - if (!isinstance(obj)) { + if (!convert && !isinstance(obj)) { return false; } + if (obj.is_none()) { + return true; + } value.ndim = len(obj); mgb_assert(value.ndim <= mgb::TensorShape::MAX_NDIM); size_t i = 0; diff --git a/imperative/python/src/imperative_rt.cpp b/imperative/python/src/imperative_rt.cpp index 90c28761..6055b411 100644 --- a/imperative/python/src/imperative_rt.cpp +++ b/imperative/python/src/imperative_rt.cpp @@ -63,6 +63,7 @@ void init_imperative_rt(py::module m) { return self.put(npy::np2tensor(data.ptr(), npy::Meth::copy_into(&ret), dtype)); } }, py::arg(), py::arg("dtype") = py::none(), py::arg("device") = py::none()) + .def("put", py::overload_cast(&Interpreter::Channel::put)) .def("delete", [](Interpreter::Channel& self, Interpreter::Handle handle) { return self.del(handle); }) diff --git a/imperative/python/src/pyext17.h b/imperative/python/src/pyext17.h index a28a0f6c..c28a474d 100644 --- a/imperative/python/src/pyext17.h +++ b/imperative/python/src/pyext17.h @@ -24,6 +24,12 @@ constexpr bool has_fastcall = true; constexpr bool has_fastcall = false; #endif +#ifdef _Py_TPFLAGS_HAVE_VECTORCALL +constexpr bool has_vectorcall = true; +#else +constexpr bool has_vectorcall = false; +#endif + template struct invocable_with { template @@ -55,6 +61,9 @@ private: public: PyObject_HEAD std::aligned_storage_t storage; + #ifdef _Py_TPFLAGS_HAVE_VECTORCALL + PyObject* vectorcall_slot; + #endif inline T* inst() { return reinterpret_cast(&storage); @@ -155,6 +164,51 @@ private: // polyfills + struct tp_vectorcall { + static constexpr bool valid = HAS_MEMBER(T, tp_vectorcall); + static constexpr bool haskw = [](){if constexpr (valid) + if constexpr (std::is_invocable_v) + return true; + return false;}(); + + template + static PyObject* impl(PyObject* self, PyObject*const* args, size_t nargsf, PyObject *kwnames) { + auto* inst = reinterpret_cast(self)->inst(); + if constexpr (haskw) { + CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf, kwnames)); + } else { + if (kwnames && PyTuple_GET_SIZE(kwnames)) { + PyErr_SetString(PyExc_TypeError, "expect no keyword argument"); + return nullptr; + } + CVT_RET_PYOBJ(inst->tp_vectorcall(args, nargsf)); + } + } + + static constexpr Py_ssize_t offset = []() {if constexpr (valid) return offsetof(wrap_t, vectorcall_slot); + else return 0;}(); + }; + + struct tp_call { + static constexpr bool provided = HAS_MEMBER(T, tp_call); + static constexpr bool static_form = invocable_with{}( + [](auto&& t, auto... args) -> decltype(std::decay_t::tp_call(args...)) {}); + static constexpr bool valid = provided || tp_vectorcall::valid; + + template + static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { + auto* inst = reinterpret_cast(self)->inst(); + CVT_RET_PYOBJ(inst->tp_call(args, kwargs)); + } + + static constexpr ternaryfunc value = []() {if constexpr (static_form) return T::tp_call; + else if constexpr (provided) return impl<>; + #ifdef _Py_TPFLAGS_HAVE_VECTORCALL + else if constexpr (valid) return PyVectorcall_Call; + #endif + else return nullptr;}(); + }; + struct tp_new { static constexpr bool provided = HAS_MEMBER(T, tp_new); static constexpr bool varkw = std::is_constructible_v; @@ -163,11 +217,14 @@ private: template static PyObject* impl(PyTypeObject* type, PyObject* args, PyObject* kwargs) { auto* self = type->tp_alloc(type, 0); - auto* ptr = reinterpret_cast(self)->inst(); + auto* inst = reinterpret_cast(self)->inst(); + if constexpr (has_vectorcall && tp_vectorcall::valid) { + reinterpret_cast(self)->vectorcall_slot = &tp_vectorcall::template impl<>; + } if constexpr (varkw) { - new(ptr) T(args, kwargs); + new(inst) T(args, kwargs); } else { - new(ptr) T(); + new(inst) T(); } return self; } @@ -190,22 +247,6 @@ private: else return impl<>;}(); }; - struct tp_call { - static constexpr bool valid = HAS_MEMBER(T, tp_call); - static constexpr bool static_form = invocable_with{}( - [](auto&& t, auto... args) -> decltype(std::decay_t::tp_call(args...)) {}); - - template - static PyObject* impl(PyObject* self, PyObject* args, PyObject* kwargs) { - auto* inst = reinterpret_cast(self)->inst(); - CVT_RET_PYOBJ(inst->tp_call(args, kwargs)); - } - - static constexpr ternaryfunc value = []() {if constexpr (static_form) return T::tp_call; - else if constexpr (valid) return impl<>; - else return nullptr;}(); - }; - public: class TypeBuilder { std::vector m_methods; @@ -228,9 +269,17 @@ public: m_type.tp_name = T::tp_name; } m_type.tp_dealloc = tp_dealloc::value; + #ifdef _Py_TPFLAGS_HAVE_VECTORCALL + m_type.tp_vectorcall_offset = tp_vectorcall::offset; + #endif m_type.tp_call = tp_call::value; m_type.tp_basicsize = sizeof(wrap_t); m_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + #ifdef _Py_TPFLAGS_HAVE_VECTORCALL + if constexpr (tp_vectorcall::valid) { + m_type.tp_flags |= _Py_TPFLAGS_HAVE_VECTORCALL; + } + #endif m_type.tp_new = tp_new::value; } diff --git a/imperative/python/test/unit/test_tracing.py b/imperative/python/test/unit/test_tracing.py new file mode 100644 index 00000000..83b61fec --- /dev/null +++ b/imperative/python/test/unit/test_tracing.py @@ -0,0 +1,65 @@ +import numpy as np + +from megengine.core.ops import builtin as ops +from megengine.core.tensor.core import apply +from megengine.core.tensor.raw_tensor import as_raw_tensor +from megengine.jit import exclude_from_trace, trace + + +def test_trace(): + for symbolic in [False, True]: + + @trace(symbolic=symbolic) + def f(x): + op = ops.Elemwise(mode="negate") + (y,) = apply(op, x) + return y + + x = as_raw_tensor([1]).numpy() + y = f.__wrapped__(as_raw_tensor(x)).numpy() + + for i in range(3): + np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) + + +def test_exclude_from_trace(): + for symbolic in [False, True]: + + @trace(symbolic=symbolic) + def f(x): + neg = ops.Elemwise(mode="negate") + (x,) = apply(neg, x) + with exclude_from_trace(): + if i % 2: + (x,) = apply(neg, x) + (x,) = apply(neg, x) + return x + + x = as_raw_tensor([1]).numpy() + + for i in range(3): + y = f.__wrapped__(as_raw_tensor(x)).numpy() + np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) + + +def test_print_in_trace(): + for symbolic in [False]: # cannot read value in symbolic mode + + @trace(symbolic=symbolic) + def f(x): + nonlocal buf + neg = ops.Elemwise(mode="negate") + (x,) = apply(neg, x) + buf = x.numpy() + (x,) = apply(neg, x) + return x + + buf = None + x = as_raw_tensor([1]).numpy() + + for i in range(3): + y = f.__wrapped__(as_raw_tensor(x)).numpy() + z = buf + buf = None + np.testing.assert_equal(f(as_raw_tensor(x)).numpy(), y) + np.testing.assert_equal(z, buf) diff --git a/imperative/src/impl/interpreter_impl.cpp b/imperative/src/impl/interpreter_impl.cpp index 91520341..de521f22 100644 --- a/imperative/src/impl/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter_impl.cpp @@ -37,6 +37,15 @@ void* ChannelImpl::put(const HostTensorND& value) { return info; } +void* ChannelImpl::put(const DeviceTensorND& data) { + auto info = alloc(); + info->desc.layout = data.layout(); + info->desc.comp_node = data.comp_node(); + info->ptr = Tensor::make(data); + m_valid_handle.insert(info); + return info; +} + void ChannelImpl::del(void* handle) { mgb_assert(m_valid_handle.erase(handle), "invalid handle: %p", handle); m_worker.add_task(Del{reinterpret_cast(handle)}); diff --git a/imperative/src/impl/interpreter_impl.h b/imperative/src/impl/interpreter_impl.h index 71c55515..4676d27a 100644 --- a/imperative/src/impl/interpreter_impl.h +++ b/imperative/src/impl/interpreter_impl.h @@ -55,6 +55,7 @@ struct ChannelImpl : Interpreter::Channel { ~ChannelImpl() override; Handle put(const HostTensorND& value) override; + Handle put(const DeviceTensorND& value) override; void del(Handle) override; diff --git a/imperative/src/impl/opr_utility.cpp b/imperative/src/impl/opr_utility.cpp index 8a88b048..d5f9a1c0 100644 --- a/imperative/src/impl/opr_utility.cpp +++ b/imperative/src/impl/opr_utility.cpp @@ -31,9 +31,10 @@ MGB_DYN_TYPE_OBJ_FINAL_IMPL(InputCallback); InputCallback::InputCallback(cg::ComputingGraph& graph, callback_t callback, const VarNodeArray& inputs, + const TensorShape& output_shape, const OperatorNodeConfig& config) : Super(&graph, config, "input_callback", inputs), - m_callback(callback) { + m_output_shape(output_shape), m_callback(callback) { for (VarNode* i : inputs) { add_input({i}); } @@ -48,7 +49,8 @@ InputCallback::InputCallback(cg::ComputingGraph& graph, callback_t callback, SymbolVarArray InputCallback::make(cg::ComputingGraph& graph, callback_t callback, CompNode comp_node, - DType dtype, const SymbolVarArray& inputs) { + DType dtype, const TensorShape& shape, + const SymbolVarArray& inputs) { mgb_assert(comp_node.valid()); mgb_assert(dtype.valid()); OperatorNodeConfig config; @@ -56,11 +58,22 @@ SymbolVarArray InputCallback::make(cg::ComputingGraph& graph, config.output_dtype(dtype); auto vinputs = to_var_node_array(inputs); auto opr = graph.insert_opr( - std::make_unique(graph, callback, vinputs, config)); + std::make_unique(graph, callback, vinputs, shape, config)); return to_symbol_var_array(opr->output()); } -void InputCallback::init_output_static_infer_desc() {} +void InputCallback::init_output_static_infer_desc() { + if (m_output_shape.ndim) { + using namespace cg::static_infer; + auto &&mgr = owner_graph()->static_infer_manager(); + auto infer_shape = [this](TensorShape &dest, const InpVal &) { + dest = m_output_shape; + return true; + }; + mgr.register_shape_infer(output(0), + {SourceType::CONSTANT, {}, infer_shape}); + } +} cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const { NodeProp* prop = Super::do_make_node_prop(); @@ -73,9 +86,23 @@ cg::OperatorNodeBase::NodeProp* InputCallback::do_make_node_prop() const { void InputCallback::scn_do_execute() { auto dev_tensor = m_callback(); + if (m_output_shape.ndim) { + mgb_assert(dev_tensor.shape().eq_shape(m_output_shape)); + } output(0)->reset_dev_tensor_from_tensor(dev_tensor); } +cg::OperatorNodeBase* InputCallback::shallow_copy( + const serialization::OprShallowCopyContext &ctx, + const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, + const OperatorNodeConfig &config) { + auto &&opr = opr_.cast_final_safe(); + auto* graph = ctx.owner_graph(opr, inputs); + return graph->insert_opr(std::make_unique(*graph, opr.m_callback, inputs, opr.m_output_shape, config)); +} + +MGB_REG_OPR_SHALLOW_COPY(InputCallback, InputCallback::shallow_copy); + /* ================ OutputCallback ================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(OutputCallback); @@ -122,6 +149,17 @@ void OutputCallback::scn_do_execute() { m_param.callback(input(0)->dev_tensor()); } +cg::OperatorNodeBase* OutputCallback::shallow_copy( + const serialization::OprShallowCopyContext &ctx, + const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, + const OperatorNodeConfig &config) { + auto &&opr = opr_.cast_final_safe(); + auto* graph = ctx.owner_graph(opr, inputs); + return graph->insert_opr(std::make_unique(opr.m_param, inputs, config)); +} + +MGB_REG_OPR_SHALLOW_COPY(OutputCallback, OutputCallback::shallow_copy); + /* ================ NopCallback ================== */ MGB_DYN_TYPE_OBJ_FINAL_IMPL(NopCallback); diff --git a/imperative/src/include/megbrain/imperative/interpreter.h b/imperative/src/include/megbrain/imperative/interpreter.h index b93fc445..016d2055 100644 --- a/imperative/src/include/megbrain/imperative/interpreter.h +++ b/imperative/src/include/megbrain/imperative/interpreter.h @@ -22,6 +22,7 @@ struct Interpreter { virtual ~Channel() = default; virtual Handle put(const HostTensorND& value) = 0; + virtual Handle put(const DeviceTensorND& value) = 0; virtual void del(Handle) = 0; diff --git a/imperative/src/include/megbrain/imperative/opr_utility.h b/imperative/src/include/megbrain/imperative/opr_utility.h index 711484b6..9054c217 100644 --- a/imperative/src/include/megbrain/imperative/opr_utility.h +++ b/imperative/src/include/megbrain/imperative/opr_utility.h @@ -17,6 +17,7 @@ #include "megbrain/opr/internal/param_tag_defs.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h" #include "megbrain/opr/param_defs.h" +#include "megbrain/serialization/sereg.h" #include "megdnn/oprs/utils.h" @@ -33,17 +34,24 @@ public: InputCallback(cg::ComputingGraph& graph, callback_t callback, const VarNodeArray& inputs, + const TensorShape& output_shape, const OperatorNodeConfig &config); static SymbolVarArray make(cg::ComputingGraph& graph, callback_t callback, CompNode comp_node, DType dtype, + const TensorShape& shape, const SymbolVarArray& inputs = {}); + static cg::OperatorNodeBase* shallow_copy( + const serialization::OprShallowCopyContext &ctx, + const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, + const OperatorNodeConfig &config); protected: void scn_do_execute() override; void init_output_static_infer_desc() override; NodeProp* do_make_node_prop() const override; private: + TensorShape m_output_shape; callback_t m_callback; }; @@ -63,6 +71,10 @@ public: SymbolVar input) { return make(std::move(param), SymbolVarArray{input}); } + static cg::OperatorNodeBase* shallow_copy( + const serialization::OprShallowCopyContext &ctx, + const cg::OperatorNodeBase &opr_, const VarNodeArray &inputs, + const OperatorNodeConfig &config); protected: void scn_do_execute() override; void init_output_static_infer_desc() override;