Browse Source

refactor(ops): remove BackwardGraph op

GitOrigin-RevId: eda20e5760
release-1.5
Megvii Engine Team 4 years ago
parent
commit
241b35a697
21 changed files with 221 additions and 422 deletions
  1. +0
    -16
      imperative/python/megengine/core/tensor/megbrain_graph.py
  2. +8
    -17
      imperative/python/megengine/jit/tracing.py
  3. +3
    -3
      imperative/python/src/grad.cpp
  4. +9
    -21
      imperative/python/src/imperative_rt.cpp
  5. +0
    -37
      imperative/python/src/ops.cpp
  6. +0
    -3
      imperative/python/src/tensor.cpp
  7. +37
    -15
      imperative/python/src/tensor.h
  8. +1
    -1
      imperative/python/src/trace.cpp
  9. +5
    -7
      imperative/src/impl/backward_graph_opt.cpp
  10. +1
    -2
      imperative/src/impl/interpreter/interpreter_impl.cpp
  11. +64
    -0
      imperative/src/impl/op_def.cpp
  12. +0
    -141
      imperative/src/impl/ops/backward_graph.cpp
  13. +3
    -34
      imperative/src/impl/proxy_graph.cpp
  14. +1
    -1
      imperative/src/impl/proxy_graph.h
  15. +1
    -1
      imperative/src/impl/proxy_graph/common.h
  16. +1
    -24
      imperative/src/impl/proxy_graph_detail.cpp
  17. +4
    -4
      imperative/src/include/megbrain/imperative/backward_graph_opt.h
  18. +53
    -3
      imperative/src/include/megbrain/imperative/op_def.h
  19. +0
    -86
      imperative/src/include/megbrain/imperative/ops/backward_graph.h
  20. +1
    -1
      imperative/src/include/megbrain/imperative/ops/utility.h
  21. +29
    -5
      imperative/src/test/backward_graph.cpp

+ 0
- 16
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -18,7 +18,6 @@ import numpy as np
from .. import _imperative_rt
from .._imperative_rt import GraphOptimizeOptions
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode
from .._imperative_rt.ops import BackwardGraph
from .._wrap import device as as_device
from ..ops.builtin import OpDef
from .core import TensorBase
@@ -481,21 +480,6 @@ def apply_normal_varnode(op: OpDef, *args: VarNode):
return _wrap(outputs)


def apply_backward_varnode(op: BackwardGraph, *args: VarNode):
assert args
graph = args[0].graph
outputs = op.interpret(
op,
lambda op, args: apply_normal_varnode(op, *args),
graph._make_const_for_backward,
args,
)
return outputs


set_cpp_apply_backward_varnode(apply_backward_varnode)


def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None):
outputs = _imperative_rt.input_callback(
callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph


+ 8
- 17
imperative/python/megengine/jit/tracing.py View File

@@ -32,7 +32,7 @@ from ..core._imperative_rt.ops import (
)
from ..core._trace_option import set_symbolic_shape
from ..core._wrap import device as as_device
from ..core.ops.builtin import BackwardGraph, BatchNorm, OpDef
from ..core.ops.builtin import BatchNorm, OpDef
from ..core.ops.special import Const
from ..core.tensor import megbrain_graph as G
from ..core.tensor.utils import setscalar
@@ -587,10 +587,7 @@ class trace:

ivars.append(info.varnode)

if isinstance(op, BackwardGraph):
ovars = G.apply_backward_varnode(op, *ivars)
else:
ovars = G.apply_normal_varnode(op, *ivars)
ovars = G.apply_normal_varnode(op, *ivars)

if require_links and len(ovars) > 0:
io_links = (ovars[0],)
@@ -805,14 +802,11 @@ class trace:
name=info.name,
)
ivars.append(h2v[h])
if isinstance(op, BackwardGraph):
ovars = G.apply_backward_varnode(op, *ivars)
else:
if isinstance(op, BatchNorm):
assert (
op.fwd_mode == BatchNorm.FwdMode.INFERENCE
), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?"
ovars = G.apply_normal_varnode(op, *ivars)
if isinstance(op, BatchNorm):
assert (
op.fwd_mode == BatchNorm.FwdMode.INFERENCE
), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?"
ovars = G.apply_normal_varnode(op, *ivars)

AutoNaming.record_opnode(ovars[0].op)

@@ -1088,10 +1082,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor):
ivars[0] = opnode.outputs[0]
active_trace._lazy_eval_links = (ivars[0],)

if isinstance(op, BackwardGraph):
ovars = G.apply_backward_varnode(op, *ivars)
else:
ovars = G.apply_normal_varnode(op, *ivars)
ovars = G.apply_normal_varnode(op, *ivars)
outputs = [RawTensor(o) for o in ovars]

if require_links:


+ 3
- 3
imperative/python/src/grad.cpp View File

@@ -75,9 +75,9 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph(
input_requires_grad[i] = python::input_requires_grad(ctx, i);
}
std::shared_ptr<OptimizedBackwardGraphResult> ret;
auto bg = proxy_graph_detail::make_backward_graph(
auto bg = OpDef::make_backward_graph(
*ctx.op, inputs, input_requires_grad, output_has_grad);
if (bg.backward) {
if (!bg.backward.empty()) {
ret = std::make_shared<OptimizedBackwardGraphResult>(bg);
}
backward_graph_cache.emplace(key, ret);
@@ -112,7 +112,7 @@ struct BackwardGraphWithClosure {
size_t count = std::count_if(save_for_backward.begin(),
save_for_backward.end(),
ranges::identity{});
if (backward_graph->precomp) {
if (!backward_graph->precomp.empty()) {
auto&& irng = ranges::span(ctx.args, ctx.nargs);
auto&& orng = views::transform(outputs, [](auto&& i){return i.get();});
auto precomp = apply(backward_graph->precomp, views::concat(irng, orng));


+ 9
- 21
imperative/python/src/imperative_rt.cpp View File

@@ -30,26 +30,14 @@ using namespace imperative;
using namespace interpreter;


namespace {

std::optional<std::tuple<std::shared_ptr<OpDef>, std::vector<bool>, std::vector<bool>>>
make_backward_graph(
const OpDef& opdef, std::vector<LogicalTensorDesc> inputs,
std::vector<bool> input_requires_grad,
std::vector<bool> output_has_grad) {
auto res = OpDef::make_backward_graph(opdef,
SmallVector<LogicalTensorDesc>(inputs.begin(), inputs.end()),
SmallVector<bool>(input_requires_grad.begin(), input_requires_grad.end()),
SmallVector<bool>(output_has_grad.begin(), output_has_grad.end()));
if (res.backward) {
return std::optional<std::tuple<std::shared_ptr<OpDef>, std::vector<bool>, std::vector<bool>>>{
std::in_place, res.backward, res.save_for_backward, res.input_has_grad};
} else {
return {};
}
}
} // namespace

void init_imperative_rt(py::module m) {
m.def("make_backward_graph", &make_backward_graph);
auto make_backward_graph = [](
const OpDef& def,
const SmallVector<LogicalTensorDesc>& inputs,
const SmallVector<bool>& input_requires_grad,
const SmallVector<bool>& output_has_grad){
auto result = OpDef::make_backward_graph(def, inputs, input_requires_grad, output_has_grad);
return std::make_tuple("backward_graph", result.save_for_backward, result.input_has_grad);
};
m.def("make_backward_graph", make_backward_graph);
}

+ 0
- 37
imperative/python/src/ops.cpp View File

@@ -367,42 +367,6 @@ void _init_py_op_def(py::module m) {
}

/*********** begin of hand-write opdefs **************/

PyOpDefBegin(BackwardGraph) // {{
// };
PyOpDefEnd(BackwardGraph)

void _init_py_backward_graph(py::module m) {
using py_op = PyOp(BackwardGraph);
auto& py_type = PyOpType(BackwardGraph);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.BackwardGraph";
py_type.tp_basicsize = sizeof(PyOp(BackwardGraph));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "BackwardGraph";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
mgb_assert(PyType_Ready(&py_type) >= 0);
// FIXME: rewrite interpret function in cpython instead wrap directly by pybind11::cppfunction
auto interpret = py::cpp_function(
[](OpDef& self, py::object pyf, py::object pyc,
const mgb::SmallVector<py::object>& inputs) {
auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) {
return py::cast<mgb::SmallVector<py::object>>(pyf(op.shared_from_this(), inputs));
};
auto c = [pyc](const TensorPtr& tensor) {
return pyc(tensor->dev_tensor());
};
return self.cast_final_safe<BackwardGraph>().graph().interpret<py::object>(f, c, inputs);
});
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "interpret", interpret.release().ptr()) >= 0);
PyType_Modified(&py_type);
m.add_object("BackwardGraph", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second);
}

struct PyOpBase : PyOpDef {
static PyTypeObject py_type;

@@ -496,7 +460,6 @@ FOR_EACH_BIT_COMBINED_ENUM_PARAM(BIT_COMBINED_ENUM_CASTER_IMPL)

void init_ops(py::module m) {
_init_py_op_def(m);
_init_py_backward_graph(m);
_init_py_op_base(m);
INIT_ALL_OP(m)



+ 0
- 3
imperative/python/src/tensor.cpp View File

@@ -156,9 +156,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje
ctx.args = &tensors[0];
ctx.nargs = nargs;
ctx.pytype = pytype;
if (ctx.op->same_type<BackwardGraph>()) {
ctx.backward = true;
}

if (py::isinstance<PySymbolVar>(py::handle(args[0]))){
SmallVector<cg::VarNode*> vinputs(nargs);


+ 37
- 15
imperative/python/src/tensor.h View File

@@ -248,31 +248,53 @@ apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) {
return apply(ctx);
}

template <typename T>
auto apply(std::shared_ptr<OpDef> op, T&& tensors)
-> std::enable_if_t<std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>,
apply_result_t> {
inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) {
ApplyContext ctx;
ctx.op = std::move(op);
ctx.nargs = tensors.size();
Tensor* args[ctx.nargs];
ctx.nargs = nargs;
ctx.args = args;
for (size_t i = 0; i < ctx.nargs; ++i) {
args[i] = resolve_arrow(tensors[i]);
for (size_t i = 0; i < nargs; ++i) {
ctx.flags |= args[i]->m_flags;
}
return apply(ctx);
}

inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) {
ApplyContext ctx;
ctx.op = std::move(op);
ctx.nargs = nargs;
ctx.args = args;
template <typename T>
auto apply(std::shared_ptr<OpDef> op, T&& tensors)
-> std::enable_if_t<std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>,
apply_result_t> {
size_t nargs = tensors.size();
Tensor* args[nargs];
for (size_t i = 0; i < nargs; ++i) {
ctx.flags |= args[i]->m_flags;
args[i] = resolve_arrow(tensors[i]);
}
return apply(ctx);
return apply(op, args, nargs);
}

inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) {
SmallVector<std::shared_ptr<Tensor>> inputs;
for (size_t i = 0; i < nargs; ++i) {
inputs.push_back(args[i]->shared_from_this());
}
auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<std::shared_ptr<Tensor>> inputs) {
return apply(op, inputs);
};
auto const_functor = [](imperative::TensorPtr value) {
return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor()));
};
return graph.apply(inputs, apply_functor, const_functor);
}

template <typename T>
auto apply(Subgraph graph, T&& tensors)
-> std::enable_if_t<std::is_same_v<decltype(tensors[0]), Tensor*>,
apply_result_t> {
size_t nargs = tensors.size();
Tensor* args[nargs];
for (size_t i = 0; i < nargs; ++i) {
args[i] = resolve_arrow(tensors[i]);
}
return apply(graph, args, nargs);
}

void init_tensor(pybind11::module);


+ 1
- 1
imperative/python/src/trace.cpp View File

@@ -22,7 +22,7 @@ apply_result_t apply_trace(ApplyContext& ctx) {
apply_result_t outputs;

if (ctx.backward) {
// call megbrain_graph.py apply(BackwardGraph, *args)
// reach here when compiled=True
auto args = py::tuple(ctx.nargs + 1);
args[0] = py::cast(ctx.op);
for (size_t i = 0; i < ctx.nargs; i++) {


+ 5
- 7
imperative/src/impl/backward_graph_opt.cpp View File

@@ -18,24 +18,22 @@ using namespace imperative;

OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src)
: input_has_grad(src.input_has_grad) {
if (!src.backward->same_type<BackwardGraph>()) {
if (src.backward.exprs.size() <= 1) {
// backward graph only contains a single op
backward = src.backward;
save_for_backward = src.save_for_backward;
return;
}
save_for_backward.resize(src.save_for_backward.size(), false);
precomp.reset(new BackwardGraph);
backward.reset(new BackwardGraph);

auto&& graph = src.backward->cast_final_safe<BackwardGraph>().graph();
auto&& graph = src.backward;
auto&& mask = src.save_for_backward;
size_t input_size = src.input_has_grad.size();
size_t output_size = (mask.size() - input_size) / 2;
mgb_assert(input_size + output_size * 2 == mask.size());

auto& fgraph = precomp->cast_final<BackwardGraph>().graph();
auto& bgraph = backward->cast_final<BackwardGraph>().graph();
auto& fgraph = precomp;
auto& bgraph = backward;

// optimization: move ops (e.g. GetVarShape) to forward to
// reduce memory footprint
@@ -113,6 +111,6 @@ OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphRe
}

if (!fgraph.outputs.size()) {
precomp.reset();
precomp = {};
}
}

+ 1
- 2
imperative/src/impl/interpreter/interpreter_impl.cpp View File

@@ -911,8 +911,7 @@ auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle {
op_type == RemoteSend::typeinfo() ||
op_type == CollectiveComm::typeinfo() ||
op_type == opr::InputCallback::typeinfo() ||
op_type == opr::OutputCallback::typeinfo() ||
op_type == BackwardGraph::typeinfo()) {
op_type == opr::OutputCallback::typeinfo()) {
return m_commands.end();
}
} else if constexpr (std::is_same_v<T, GetValue>) {


+ 64
- 0
imperative/src/impl/op_def.cpp View File

@@ -10,6 +10,9 @@
*/

#include "megbrain/imperative/op_def.h"

#include <sstream>

#include "megbrain/imperative/ops/opr_attr.h"

#include "./op_trait.h"
@@ -117,6 +120,67 @@ const std::string OpDef::make_name() const {
return m_scope + "." + trait()->make_name(*this);
}

std::string Subgraph::repr() const {
std::ostringstream buf;
buf << "(";
for (size_t i = 0; i < inputs.size(); ++i) {
if (i > 0) buf << ", ";
buf << "%" << inputs[i];
}
buf << ") => {\n";
auto fmt_const = [](size_t i, const TensorPtr& t) {
if (t->shape().ndim == 1 && t->shape()[0] == 1) {
auto&& v = t->get_value();
if (v.dtype() == dtype::Float32{}) {
return std::to_string(*v.ptr<dt_float32>());
} else if (v.dtype() == dtype::Int32{}) {
return std::to_string(*v.ptr<int32_t>());
}
}
return std::string("%c") + std::to_string(i);
};
std::unordered_map<size_t, std::string> const_reps;
for (auto&& [i, t] : constants) {
const_reps.emplace(i, fmt_const(i, t));
}
for (auto& [op, ins, outs] : exprs) {
buf << " ";
if (outs.size()) {
for (size_t i = 0; i < outs.size(); ++i) {
if (i > 0) buf << ", ";
buf << "%" << outs[i];
}
buf << " = ";
}
if (auto* p = op->try_cast_final<OprAttr>()) {
buf << p->type;
} else {
buf << op->dyn_typeinfo()->name;
}
for (size_t i : ins) {
buf << " ";
auto&& it = const_reps.find(i);
if (it != const_reps.end()) {
buf << it->second;
} else {
buf << "%" << i;
}
}
buf << "\n";
}
buf << " ";
if (outputs.size()) {
for (size_t i = 0; i < outputs.size(); ++i) {
if (i > 0) buf << ", ";
buf << "%" << outputs[i];
}
} else {
buf << "()";
}
buf << "\n}\n";
return buf.str();
}

} // namespace imperative
} // namespace mgb



+ 0
- 141
imperative/src/impl/ops/backward_graph.cpp View File

@@ -19,147 +19,6 @@
namespace mgb {
namespace imperative {

SmallVector<TensorPtr>
BackwardGraph::InternalGraph::apply(
const SmallVector<TensorPtr>& inputs) const {
return interpret<TensorPtr>(
&OpDef::apply_on_physical_tensor,
[](const TensorPtr& x) {return x;},
inputs);
}

std::tuple<SmallVector<LogicalTensorDesc>, bool> BackwardGraph::InternalGraph::infer_attrs(
const SmallVector<LogicalTensorDesc>& inputs) const {
using TensorAttr = LogicalTensorDesc;
ThinHashMap<size_t, TensorAttr> node2attr;
auto&& input_nodes = this->inputs;
auto&& output_nodes = this->outputs;
mgb_assert(inputs.size() == input_nodes.size());
for (size_t i = 0; i < inputs.size(); ++ i) {
node2attr[input_nodes[i]] = inputs[i];
}
for (auto &&i : constants) {
auto* value = i.second->try_get_value();
mgb_assert(value);
node2attr[i.first] = TensorAttr{
i.second->layout(), i.second->comp_node(),
value->proxy_to_default_cpu()};
}
bool validated = true;
for (size_t i = 0; i < exprs.size(); ++ i) {
auto&& [expr_op, expr_inps, expr_oups] = exprs[i];
SmallVector<TensorAttr> expr_input_descs;
for (auto &&inp : expr_inps) {
expr_input_descs.push_back(node2attr.at(inp));
}

auto [expr_output_descs, expr_validated] = OpDef::infer_output_attrs_fallible(
*expr_op, expr_input_descs);
validated = validated && expr_validated;

mgb_assert(expr_output_descs.size() == expr_oups.size());
for (size_t i = 0; i < expr_output_descs.size(); ++ i) {
node2attr[expr_oups[i]] = expr_output_descs[i];
}
}

SmallVector<TensorAttr> ret;
for (auto &&i : output_nodes) {
ret.push_back(node2attr.at(i));
}
return {ret, validated};
}

std::string BackwardGraph::InternalGraph::repr() {
std::ostringstream buf;
buf << "(";
for (size_t i = 0; i < inputs.size(); ++i) {
if (i > 0) buf << ", ";
buf << "%" << inputs[i];
}
buf << ") => {\n";
auto fmt_const = [](size_t i, TensorPtr& t) {
if (t->shape().ndim == 1 && t->shape()[0] == 1) {
auto&& v = t->get_value();
if (v.dtype() == dtype::Float32{}) {
return std::to_string(*v.ptr<dt_float32>());
} else if (v.dtype() == dtype::Int32{}) {
return std::to_string(*v.ptr<int32_t>());
}
}
return std::string("%c") + std::to_string(i);
};
std::unordered_map<size_t, std::string> const_reps;
for (auto&& [i, t] : constants) {
const_reps.emplace(i, fmt_const(i, t));
}
for (auto& [op, ins, outs] : exprs) {
buf << " ";
if (outs.size()) {
for (size_t i = 0; i < outs.size(); ++i) {
if (i > 0) buf << ", ";
buf << "%" << outs[i];
}
buf << " = ";
}
if (auto* p = op->try_cast_final<OprAttr>()) {
buf << p->type;
} else {
buf << op->dyn_typeinfo()->name;
}
for (size_t i : ins) {
buf << " ";
auto&& it = const_reps.find(i);
if (it != const_reps.end()) {
buf << it->second;
} else {
buf << "%" << i;
}
}
buf << "\n";
}
buf << " ";
if (outputs.size()) {
for (size_t i = 0; i < outputs.size(); ++i) {
if (i > 0) buf << ", ";
buf << "%" << outputs[i];
}
} else {
buf << "()";
}
buf << "\n}\n";
return buf.str();
}

MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardGraph);

namespace {
SmallVector<TensorPtr> backward_impl(
const OpDef& backward_graph,
const SmallVector<TensorPtr>& tensors) {
return backward_graph.cast_final_safe<BackwardGraph>()
.graph().apply(tensors);
}

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_tensor_attrs(
const OpDef& backward_graph,
const SmallVector<LogicalTensorDesc> inputs) {
return backward_graph.cast_final_safe<BackwardGraph>()
.graph().infer_attrs(inputs);
}

std::vector<std::pair<const char*, std::string>> props(
const OpDef& backward_graph) {
return {};
}

OP_TRAIT_REG(BackwardGraph, BackwardGraph)
.apply_on_physical_tensor(backward_impl)
.infer_output_attrs_fallible(infer_tensor_attrs)
.props(props)
.fallback();
} // anonymous namespace

} // namespace imperative
} // namespace mgb



+ 3
- 34
imperative/src/impl/proxy_graph.cpp View File

@@ -669,8 +669,7 @@ ProxyGraph::make_backward_graph(
auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo());

BackwardGraphResult result;
auto&& backward = BackwardGraph::make();
auto&& igraph = backward->cast_final_safe<BackwardGraph>().graph();
auto&& igraph = result.backward;

size_t nr_backward_graph_inputs = 0;
auto gen_expr = [this, &var2idx, &igraph, &push, &fwd,
@@ -682,7 +681,7 @@ ProxyGraph::make_backward_graph(
++ nr_backward_graph_inputs;
push(op->output(0));
} else {
std::vector<size_t> inputs, outputs;
SmallVector<size_t> inputs, outputs;
for (auto &&i : op->input()) {
if (i->owner_opr() == fwd) {
if (var2idx.find(i) == var2idx.end()) {
@@ -695,7 +694,7 @@ ProxyGraph::make_backward_graph(
for (auto &&i : op->usable_output()) {
outputs.push_back(push(i));
}
igraph.exprs.emplace_back(OpDef::make_from_op_node(op), inputs, outputs);
igraph.exprs.push_back({OpDef::make_from_op_node(op), inputs, outputs});
}
};

@@ -770,36 +769,6 @@ ProxyGraph::make_backward_graph(
write_inputs(outputs);
write_inputs(output_grads);
mgb_assert(igraph.inputs.size() == nr_backward_graph_inputs);

auto treat_as_single = [](auto&& igraph) {
if (igraph.exprs.size() != 1)
return false;
auto&& expr = igraph.exprs[0];
auto&& expr_inputs = std::get<1>(expr);
if (expr_inputs.size() != igraph.inputs.size()) {
return false;
}
for (size_t i = 0; i < expr_inputs.size(); ++ i) {
if (igraph.inputs[i] != expr_inputs[i]) {
return false;
}
}
auto&& expr_outputs = std::get<2>(expr);
if (expr_outputs.size() != igraph.outputs.size()) {
return false;
}
for (size_t i = 0; i < expr_outputs.size(); ++ i) {
if (igraph.outputs[i] != expr_outputs[i]) {
return false;
}
}
return true;
};
if (treat_as_single(igraph)) {
result.backward = std::get<0>(igraph.exprs[0]);
} else {
result.backward = backward;
}
return result;
}



+ 1
- 1
imperative/src/impl/proxy_graph.h View File

@@ -65,7 +65,7 @@ private:
class InputPlaceholder;
struct ProxyGraphInst;
struct GradGraph;
struct CurOprGuard;
class CurOprGuard;

void reset();



+ 1
- 1
imperative/src/impl/proxy_graph/common.h View File

@@ -15,7 +15,7 @@ namespace mgb::imperative::proxy_graph {
// e.g. friend class mgb::imperative::proxy_graph::ProxyGraph
struct ProxyGraph {
struct InputPlaceholder;
struct MiniGraph;
class MiniGraph;
};

} // namespace mgb::imperative::proxy_graph

+ 1
- 24
imperative/src/impl/proxy_graph_detail.cpp View File

@@ -75,30 +75,7 @@ apply_on_physical_tensor(const OpDef& def,
auto output_descs = infer_output_attrs(def, inputs);
SmallVector<TensorPtr> outputs(output_descs.size(), {});
for (size_t i = 0; i < outputs.size(); i++) {
auto& output = outputs[i];
auto& output_desc = output_descs[i];
if (def.same_type<Elemwise>()) {
for (size_t j = 0; j < inputs.size(); j++) {
// TODO: reindex inputs to support inplace exprs like 'y = x op x'.
auto& input = inputs[j];
// Because we pass inputs by value, if input and input->blob() are all unique,
// their ownerships are on the stack, thus we can reuse them safely.
// @see: interpreter::intl::ChannelImpl::process_one_task
if (input.unique() && input->blob().unique() && input->blob()->storage().unique() &&
input->layout().dtype == output_desc.layout.dtype &&
input->layout().eq_layout(output_desc.layout) &&
input->comp_node() == output_desc.comp_node) {
static std::atomic_llong inplace_count = 0;
mgb_log_debug("do inplace for elemwise, layout: %s, count: %lld",
output_desc.layout.to_string().c_str(), ++inplace_count);
output = Tensor::make(input->blob(), input->layout(), input->offset());
break;
}
}
}
if (!output) {
output = Tensor::make(output_desc.layout, output_desc.comp_node);
}
outputs[i] = Tensor::make(output_descs[i].layout, output_descs[i].comp_node);
}
exec(def, inputs, outputs);
auto async_error = ProxyGraph::get_async_error();


+ 4
- 4
imperative/src/include/megbrain/imperative/backward_graph_opt.h View File

@@ -14,10 +14,10 @@
namespace mgb::imperative {

struct OptimizedBackwardGraphResult {
std::shared_ptr<OpDef> precomp;
std::shared_ptr<OpDef> backward;
std::vector<bool> save_for_backward;
std::vector<bool> input_has_grad;
Subgraph precomp;
Subgraph backward;
SmallVector<bool> save_for_backward;
SmallVector<bool> input_has_grad;

OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph);
};


+ 53
- 3
imperative/src/include/megbrain/imperative/op_def.h View File

@@ -26,10 +26,60 @@ enum DispatchMode {
KERNEL = 1
};

using SharedOp = std::shared_ptr<OpDef>;

template <typename T>
struct Expr {
std::shared_ptr<OpDef> op;
SmallVector<T> inputs;
SmallVector<T> outputs;
};

struct Subgraph {
SmallVector<size_t> inputs;
SmallVector<std::pair<size_t, TensorPtr>> constants;
SmallVector<size_t> outputs;
SmallVector<Expr<size_t>> exprs;

template <typename T, typename F, typename C>
SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const {
std::unordered_map<size_t, T> idx2var;
mgb_assert(inputs.size() == input_vars.size(), "input size mismatch");
for (size_t i = 0; i < inputs.size(); ++i) {
idx2var[inputs[i]] = input_vars[i];
}
for (auto&& [idx, val]: constants) {
idx2var[idx] = c(val);
}
for (auto& expr: exprs) {
SmallVector<T> expr_inputs;
for (auto idx: expr.inputs) {
expr_inputs.push_back(idx2var[idx]);
}
SmallVector<T> expr_outputs = f(expr.op, std::move(expr_inputs));
mgb_assert(expr_outputs.size() == expr.outputs.size(), "output size mismatch");
for (size_t i = 0; i < expr_outputs.size(); ++i) {
idx2var[expr.outputs[i]] = expr_outputs[i];
}
}
SmallVector<T> output_vars;
for (auto idx: outputs) {
output_vars.push_back(idx2var[idx]);
}
return output_vars;
}

bool empty() const {
return outputs.size() == 0;
}

std::string repr() const;
};

struct BackwardGraphResult {
std::shared_ptr<OpDef> backward;
std::vector<bool> save_for_backward;
std::vector<bool> input_has_grad;
Subgraph backward;
SmallVector<bool> save_for_backward;
SmallVector<bool> input_has_grad;
};

class OpDef : public Hashable,


+ 0
- 86
imperative/src/include/megbrain/imperative/ops/backward_graph.h View File

@@ -15,92 +15,6 @@

namespace mgb {
namespace imperative {

// a special OpDef used for taking gradient on physical tensor
struct BackwardGraph final : public OpDefImplBase<BackwardGraph> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
struct InternalGraph {
// op, inputs, outputs
using Expr = std::tuple<std::shared_ptr<OpDef>,
std::vector<size_t>, std::vector<size_t>>;
std::vector<Expr> exprs;

// index array of input nodes
std::vector<size_t> inputs;

// index array of output nodes
std::vector<size_t> outputs;

// pair of (node index, correspending constant)
std::vector<std::pair<size_t, TensorPtr>> constants;

SmallVector<TensorPtr>
apply(const SmallVector<TensorPtr>& inputs) const;

std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_attrs(
const SmallVector<LogicalTensorDesc>& inputs) const;

template <typename T, typename F, typename C>
SmallVector<T> interpret(F&& f, C&& c, const SmallVector<T>& inputs) const {
ThinHashMap<size_t, T> node2tensor;
auto&& input_nodes = this->inputs;
mgb_assert(inputs.size() == input_nodes.size());
for (size_t i = 0; i < inputs.size(); ++ i) {
node2tensor[input_nodes[i]] = inputs[i];
}
for (auto &&i : constants) {
node2tensor[i.first] = c(i.second);
}
for (size_t i = 0; i < exprs.size(); ++ i) {
auto&& expr = exprs[i];
SmallVector<T> inputs;
for (auto &&in : std::get<1>(expr)) {
inputs.push_back(node2tensor.at(in));
}
auto&& outputs = f(*std::get<0>(expr), std::move(inputs));
auto&& output_nodes = std::get<2>(expr);
mgb_assert(outputs.size() == output_nodes.size());
for (size_t i = 0; i < outputs.size(); ++ i) {
node2tensor[output_nodes[i]] = std::move(outputs[i]);
}
}
SmallVector<T> ret;
for (auto &&i : outputs) {
ret.push_back(node2tensor.at(i));
}
return ret;
}

std::string repr();
};

const InternalGraph& graph() const {
return m_graph;
}

InternalGraph& graph() {
return m_graph;
}

bool is_same_st(const Hashable& rhs) const override {
if (!rhs.same_type<BackwardGraph>()) {
return false;
}
auto& other = rhs.cast_final_safe<BackwardGraph>();
if (this == &other) {
return true;
}
// FIXME
return false;
}

std::string repr() {return m_graph.repr();}

private:
InternalGraph m_graph;
};

} // namespace imperative
} // namespace mgb



+ 1
- 1
imperative/src/include/megbrain/imperative/ops/utility.h View File

@@ -29,7 +29,7 @@ struct GenericPyOp final : OpDefImplBase<GenericPyOp> {
}

bool is_same_st(const Hashable& rhs) const override {
return obj.equal(static_cast<const GenericPyOp&>(rhs).obj);
return obj.equal(rhs.cast_final<GenericPyOp>().obj);
}

MGB_DYN_TYPE_OBJ_FINAL_DECL;


+ 29
- 5
imperative/src/test/backward_graph.cpp View File

@@ -75,6 +75,10 @@ T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, cons
return ret;
}

SmallVector<TensorPtr> apply_shared_on_physical_tensor(std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs) {
return OpDef::apply_on_physical_tensor(*def, inputs);
}

TEST(TestImperative, BackwardGraphBasic) {
HostTensorGenerator<> gen;
SmallVector<HostTensorND> hvs;
@@ -114,7 +118,11 @@ TEST(TestImperative, BackwardGraphBasic) {
}
}
inputs.clear();
auto input_grads = OpDef::apply_on_physical_tensor(*(result.backward), backward_graph_inputs);
auto input_grads = result.backward.apply(
backward_graph_inputs,
apply_shared_on_physical_tensor,
[&](auto&& x){ return x; }
);
mgb_assert(input_grads.size() == input_has_grad.size());
for (size_t i = 0; i < input_has_grad.size(); ++ i) {
mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i]));
@@ -164,7 +172,11 @@ TEST(TestImperative, BackwardGraphIdentity) {
}
}
inputs.clear();
auto input_grads = OpDef::apply_on_physical_tensor(*(result.backward), backward_graph_inputs);
auto input_grads = result.backward.apply(
backward_graph_inputs,
apply_shared_on_physical_tensor,
[&](auto&& x){ return x; }
);
mgb_assert(input_grads.size() == input_has_grad.size());
for (size_t i = 0; i < input_has_grad.size(); ++ i) {
mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i]));
@@ -224,9 +236,17 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0];

auto backward_graph_inputs = prepare_backward_graph_inputs<SmallVector<TensorPtr>>(bg, {a_tn, b_tn}, {c_tn}, {dc_tn});
auto grads = expand_grads(bg, OpDef::apply_on_physical_tensor(*bg.backward, backward_graph_inputs));
auto grads = expand_grads(bg, bg.backward.apply(
backward_graph_inputs,
apply_shared_on_physical_tensor,
[&](auto&& x){ return x; }
));

auto precomp = OpDef::apply_on_physical_tensor(*obg.precomp, {a_tn, b_tn, c_tn});
auto precomp = obg.precomp.apply(
SmallVector<TensorPtr>{a_tn, b_tn, c_tn},
apply_shared_on_physical_tensor,
[&](auto&& x){ return x; }
);
ASSERT_EQ(precomp.size(), 2);
ASSERT_EQ(precomp[0]->shape().ndim, 1);
ASSERT_LE(precomp[0]->shape()[0], 2);
@@ -234,7 +254,11 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) {
ASSERT_LE(precomp[1]->shape()[0], 2);

auto backward_inputs = prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn});
auto grads2 = expand_grads(obg, OpDef::apply_on_physical_tensor(*obg.backward, backward_inputs));
auto grads2 = expand_grads(obg, obg.backward.apply(
backward_inputs,
apply_shared_on_physical_tensor,
[&](auto&& x){ return x; }
));

ASSERT_EQ(grads2.size(), 2);
MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value());


Loading…
Cancel
Save