diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index cc4a4f4e..e8a1c31f 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -78,6 +78,14 @@ class Graph(_imperative_rt.ComputingGraph): opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self) return opnode.outputs[0] + def make_h2d(self, *, dtype, device): + device = as_device(device).to_c() + return self._wrap(_imperative_rt.make_h2d(self, device, dtype)) + + +def dump(*args): + return _imperative_rt.dump_graph([i._node for i in args]) + class VarNode(TensorBase): def __init__(self, node: _imperative_rt.VarNode): @@ -93,6 +101,14 @@ class VarNode(TensorBase): return self.graph._wrap(self._node.owner) @property + def name(self): + return self._node.name + + @name.setter + def name(self, name): + self._node.name = name + + @property def dtype(self): return self._node.dtype @@ -119,6 +135,14 @@ class OpNode: return self._node.graph @property + def name(self): + return self._node.name + + @name.setter + def name(self, name): + self._node.name = name + + @property def inputs(self): return tuple(map(self.graph._wrap, self._node.inputs)) diff --git a/imperative/python/src/graph_rt.cpp b/imperative/python/src/graph_rt.cpp index 27899e02..f827466f 100644 --- a/imperative/python/src/graph_rt.cpp +++ b/imperative/python/src/graph_rt.cpp @@ -11,6 +11,7 @@ #include "./graph_rt.h" +#include "megbrain/serialization/serializer.h" #include "megbrain/imperative/opr_utility.h" #include "megbrain/opr/io.h" #include "megbrain/opr/basic_arith.h" @@ -47,7 +48,8 @@ void init_graph_rt(py::module m) { py::class_>(m, "VarNode") .def_property_readonly("owner", [](cg::VarNode* v) {return v->owner_opr();}) .def_property_readonly("graph", [](cg::VarNode* v) {return v->owner_graph();}) - .def_property_readonly("name", py::overload_cast<>(&VarNode::name, py::const_)) + .def_property("name", py::overload_cast<>(&VarNode::name, py::const_), + py::overload_cast(&VarNode::name)) .def_property_readonly("dtype", [](cg::VarNode* v) {return v->dtype();}) .def_property_readonly("comp_node", [](cg::VarNode* v) {return v->comp_node();}) .def_property_readonly("shape", [](cg::VarNode* v) -> const TensorShape* { @@ -75,7 +77,8 @@ void init_graph_rt(py::module m) { py::class_>(m, "OperatorNode") .def_property_readonly("graph", [](cg::OperatorNodeBase* opr) {return opr->owner_graph();}) - .def_property_readonly("name", py::overload_cast<>(&cg::OperatorNodeBase::name, py::const_)) + .def_property("name", py::overload_cast<>(&cg::OperatorNodeBase::name, py::const_), + py::overload_cast(&cg::OperatorNodeBase::name)) .def_property_readonly("inputs", [](cg::OperatorNodeBase* opr) { return to_tuple(opr->input()); }) @@ -99,6 +102,15 @@ void init_graph_rt(py::module m) { }) .def_property_readonly("options", py::overload_cast<>(&cg::ComputingGraph::options)); + m.def("dump_graph", [](const std::vector& dest_vars) { + using namespace mgb::serialization; + std::vector buf; + auto dumper = GraphDumper::make(OutputFile::make_vector_proxy(&buf)); + SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); + dumper->dump(symvars); + return py::bytes(reinterpret_cast(&buf[0]), buf.size()); + }); + #define CURRENT_CLASS cg::ComputingGraph::Options auto PyComputingGraphOptions = py::class_(PyComputingGraph, "Options") @@ -198,6 +210,20 @@ void init_graph_rt(py::module m) { return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); }); + m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, std::optional name) { + if (!cn.valid()) { + throw py::type_error("device must be valid"); + } + if (!dtype.valid()) { + throw py::type_error("dtype must be valid"); + } + OperatorNodeConfig config; + if (name) { + config.name(*name); + } + return opr::Host2DeviceCopy::make(graph, std::make_shared(cn, dtype), config).node(); + }, py::arg(), py::arg(), py::arg(), py::arg() = py::none()); + m.def("input_callback", [input_callback](std::function callback, const CompNode& comp_node, const DType& dtype,