|
|
@@ -11,6 +11,7 @@ |
|
|
|
|
|
|
|
#include "./graph_rt.h" |
|
|
|
|
|
|
|
#include "megbrain/serialization/serializer.h" |
|
|
|
#include "megbrain/imperative/opr_utility.h" |
|
|
|
#include "megbrain/opr/io.h" |
|
|
|
#include "megbrain/opr/basic_arith.h" |
|
|
@@ -47,7 +48,8 @@ void init_graph_rt(py::module m) { |
|
|
|
py::class_<cg::VarNode, GraphNodePtr<cg::VarNode>>(m, "VarNode") |
|
|
|
.def_property_readonly("owner", [](cg::VarNode* v) {return v->owner_opr();}) |
|
|
|
.def_property_readonly("graph", [](cg::VarNode* v) {return v->owner_graph();}) |
|
|
|
.def_property_readonly("name", py::overload_cast<>(&VarNode::name, py::const_)) |
|
|
|
.def_property("name", py::overload_cast<>(&VarNode::name, py::const_), |
|
|
|
py::overload_cast<std::string>(&VarNode::name)) |
|
|
|
.def_property_readonly("dtype", [](cg::VarNode* v) {return v->dtype();}) |
|
|
|
.def_property_readonly("comp_node", [](cg::VarNode* v) {return v->comp_node();}) |
|
|
|
.def_property_readonly("shape", [](cg::VarNode* v) -> const TensorShape* { |
|
|
@@ -75,7 +77,8 @@ void init_graph_rt(py::module m) { |
|
|
|
|
|
|
|
py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(m, "OperatorNode") |
|
|
|
.def_property_readonly("graph", [](cg::OperatorNodeBase* opr) {return opr->owner_graph();}) |
|
|
|
.def_property_readonly("name", py::overload_cast<>(&cg::OperatorNodeBase::name, py::const_)) |
|
|
|
.def_property("name", py::overload_cast<>(&cg::OperatorNodeBase::name, py::const_), |
|
|
|
py::overload_cast<std::string>(&cg::OperatorNodeBase::name)) |
|
|
|
.def_property_readonly("inputs", [](cg::OperatorNodeBase* opr) { |
|
|
|
return to_tuple(opr->input()); |
|
|
|
}) |
|
|
@@ -99,6 +102,15 @@ void init_graph_rt(py::module m) { |
|
|
|
}) |
|
|
|
.def_property_readonly("options", py::overload_cast<>(&cg::ComputingGraph::options)); |
|
|
|
|
|
|
|
m.def("dump_graph", [](const std::vector<VarNode*>& dest_vars) { |
|
|
|
using namespace mgb::serialization; |
|
|
|
std::vector<uint8_t> buf; |
|
|
|
auto dumper = GraphDumper::make(OutputFile::make_vector_proxy(&buf)); |
|
|
|
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end()); |
|
|
|
dumper->dump(symvars); |
|
|
|
return py::bytes(reinterpret_cast<const char*>(&buf[0]), buf.size()); |
|
|
|
}); |
|
|
|
|
|
|
|
#define CURRENT_CLASS cg::ComputingGraph::Options |
|
|
|
|
|
|
|
auto PyComputingGraphOptions = py::class_<cg::ComputingGraph::Options>(PyComputingGraph, "Options") |
|
|
@@ -198,6 +210,20 @@ void init_graph_rt(py::module m) { |
|
|
|
return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node(); |
|
|
|
}); |
|
|
|
|
|
|
|
m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, std::optional<std::string> name) { |
|
|
|
if (!cn.valid()) { |
|
|
|
throw py::type_error("device must be valid"); |
|
|
|
} |
|
|
|
if (!dtype.valid()) { |
|
|
|
throw py::type_error("dtype must be valid"); |
|
|
|
} |
|
|
|
OperatorNodeConfig config; |
|
|
|
if (name) { |
|
|
|
config.name(*name); |
|
|
|
} |
|
|
|
return opr::Host2DeviceCopy::make(graph, std::make_shared<HostTensorND>(cn, dtype), config).node(); |
|
|
|
}, py::arg(), py::arg(), py::arg(), py::arg() = py::none()); |
|
|
|
|
|
|
|
m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback, |
|
|
|
const CompNode& comp_node, |
|
|
|
const DType& dtype, |
|
|
|