From 241b35a697fcfa9d90fe5e846632c7643ef92ba8 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 24 May 2021 17:09:19 +0800 Subject: [PATCH] refactor(ops): remove BackwardGraph op GitOrigin-RevId: eda20e57606daad69790f6abbc7cd7fba2ba934c --- .../python/megengine/core/tensor/megbrain_graph.py | 16 --- imperative/python/megengine/jit/tracing.py | 25 ++-- imperative/python/src/grad.cpp | 6 +- imperative/python/src/imperative_rt.cpp | 30 ++--- imperative/python/src/ops.cpp | 37 ------ imperative/python/src/tensor.cpp | 3 - imperative/python/src/tensor.h | 52 +++++--- imperative/python/src/trace.cpp | 2 +- imperative/src/impl/backward_graph_opt.cpp | 12 +- .../src/impl/interpreter/interpreter_impl.cpp | 3 +- imperative/src/impl/op_def.cpp | 64 ++++++++++ imperative/src/impl/ops/backward_graph.cpp | 141 --------------------- imperative/src/impl/proxy_graph.cpp | 37 +----- imperative/src/impl/proxy_graph.h | 2 +- imperative/src/impl/proxy_graph/common.h | 2 +- imperative/src/impl/proxy_graph_detail.cpp | 25 +--- .../megbrain/imperative/backward_graph_opt.h | 8 +- .../src/include/megbrain/imperative/op_def.h | 56 +++++++- .../megbrain/imperative/ops/backward_graph.h | 86 ------------- .../src/include/megbrain/imperative/ops/utility.h | 2 +- imperative/src/test/backward_graph.cpp | 34 ++++- 21 files changed, 221 insertions(+), 422 deletions(-) diff --git a/imperative/python/megengine/core/tensor/megbrain_graph.py b/imperative/python/megengine/core/tensor/megbrain_graph.py index e11e0f29..3887c167 100644 --- a/imperative/python/megengine/core/tensor/megbrain_graph.py +++ b/imperative/python/megengine/core/tensor/megbrain_graph.py @@ -18,7 +18,6 @@ import numpy as np from .. import _imperative_rt from .._imperative_rt import GraphOptimizeOptions from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode -from .._imperative_rt.ops import BackwardGraph from .._wrap import device as as_device from ..ops.builtin import OpDef from .core import TensorBase @@ -481,21 +480,6 @@ def apply_normal_varnode(op: OpDef, *args: VarNode): return _wrap(outputs) -def apply_backward_varnode(op: BackwardGraph, *args: VarNode): - assert args - graph = args[0].graph - outputs = op.interpret( - op, - lambda op, args: apply_normal_varnode(op, *args), - graph._make_const_for_backward, - args, - ) - return outputs - - -set_cpp_apply_backward_varnode(apply_backward_varnode) - - def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): outputs = _imperative_rt.input_callback( callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph diff --git a/imperative/python/megengine/jit/tracing.py b/imperative/python/megengine/jit/tracing.py index a8c222f8..d17d6e71 100644 --- a/imperative/python/megengine/jit/tracing.py +++ b/imperative/python/megengine/jit/tracing.py @@ -32,7 +32,7 @@ from ..core._imperative_rt.ops import ( ) from ..core._trace_option import set_symbolic_shape from ..core._wrap import device as as_device -from ..core.ops.builtin import BackwardGraph, BatchNorm, OpDef +from ..core.ops.builtin import BatchNorm, OpDef from ..core.ops.special import Const from ..core.tensor import megbrain_graph as G from ..core.tensor.utils import setscalar @@ -587,10 +587,7 @@ class trace: ivars.append(info.varnode) - if isinstance(op, BackwardGraph): - ovars = G.apply_backward_varnode(op, *ivars) - else: - ovars = G.apply_normal_varnode(op, *ivars) + ovars = G.apply_normal_varnode(op, *ivars) if require_links and len(ovars) > 0: io_links = (ovars[0],) @@ -805,14 +802,11 @@ class trace: name=info.name, ) ivars.append(h2v[h]) - if isinstance(op, BackwardGraph): - ovars = G.apply_backward_varnode(op, *ivars) - else: - if isinstance(op, BatchNorm): - assert ( - op.fwd_mode == BatchNorm.FwdMode.INFERENCE - ), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?" - ovars = G.apply_normal_varnode(op, *ivars) + if isinstance(op, BatchNorm): + assert ( + op.fwd_mode == BatchNorm.FwdMode.INFERENCE + ), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?" + ovars = G.apply_normal_varnode(op, *ivars) AutoNaming.record_opnode(ovars[0].op) @@ -1088,10 +1082,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): ivars[0] = opnode.outputs[0] active_trace._lazy_eval_links = (ivars[0],) - if isinstance(op, BackwardGraph): - ovars = G.apply_backward_varnode(op, *ivars) - else: - ovars = G.apply_normal_varnode(op, *ivars) + ovars = G.apply_normal_varnode(op, *ivars) outputs = [RawTensor(o) for o in ovars] if require_links: diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index cffdbc57..97e0fcc4 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -75,9 +75,9 @@ std::shared_ptr make_backward_graph( input_requires_grad[i] = python::input_requires_grad(ctx, i); } std::shared_ptr ret; - auto bg = proxy_graph_detail::make_backward_graph( + auto bg = OpDef::make_backward_graph( *ctx.op, inputs, input_requires_grad, output_has_grad); - if (bg.backward) { + if (!bg.backward.empty()) { ret = std::make_shared(bg); } backward_graph_cache.emplace(key, ret); @@ -112,7 +112,7 @@ struct BackwardGraphWithClosure { size_t count = std::count_if(save_for_backward.begin(), save_for_backward.end(), ranges::identity{}); - if (backward_graph->precomp) { + if (!backward_graph->precomp.empty()) { auto&& irng = ranges::span(ctx.args, ctx.nargs); auto&& orng = views::transform(outputs, [](auto&& i){return i.get();}); auto precomp = apply(backward_graph->precomp, views::concat(irng, orng)); diff --git a/imperative/python/src/imperative_rt.cpp b/imperative/python/src/imperative_rt.cpp index ae679ecb..2fed9188 100644 --- a/imperative/python/src/imperative_rt.cpp +++ b/imperative/python/src/imperative_rt.cpp @@ -30,26 +30,14 @@ using namespace imperative; using namespace interpreter; -namespace { - -std::optional, std::vector, std::vector>> -make_backward_graph( - const OpDef& opdef, std::vector inputs, - std::vector input_requires_grad, - std::vector output_has_grad) { - auto res = OpDef::make_backward_graph(opdef, - SmallVector(inputs.begin(), inputs.end()), - SmallVector(input_requires_grad.begin(), input_requires_grad.end()), - SmallVector(output_has_grad.begin(), output_has_grad.end())); - if (res.backward) { - return std::optional, std::vector, std::vector>>{ - std::in_place, res.backward, res.save_for_backward, res.input_has_grad}; - } else { - return {}; - } -} -} // namespace - void init_imperative_rt(py::module m) { - m.def("make_backward_graph", &make_backward_graph); + auto make_backward_graph = []( + const OpDef& def, + const SmallVector& inputs, + const SmallVector& input_requires_grad, + const SmallVector& output_has_grad){ + auto result = OpDef::make_backward_graph(def, inputs, input_requires_grad, output_has_grad); + return std::make_tuple("backward_graph", result.save_for_backward, result.input_has_grad); + }; + m.def("make_backward_graph", make_backward_graph); } diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index c9d50775..e151fadb 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -367,42 +367,6 @@ void _init_py_op_def(py::module m) { } /*********** begin of hand-write opdefs **************/ - -PyOpDefBegin(BackwardGraph) // {{ -// }; -PyOpDefEnd(BackwardGraph) - -void _init_py_backward_graph(py::module m) { - using py_op = PyOp(BackwardGraph); - auto& py_type = PyOpType(BackwardGraph); - py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; - py_type.tp_name = "megengine.core._imperative_rt.ops.BackwardGraph"; - py_type.tp_basicsize = sizeof(PyOp(BackwardGraph)); - py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; - py_type.tp_doc = "BackwardGraph"; - py_type.tp_base = &PyOpType(OpDef); - py_type.tp_dealloc = py_dealloc_generic; - py_type.tp_new = py_new_generic; - mgb_assert(PyType_Ready(&py_type) >= 0); - // FIXME: rewrite interpret function in cpython instead wrap directly by pybind11::cppfunction - auto interpret = py::cpp_function( - [](OpDef& self, py::object pyf, py::object pyc, - const mgb::SmallVector& inputs) { - auto f = [pyf](OpDef& op, const mgb::SmallVector& inputs) { - return py::cast>(pyf(op.shared_from_this(), inputs)); - }; - auto c = [pyc](const TensorPtr& tensor) { - return pyc(tensor->dev_tensor()); - }; - return self.cast_final_safe().graph().interpret(f, c, inputs); - }); - mgb_assert(PyDict_SetItemString( - py_type.tp_dict, "interpret", interpret.release().ptr()) >= 0); - PyType_Modified(&py_type); - m.add_object("BackwardGraph", reinterpret_cast(&py_type)); - mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second); -} - struct PyOpBase : PyOpDef { static PyTypeObject py_type; @@ -496,7 +460,6 @@ FOR_EACH_BIT_COMBINED_ENUM_PARAM(BIT_COMBINED_ENUM_CASTER_IMPL) void init_ops(py::module m) { _init_py_op_def(m); - _init_py_backward_graph(m); _init_py_op_base(m); INIT_ALL_OP(m) diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index 67be8ee1..7c6c9ddd 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -156,9 +156,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje ctx.args = &tensors[0]; ctx.nargs = nargs; ctx.pytype = pytype; - if (ctx.op->same_type()) { - ctx.backward = true; - } if (py::isinstance(py::handle(args[0]))){ SmallVector vinputs(nargs); diff --git a/imperative/python/src/tensor.h b/imperative/python/src/tensor.h index 89b50575..be6fd183 100644 --- a/imperative/python/src/tensor.h +++ b/imperative/python/src/tensor.h @@ -248,31 +248,53 @@ apply_result_t apply(std::shared_ptr op, Args&&... args) { return apply(ctx); } -template -auto apply(std::shared_ptr op, T&& tensors) - -> std::enable_if_t, - apply_result_t> { +inline auto apply(std::shared_ptr op, Tensor*const* args, size_t nargs) { ApplyContext ctx; ctx.op = std::move(op); - ctx.nargs = tensors.size(); - Tensor* args[ctx.nargs]; + ctx.nargs = nargs; ctx.args = args; - for (size_t i = 0; i < ctx.nargs; ++i) { - args[i] = resolve_arrow(tensors[i]); + for (size_t i = 0; i < nargs; ++i) { ctx.flags |= args[i]->m_flags; } return apply(ctx); } -inline auto apply(std::shared_ptr op, Tensor*const* args, size_t nargs) { - ApplyContext ctx; - ctx.op = std::move(op); - ctx.nargs = nargs; - ctx.args = args; +template +auto apply(std::shared_ptr op, T&& tensors) + -> std::enable_if_t, + apply_result_t> { + size_t nargs = tensors.size(); + Tensor* args[nargs]; for (size_t i = 0; i < nargs; ++i) { - ctx.flags |= args[i]->m_flags; + args[i] = resolve_arrow(tensors[i]); } - return apply(ctx); + return apply(op, args, nargs); +} + +inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) { + SmallVector> inputs; + for (size_t i = 0; i < nargs; ++i) { + inputs.push_back(args[i]->shared_from_this()); + } + auto apply_functor = [](std::shared_ptr op, SmallVector> inputs) { + return apply(op, inputs); + }; + auto const_functor = [](imperative::TensorPtr value) { + return std::make_shared(interpreter_for_py->put(value->dev_tensor())); + }; + return graph.apply(inputs, apply_functor, const_functor); +} + +template +auto apply(Subgraph graph, T&& tensors) + -> std::enable_if_t, + apply_result_t> { + size_t nargs = tensors.size(); + Tensor* args[nargs]; + for (size_t i = 0; i < nargs; ++i) { + args[i] = resolve_arrow(tensors[i]); + } + return apply(graph, args, nargs); } void init_tensor(pybind11::module); diff --git a/imperative/python/src/trace.cpp b/imperative/python/src/trace.cpp index 30ddb78b..16fbbf5a 100644 --- a/imperative/python/src/trace.cpp +++ b/imperative/python/src/trace.cpp @@ -22,7 +22,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { apply_result_t outputs; if (ctx.backward) { - // call megbrain_graph.py apply(BackwardGraph, *args) + // reach here when compiled=True auto args = py::tuple(ctx.nargs + 1); args[0] = py::cast(ctx.op); for (size_t i = 0; i < ctx.nargs; i++) { diff --git a/imperative/src/impl/backward_graph_opt.cpp b/imperative/src/impl/backward_graph_opt.cpp index 9e59a8aa..49dd7673 100644 --- a/imperative/src/impl/backward_graph_opt.cpp +++ b/imperative/src/impl/backward_graph_opt.cpp @@ -18,24 +18,22 @@ using namespace imperative; OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src) : input_has_grad(src.input_has_grad) { - if (!src.backward->same_type()) { + if (src.backward.exprs.size() <= 1) { // backward graph only contains a single op backward = src.backward; save_for_backward = src.save_for_backward; return; } save_for_backward.resize(src.save_for_backward.size(), false); - precomp.reset(new BackwardGraph); - backward.reset(new BackwardGraph); - auto&& graph = src.backward->cast_final_safe().graph(); + auto&& graph = src.backward; auto&& mask = src.save_for_backward; size_t input_size = src.input_has_grad.size(); size_t output_size = (mask.size() - input_size) / 2; mgb_assert(input_size + output_size * 2 == mask.size()); - auto& fgraph = precomp->cast_final().graph(); - auto& bgraph = backward->cast_final().graph(); + auto& fgraph = precomp; + auto& bgraph = backward; // optimization: move ops (e.g. GetVarShape) to forward to // reduce memory footprint @@ -113,6 +111,6 @@ OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphRe } if (!fgraph.outputs.size()) { - precomp.reset(); + precomp = {}; } } diff --git a/imperative/src/impl/interpreter/interpreter_impl.cpp b/imperative/src/impl/interpreter/interpreter_impl.cpp index 1d8c98fb..081d0691 100644 --- a/imperative/src/impl/interpreter/interpreter_impl.cpp +++ b/imperative/src/impl/interpreter/interpreter_impl.cpp @@ -911,8 +911,7 @@ auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle { op_type == RemoteSend::typeinfo() || op_type == CollectiveComm::typeinfo() || op_type == opr::InputCallback::typeinfo() || - op_type == opr::OutputCallback::typeinfo() || - op_type == BackwardGraph::typeinfo()) { + op_type == opr::OutputCallback::typeinfo()) { return m_commands.end(); } } else if constexpr (std::is_same_v) { diff --git a/imperative/src/impl/op_def.cpp b/imperative/src/impl/op_def.cpp index 96d2812d..7603299a 100644 --- a/imperative/src/impl/op_def.cpp +++ b/imperative/src/impl/op_def.cpp @@ -10,6 +10,9 @@ */ #include "megbrain/imperative/op_def.h" + +#include + #include "megbrain/imperative/ops/opr_attr.h" #include "./op_trait.h" @@ -117,6 +120,67 @@ const std::string OpDef::make_name() const { return m_scope + "." + trait()->make_name(*this); } +std::string Subgraph::repr() const { + std::ostringstream buf; + buf << "("; + for (size_t i = 0; i < inputs.size(); ++i) { + if (i > 0) buf << ", "; + buf << "%" << inputs[i]; + } + buf << ") => {\n"; + auto fmt_const = [](size_t i, const TensorPtr& t) { + if (t->shape().ndim == 1 && t->shape()[0] == 1) { + auto&& v = t->get_value(); + if (v.dtype() == dtype::Float32{}) { + return std::to_string(*v.ptr()); + } else if (v.dtype() == dtype::Int32{}) { + return std::to_string(*v.ptr()); + } + } + return std::string("%c") + std::to_string(i); + }; + std::unordered_map const_reps; + for (auto&& [i, t] : constants) { + const_reps.emplace(i, fmt_const(i, t)); + } + for (auto& [op, ins, outs] : exprs) { + buf << " "; + if (outs.size()) { + for (size_t i = 0; i < outs.size(); ++i) { + if (i > 0) buf << ", "; + buf << "%" << outs[i]; + } + buf << " = "; + } + if (auto* p = op->try_cast_final()) { + buf << p->type; + } else { + buf << op->dyn_typeinfo()->name; + } + for (size_t i : ins) { + buf << " "; + auto&& it = const_reps.find(i); + if (it != const_reps.end()) { + buf << it->second; + } else { + buf << "%" << i; + } + } + buf << "\n"; + } + buf << " "; + if (outputs.size()) { + for (size_t i = 0; i < outputs.size(); ++i) { + if (i > 0) buf << ", "; + buf << "%" << outputs[i]; + } + } else { + buf << "()"; + } + buf << "\n}\n"; + return buf.str(); +} + } // namespace imperative } // namespace mgb diff --git a/imperative/src/impl/ops/backward_graph.cpp b/imperative/src/impl/ops/backward_graph.cpp index a9e73ec7..c43b1ed4 100644 --- a/imperative/src/impl/ops/backward_graph.cpp +++ b/imperative/src/impl/ops/backward_graph.cpp @@ -19,147 +19,6 @@ namespace mgb { namespace imperative { -SmallVector -BackwardGraph::InternalGraph::apply( - const SmallVector& inputs) const { - return interpret( - &OpDef::apply_on_physical_tensor, - [](const TensorPtr& x) {return x;}, - inputs); -} - -std::tuple, bool> BackwardGraph::InternalGraph::infer_attrs( - const SmallVector& inputs) const { - using TensorAttr = LogicalTensorDesc; - ThinHashMap node2attr; - auto&& input_nodes = this->inputs; - auto&& output_nodes = this->outputs; - mgb_assert(inputs.size() == input_nodes.size()); - for (size_t i = 0; i < inputs.size(); ++ i) { - node2attr[input_nodes[i]] = inputs[i]; - } - for (auto &&i : constants) { - auto* value = i.second->try_get_value(); - mgb_assert(value); - node2attr[i.first] = TensorAttr{ - i.second->layout(), i.second->comp_node(), - value->proxy_to_default_cpu()}; - } - bool validated = true; - for (size_t i = 0; i < exprs.size(); ++ i) { - auto&& [expr_op, expr_inps, expr_oups] = exprs[i]; - SmallVector expr_input_descs; - for (auto &&inp : expr_inps) { - expr_input_descs.push_back(node2attr.at(inp)); - } - - auto [expr_output_descs, expr_validated] = OpDef::infer_output_attrs_fallible( - *expr_op, expr_input_descs); - validated = validated && expr_validated; - - mgb_assert(expr_output_descs.size() == expr_oups.size()); - for (size_t i = 0; i < expr_output_descs.size(); ++ i) { - node2attr[expr_oups[i]] = expr_output_descs[i]; - } - } - - SmallVector ret; - for (auto &&i : output_nodes) { - ret.push_back(node2attr.at(i)); - } - return {ret, validated}; -} - -std::string BackwardGraph::InternalGraph::repr() { - std::ostringstream buf; - buf << "("; - for (size_t i = 0; i < inputs.size(); ++i) { - if (i > 0) buf << ", "; - buf << "%" << inputs[i]; - } - buf << ") => {\n"; - auto fmt_const = [](size_t i, TensorPtr& t) { - if (t->shape().ndim == 1 && t->shape()[0] == 1) { - auto&& v = t->get_value(); - if (v.dtype() == dtype::Float32{}) { - return std::to_string(*v.ptr()); - } else if (v.dtype() == dtype::Int32{}) { - return std::to_string(*v.ptr()); - } - } - return std::string("%c") + std::to_string(i); - }; - std::unordered_map const_reps; - for (auto&& [i, t] : constants) { - const_reps.emplace(i, fmt_const(i, t)); - } - for (auto& [op, ins, outs] : exprs) { - buf << " "; - if (outs.size()) { - for (size_t i = 0; i < outs.size(); ++i) { - if (i > 0) buf << ", "; - buf << "%" << outs[i]; - } - buf << " = "; - } - if (auto* p = op->try_cast_final()) { - buf << p->type; - } else { - buf << op->dyn_typeinfo()->name; - } - for (size_t i : ins) { - buf << " "; - auto&& it = const_reps.find(i); - if (it != const_reps.end()) { - buf << it->second; - } else { - buf << "%" << i; - } - } - buf << "\n"; - } - buf << " "; - if (outputs.size()) { - for (size_t i = 0; i < outputs.size(); ++i) { - if (i > 0) buf << ", "; - buf << "%" << outputs[i]; - } - } else { - buf << "()"; - } - buf << "\n}\n"; - return buf.str(); -} - -MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardGraph); - -namespace { -SmallVector backward_impl( - const OpDef& backward_graph, - const SmallVector& tensors) { - return backward_graph.cast_final_safe() - .graph().apply(tensors); -} - -std::tuple, bool> infer_tensor_attrs( - const OpDef& backward_graph, - const SmallVector inputs) { - return backward_graph.cast_final_safe() - .graph().infer_attrs(inputs); -} - -std::vector> props( - const OpDef& backward_graph) { - return {}; -} - -OP_TRAIT_REG(BackwardGraph, BackwardGraph) - .apply_on_physical_tensor(backward_impl) - .infer_output_attrs_fallible(infer_tensor_attrs) - .props(props) - .fallback(); -} // anonymous namespace - } // namespace imperative } // namespace mgb diff --git a/imperative/src/impl/proxy_graph.cpp b/imperative/src/impl/proxy_graph.cpp index 0b5c96fb..1bcd6591 100644 --- a/imperative/src/impl/proxy_graph.cpp +++ b/imperative/src/impl/proxy_graph.cpp @@ -669,8 +669,7 @@ ProxyGraph::make_backward_graph( auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo()); BackwardGraphResult result; - auto&& backward = BackwardGraph::make(); - auto&& igraph = backward->cast_final_safe().graph(); + auto&& igraph = result.backward; size_t nr_backward_graph_inputs = 0; auto gen_expr = [this, &var2idx, &igraph, &push, &fwd, @@ -682,7 +681,7 @@ ProxyGraph::make_backward_graph( ++ nr_backward_graph_inputs; push(op->output(0)); } else { - std::vector inputs, outputs; + SmallVector inputs, outputs; for (auto &&i : op->input()) { if (i->owner_opr() == fwd) { if (var2idx.find(i) == var2idx.end()) { @@ -695,7 +694,7 @@ ProxyGraph::make_backward_graph( for (auto &&i : op->usable_output()) { outputs.push_back(push(i)); } - igraph.exprs.emplace_back(OpDef::make_from_op_node(op), inputs, outputs); + igraph.exprs.push_back({OpDef::make_from_op_node(op), inputs, outputs}); } }; @@ -770,36 +769,6 @@ ProxyGraph::make_backward_graph( write_inputs(outputs); write_inputs(output_grads); mgb_assert(igraph.inputs.size() == nr_backward_graph_inputs); - - auto treat_as_single = [](auto&& igraph) { - if (igraph.exprs.size() != 1) - return false; - auto&& expr = igraph.exprs[0]; - auto&& expr_inputs = std::get<1>(expr); - if (expr_inputs.size() != igraph.inputs.size()) { - return false; - } - for (size_t i = 0; i < expr_inputs.size(); ++ i) { - if (igraph.inputs[i] != expr_inputs[i]) { - return false; - } - } - auto&& expr_outputs = std::get<2>(expr); - if (expr_outputs.size() != igraph.outputs.size()) { - return false; - } - for (size_t i = 0; i < expr_outputs.size(); ++ i) { - if (igraph.outputs[i] != expr_outputs[i]) { - return false; - } - } - return true; - }; - if (treat_as_single(igraph)) { - result.backward = std::get<0>(igraph.exprs[0]); - } else { - result.backward = backward; - } return result; } diff --git a/imperative/src/impl/proxy_graph.h b/imperative/src/impl/proxy_graph.h index e0927cbc..779949d6 100644 --- a/imperative/src/impl/proxy_graph.h +++ b/imperative/src/impl/proxy_graph.h @@ -65,7 +65,7 @@ private: class InputPlaceholder; struct ProxyGraphInst; struct GradGraph; - struct CurOprGuard; + class CurOprGuard; void reset(); diff --git a/imperative/src/impl/proxy_graph/common.h b/imperative/src/impl/proxy_graph/common.h index 39f4cb88..ebfad8a1 100644 --- a/imperative/src/impl/proxy_graph/common.h +++ b/imperative/src/impl/proxy_graph/common.h @@ -15,7 +15,7 @@ namespace mgb::imperative::proxy_graph { // e.g. friend class mgb::imperative::proxy_graph::ProxyGraph struct ProxyGraph { struct InputPlaceholder; - struct MiniGraph; + class MiniGraph; }; } // namespace mgb::imperative::proxy_graph diff --git a/imperative/src/impl/proxy_graph_detail.cpp b/imperative/src/impl/proxy_graph_detail.cpp index d68ce881..d5d0365c 100644 --- a/imperative/src/impl/proxy_graph_detail.cpp +++ b/imperative/src/impl/proxy_graph_detail.cpp @@ -75,30 +75,7 @@ apply_on_physical_tensor(const OpDef& def, auto output_descs = infer_output_attrs(def, inputs); SmallVector outputs(output_descs.size(), {}); for (size_t i = 0; i < outputs.size(); i++) { - auto& output = outputs[i]; - auto& output_desc = output_descs[i]; - if (def.same_type()) { - for (size_t j = 0; j < inputs.size(); j++) { - // TODO: reindex inputs to support inplace exprs like 'y = x op x'. - auto& input = inputs[j]; - // Because we pass inputs by value, if input and input->blob() are all unique, - // their ownerships are on the stack, thus we can reuse them safely. - // @see: interpreter::intl::ChannelImpl::process_one_task - if (input.unique() && input->blob().unique() && input->blob()->storage().unique() && - input->layout().dtype == output_desc.layout.dtype && - input->layout().eq_layout(output_desc.layout) && - input->comp_node() == output_desc.comp_node) { - static std::atomic_llong inplace_count = 0; - mgb_log_debug("do inplace for elemwise, layout: %s, count: %lld", - output_desc.layout.to_string().c_str(), ++inplace_count); - output = Tensor::make(input->blob(), input->layout(), input->offset()); - break; - } - } - } - if (!output) { - output = Tensor::make(output_desc.layout, output_desc.comp_node); - } + outputs[i] = Tensor::make(output_descs[i].layout, output_descs[i].comp_node); } exec(def, inputs, outputs); auto async_error = ProxyGraph::get_async_error(); diff --git a/imperative/src/include/megbrain/imperative/backward_graph_opt.h b/imperative/src/include/megbrain/imperative/backward_graph_opt.h index 7b249441..ef6e6461 100644 --- a/imperative/src/include/megbrain/imperative/backward_graph_opt.h +++ b/imperative/src/include/megbrain/imperative/backward_graph_opt.h @@ -14,10 +14,10 @@ namespace mgb::imperative { struct OptimizedBackwardGraphResult { - std::shared_ptr precomp; - std::shared_ptr backward; - std::vector save_for_backward; - std::vector input_has_grad; + Subgraph precomp; + Subgraph backward; + SmallVector save_for_backward; + SmallVector input_has_grad; OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph); }; diff --git a/imperative/src/include/megbrain/imperative/op_def.h b/imperative/src/include/megbrain/imperative/op_def.h index 61e28634..59bb96df 100644 --- a/imperative/src/include/megbrain/imperative/op_def.h +++ b/imperative/src/include/megbrain/imperative/op_def.h @@ -26,10 +26,60 @@ enum DispatchMode { KERNEL = 1 }; +using SharedOp = std::shared_ptr; + +template +struct Expr { + std::shared_ptr op; + SmallVector inputs; + SmallVector outputs; +}; + +struct Subgraph { + SmallVector inputs; + SmallVector> constants; + SmallVector outputs; + SmallVector> exprs; + + template + SmallVector apply(SmallVector input_vars, F&& f, C&& c) const { + std::unordered_map idx2var; + mgb_assert(inputs.size() == input_vars.size(), "input size mismatch"); + for (size_t i = 0; i < inputs.size(); ++i) { + idx2var[inputs[i]] = input_vars[i]; + } + for (auto&& [idx, val]: constants) { + idx2var[idx] = c(val); + } + for (auto& expr: exprs) { + SmallVector expr_inputs; + for (auto idx: expr.inputs) { + expr_inputs.push_back(idx2var[idx]); + } + SmallVector expr_outputs = f(expr.op, std::move(expr_inputs)); + mgb_assert(expr_outputs.size() == expr.outputs.size(), "output size mismatch"); + for (size_t i = 0; i < expr_outputs.size(); ++i) { + idx2var[expr.outputs[i]] = expr_outputs[i]; + } + } + SmallVector output_vars; + for (auto idx: outputs) { + output_vars.push_back(idx2var[idx]); + } + return output_vars; + } + + bool empty() const { + return outputs.size() == 0; + } + + std::string repr() const; +}; + struct BackwardGraphResult { - std::shared_ptr backward; - std::vector save_for_backward; - std::vector input_has_grad; + Subgraph backward; + SmallVector save_for_backward; + SmallVector input_has_grad; }; class OpDef : public Hashable, diff --git a/imperative/src/include/megbrain/imperative/ops/backward_graph.h b/imperative/src/include/megbrain/imperative/ops/backward_graph.h index c831e7ed..3738808a 100644 --- a/imperative/src/include/megbrain/imperative/ops/backward_graph.h +++ b/imperative/src/include/megbrain/imperative/ops/backward_graph.h @@ -15,92 +15,6 @@ namespace mgb { namespace imperative { - -// a special OpDef used for taking gradient on physical tensor -struct BackwardGraph final : public OpDefImplBase { - MGB_DYN_TYPE_OBJ_FINAL_DECL; -public: - struct InternalGraph { - // op, inputs, outputs - using Expr = std::tuple, - std::vector, std::vector>; - std::vector exprs; - - // index array of input nodes - std::vector inputs; - - // index array of output nodes - std::vector outputs; - - // pair of (node index, correspending constant) - std::vector> constants; - - SmallVector - apply(const SmallVector& inputs) const; - - std::tuple, bool> infer_attrs( - const SmallVector& inputs) const; - - template - SmallVector interpret(F&& f, C&& c, const SmallVector& inputs) const { - ThinHashMap node2tensor; - auto&& input_nodes = this->inputs; - mgb_assert(inputs.size() == input_nodes.size()); - for (size_t i = 0; i < inputs.size(); ++ i) { - node2tensor[input_nodes[i]] = inputs[i]; - } - for (auto &&i : constants) { - node2tensor[i.first] = c(i.second); - } - for (size_t i = 0; i < exprs.size(); ++ i) { - auto&& expr = exprs[i]; - SmallVector inputs; - for (auto &&in : std::get<1>(expr)) { - inputs.push_back(node2tensor.at(in)); - } - auto&& outputs = f(*std::get<0>(expr), std::move(inputs)); - auto&& output_nodes = std::get<2>(expr); - mgb_assert(outputs.size() == output_nodes.size()); - for (size_t i = 0; i < outputs.size(); ++ i) { - node2tensor[output_nodes[i]] = std::move(outputs[i]); - } - } - SmallVector ret; - for (auto &&i : outputs) { - ret.push_back(node2tensor.at(i)); - } - return ret; - } - - std::string repr(); - }; - - const InternalGraph& graph() const { - return m_graph; - } - - InternalGraph& graph() { - return m_graph; - } - - bool is_same_st(const Hashable& rhs) const override { - if (!rhs.same_type()) { - return false; - } - auto& other = rhs.cast_final_safe(); - if (this == &other) { - return true; - } - // FIXME - return false; - } - - std::string repr() {return m_graph.repr();} - -private: - InternalGraph m_graph; -}; - } // namespace imperative } // namespace mgb diff --git a/imperative/src/include/megbrain/imperative/ops/utility.h b/imperative/src/include/megbrain/imperative/ops/utility.h index ba85f366..ab58283c 100644 --- a/imperative/src/include/megbrain/imperative/ops/utility.h +++ b/imperative/src/include/megbrain/imperative/ops/utility.h @@ -29,7 +29,7 @@ struct GenericPyOp final : OpDefImplBase { } bool is_same_st(const Hashable& rhs) const override { - return obj.equal(static_cast(rhs).obj); + return obj.equal(rhs.cast_final().obj); } MGB_DYN_TYPE_OBJ_FINAL_DECL; diff --git a/imperative/src/test/backward_graph.cpp b/imperative/src/test/backward_graph.cpp index 9e83ed85..e4ab27d0 100644 --- a/imperative/src/test/backward_graph.cpp +++ b/imperative/src/test/backward_graph.cpp @@ -75,6 +75,10 @@ T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, cons return ret; } +SmallVector apply_shared_on_physical_tensor(std::shared_ptr def, SmallVector inputs) { + return OpDef::apply_on_physical_tensor(*def, inputs); +} + TEST(TestImperative, BackwardGraphBasic) { HostTensorGenerator<> gen; SmallVector hvs; @@ -114,7 +118,11 @@ TEST(TestImperative, BackwardGraphBasic) { } } inputs.clear(); - auto input_grads = OpDef::apply_on_physical_tensor(*(result.backward), backward_graph_inputs); + auto input_grads = result.backward.apply( + backward_graph_inputs, + apply_shared_on_physical_tensor, + [&](auto&& x){ return x; } + ); mgb_assert(input_grads.size() == input_has_grad.size()); for (size_t i = 0; i < input_has_grad.size(); ++ i) { mgb_assert(input_has_grad[i] == static_cast(input_grads[i])); @@ -164,7 +172,11 @@ TEST(TestImperative, BackwardGraphIdentity) { } } inputs.clear(); - auto input_grads = OpDef::apply_on_physical_tensor(*(result.backward), backward_graph_inputs); + auto input_grads = result.backward.apply( + backward_graph_inputs, + apply_shared_on_physical_tensor, + [&](auto&& x){ return x; } + ); mgb_assert(input_grads.size() == input_has_grad.size()); for (size_t i = 0; i < input_has_grad.size(); ++ i) { mgb_assert(input_has_grad[i] == static_cast(input_grads[i])); @@ -224,9 +236,17 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0]; auto backward_graph_inputs = prepare_backward_graph_inputs>(bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); - auto grads = expand_grads(bg, OpDef::apply_on_physical_tensor(*bg.backward, backward_graph_inputs)); + auto grads = expand_grads(bg, bg.backward.apply( + backward_graph_inputs, + apply_shared_on_physical_tensor, + [&](auto&& x){ return x; } + )); - auto precomp = OpDef::apply_on_physical_tensor(*obg.precomp, {a_tn, b_tn, c_tn}); + auto precomp = obg.precomp.apply( + SmallVector{a_tn, b_tn, c_tn}, + apply_shared_on_physical_tensor, + [&](auto&& x){ return x; } + ); ASSERT_EQ(precomp.size(), 2); ASSERT_EQ(precomp[0]->shape().ndim, 1); ASSERT_LE(precomp[0]->shape()[0], 2); @@ -234,7 +254,11 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { ASSERT_LE(precomp[1]->shape()[0], 2); auto backward_inputs = prepare_optimized_backward_inputs>(obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); - auto grads2 = expand_grads(obg, OpDef::apply_on_physical_tensor(*obg.backward, backward_inputs)); + auto grads2 = expand_grads(obg, obg.backward.apply( + backward_inputs, + apply_shared_on_physical_tensor, + [&](auto&& x){ return x; } + )); ASSERT_EQ(grads2.size(), 2); MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value());