Browse Source

feat(mge/imperative): add name, make_h2d, dump_graph to graph runtime

GitOrigin-RevId: b8681a31a8
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
76dbaa2741
2 changed files with 52 additions and 2 deletions
  1. +24
    -0
      imperative/python/megengine/core/tensor/megbrain_graph.py
  2. +28
    -2
      imperative/python/src/graph_rt.cpp

+ 24
- 0
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -78,6 +78,14 @@ class Graph(_imperative_rt.ComputingGraph):
opnode = InputNode(*args, device=device, dtype=dtype, shape=shape, graph=self)
return opnode.outputs[0]

def make_h2d(self, *, dtype, device):
device = as_device(device).to_c()
return self._wrap(_imperative_rt.make_h2d(self, device, dtype))


def dump(*args):
return _imperative_rt.dump_graph([i._node for i in args])


class VarNode(TensorBase):
def __init__(self, node: _imperative_rt.VarNode):
@@ -93,6 +101,14 @@ class VarNode(TensorBase):
return self.graph._wrap(self._node.owner)

@property
def name(self):
return self._node.name

@name.setter
def name(self, name):
self._node.name = name

@property
def dtype(self):
return self._node.dtype

@@ -119,6 +135,14 @@ class OpNode:
return self._node.graph

@property
def name(self):
return self._node.name

@name.setter
def name(self, name):
self._node.name = name

@property
def inputs(self):
return tuple(map(self.graph._wrap, self._node.inputs))



+ 28
- 2
imperative/python/src/graph_rt.cpp View File

@@ -11,6 +11,7 @@

#include "./graph_rt.h"

#include "megbrain/serialization/serializer.h"
#include "megbrain/imperative/opr_utility.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/basic_arith.h"
@@ -47,7 +48,8 @@ void init_graph_rt(py::module m) {
py::class_<cg::VarNode, GraphNodePtr<cg::VarNode>>(m, "VarNode")
.def_property_readonly("owner", [](cg::VarNode* v) {return v->owner_opr();})
.def_property_readonly("graph", [](cg::VarNode* v) {return v->owner_graph();})
.def_property_readonly("name", py::overload_cast<>(&VarNode::name, py::const_))
.def_property("name", py::overload_cast<>(&VarNode::name, py::const_),
py::overload_cast<std::string>(&VarNode::name))
.def_property_readonly("dtype", [](cg::VarNode* v) {return v->dtype();})
.def_property_readonly("comp_node", [](cg::VarNode* v) {return v->comp_node();})
.def_property_readonly("shape", [](cg::VarNode* v) -> const TensorShape* {
@@ -75,7 +77,8 @@ void init_graph_rt(py::module m) {

py::class_<cg::OperatorNodeBase, GraphNodePtr<cg::OperatorNodeBase>>(m, "OperatorNode")
.def_property_readonly("graph", [](cg::OperatorNodeBase* opr) {return opr->owner_graph();})
.def_property_readonly("name", py::overload_cast<>(&cg::OperatorNodeBase::name, py::const_))
.def_property("name", py::overload_cast<>(&cg::OperatorNodeBase::name, py::const_),
py::overload_cast<std::string>(&cg::OperatorNodeBase::name))
.def_property_readonly("inputs", [](cg::OperatorNodeBase* opr) {
return to_tuple(opr->input());
})
@@ -99,6 +102,15 @@ void init_graph_rt(py::module m) {
})
.def_property_readonly("options", py::overload_cast<>(&cg::ComputingGraph::options));

m.def("dump_graph", [](const std::vector<VarNode*>& dest_vars) {
using namespace mgb::serialization;
std::vector<uint8_t> buf;
auto dumper = GraphDumper::make(OutputFile::make_vector_proxy(&buf));
SymbolVarArray symvars(dest_vars.begin(), dest_vars.end());
dumper->dump(symvars);
return py::bytes(reinterpret_cast<const char*>(&buf[0]), buf.size());
});

#define CURRENT_CLASS cg::ComputingGraph::Options

auto PyComputingGraphOptions = py::class_<cg::ComputingGraph::Options>(PyComputingGraph, "Options")
@@ -198,6 +210,20 @@ void init_graph_rt(py::module m) {
return opr::ImmutableTensor::make(*graph, hv, OperatorNodeConfig(cn)).node();
});

m.def("make_h2d", [](cg::ComputingGraph& graph, CompNode cn, DType dtype, std::optional<std::string> name) {
if (!cn.valid()) {
throw py::type_error("device must be valid");
}
if (!dtype.valid()) {
throw py::type_error("dtype must be valid");
}
OperatorNodeConfig config;
if (name) {
config.name(*name);
}
return opr::Host2DeviceCopy::make(graph, std::make_shared<HostTensorND>(cn, dtype), config).node();
}, py::arg(), py::arg(), py::arg(), py::arg() = py::none());

m.def("input_callback", [input_callback](std::function<DeviceTensorND(void)> callback,
const CompNode& comp_node,
const DType& dtype,


Loading…
Cancel
Save