Browse Source

refactor(mge/imperative): add BackwardGraph.interpret

GitOrigin-RevId: bb3a59380e
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
4101d5bca0
4 changed files with 55 additions and 29 deletions
  1. +8
    -0
      imperative/python/megengine/core/tensor/megbrain_graph.py
  2. +12
    -1
      imperative/python/src/ops.cpp
  3. +4
    -28
      imperative/src/impl/ops/backward_graph.cpp
  4. +31
    -0
      imperative/src/include/megbrain/imperative/ops/backward_graph.h

+ 8
- 0
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -12,6 +12,7 @@ import weakref
from concurrent.futures import Future, ThreadPoolExecutor

from .. import _imperative_rt
from .._imperative_rt.ops import BackwardGraph
from .._wrap import device as as_device
from ..ops.builtin import OpDef
from .core import OpBase, TensorBase, apply
@@ -131,6 +132,13 @@ def _(op: OpDef, *args: VarNode):
return _wrap(outputs)


@apply.register()
def _(op: BackwardGraph, *args: VarNode):
assert args
graph = args[0].graph
return op.interpret(lambda op, args: apply(op, *args), graph.make_const, args)


def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None):
outputs = _imperative_rt.input_callback(
callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph


+ 12
- 1
imperative/python/src/ops.cpp View File

@@ -40,6 +40,18 @@ void init_ops(py::module m) {
attr.param.insert(attr.param.end(), s.begin(), s.end());
});

py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph")
.def("interpret", [](BackwardGraph& self, py::object pyf, py::object pyc,
const mgb::SmallVector<py::object>& inputs) {
auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) {
return py::cast<mgb::SmallVector<py::object>>(pyf(op.copy(), inputs));
};
auto c = [pyc](const TensorPtr& tensor) {
return pyc(tensor->dev_tensor());
};
return self.graph().interpret<py::object>(f, c, inputs);
});

py::class_<GetVarShape, std::shared_ptr<GetVarShape>, OpDef>(m, "GetVarShape")
.def(py::init());

@@ -98,7 +110,6 @@ void init_ops(py::module m) {
.def(py::init<>())
.def_readwrite("offsets", &ParamPackConcat::offsets);

py::class_<BackwardGraph, std::shared_ptr<BackwardGraph>, OpDef>(m, "BackwardGraph");
py::class_<CondTake, std::shared_ptr<CondTake>, OpDef>(m, "CondTake")
.def(py::init<>());



+ 4
- 28
imperative/src/impl/ops/backward_graph.cpp View File

@@ -18,34 +18,10 @@ namespace imperative {
SmallVector<TensorPtr>
BackwardGraph::InternalGraph::apply(
const SmallVector<TensorPtr>& inputs) const {
ThinHashMap<size_t, TensorPtr> node2tensor;
auto&& input_nodes = this->inputs;
mgb_assert(inputs.size() == input_nodes.size());
for (size_t i = 0; i < inputs.size(); ++ i) {
node2tensor[input_nodes[i]] = inputs[i];
}
for (auto &&i : constants) {
node2tensor[i.first] = i.second;
}
for (size_t i = 0; i < exprs.size(); ++ i) {
auto&& expr = exprs[i];
SmallVector<TensorPtr> inputs;
for (auto &&in : std::get<1>(expr)) {
inputs.push_back(node2tensor.at(in));
}
auto outputs = OpDef::apply_on_physical_tensor(
*std::get<0>(expr), inputs);
auto output_nodes = std::get<2>(expr);
mgb_assert(outputs.size() == output_nodes.size());
for (size_t i = 0; i < outputs.size(); ++ i) {
node2tensor[output_nodes[i]] = outputs[i];
}
}
SmallVector<TensorPtr> ret;
for (auto &&i : outputs) {
ret.push_back(node2tensor.at(i));
}
return ret;
return interpret<TensorPtr>(
&OpDef::apply_on_physical_tensor,
[](const TensorPtr& x) {return x;},
inputs);
}

SmallVector<LogicalTensorDesc>


+ 31
- 0
imperative/src/include/megbrain/imperative/ops/backward_graph.h View File

@@ -40,6 +40,37 @@ public:

SmallVector<LogicalTensorDesc>
infer_attrs(const SmallVector<LogicalTensorDesc>& inputs) const;

template <typename T, typename F, typename C>
SmallVector<T> interpret(F&& f, C&& c, const SmallVector<T>& inputs) const {
ThinHashMap<size_t, T> node2tensor;
auto&& input_nodes = this->inputs;
mgb_assert(inputs.size() == input_nodes.size());
for (size_t i = 0; i < inputs.size(); ++ i) {
node2tensor[input_nodes[i]] = inputs[i];
}
for (auto &&i : constants) {
node2tensor[i.first] = c(i.second);
}
for (size_t i = 0; i < exprs.size(); ++ i) {
auto&& expr = exprs[i];
SmallVector<T> inputs;
for (auto &&in : std::get<1>(expr)) {
inputs.push_back(node2tensor.at(in));
}
auto&& outputs = f(*std::get<0>(expr), std::move(inputs));
auto&& output_nodes = std::get<2>(expr);
mgb_assert(outputs.size() == output_nodes.size());
for (size_t i = 0; i < outputs.size(); ++ i) {
node2tensor[output_nodes[i]] = std::move(outputs[i]);
}
}
SmallVector<T> ret;
for (auto &&i : outputs) {
ret.push_back(node2tensor.at(i));
}
return ret;
}
};

const InternalGraph& graph() const {


Loading…
Cancel
Save