@@ -18,7 +18,6 @@ import numpy as np | |||||
from .. import _imperative_rt | from .. import _imperative_rt | ||||
from .._imperative_rt import GraphOptimizeOptions | from .._imperative_rt import GraphOptimizeOptions | ||||
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode | from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode | ||||
from .._imperative_rt.ops import BackwardGraph | |||||
from .._wrap import device as as_device | from .._wrap import device as as_device | ||||
from ..ops.builtin import OpDef | from ..ops.builtin import OpDef | ||||
from .core import TensorBase | from .core import TensorBase | ||||
@@ -481,21 +480,6 @@ def apply_normal_varnode(op: OpDef, *args: VarNode): | |||||
return _wrap(outputs) | return _wrap(outputs) | ||||
def apply_backward_varnode(op: BackwardGraph, *args: VarNode): | |||||
assert args | |||||
graph = args[0].graph | |||||
outputs = op.interpret( | |||||
op, | |||||
lambda op, args: apply_normal_varnode(op, *args), | |||||
graph._make_const_for_backward, | |||||
args, | |||||
) | |||||
return outputs | |||||
set_cpp_apply_backward_varnode(apply_backward_varnode) | |||||
def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): | def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): | ||||
outputs = _imperative_rt.input_callback( | outputs = _imperative_rt.input_callback( | ||||
callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph | callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph | ||||
@@ -32,7 +32,7 @@ from ..core._imperative_rt.ops import ( | |||||
) | ) | ||||
from ..core._trace_option import set_symbolic_shape | from ..core._trace_option import set_symbolic_shape | ||||
from ..core._wrap import device as as_device | from ..core._wrap import device as as_device | ||||
from ..core.ops.builtin import BackwardGraph, BatchNorm, OpDef | |||||
from ..core.ops.builtin import BatchNorm, OpDef | |||||
from ..core.ops.special import Const | from ..core.ops.special import Const | ||||
from ..core.tensor import megbrain_graph as G | from ..core.tensor import megbrain_graph as G | ||||
from ..core.tensor.utils import setscalar | from ..core.tensor.utils import setscalar | ||||
@@ -587,10 +587,7 @@ class trace: | |||||
ivars.append(info.varnode) | ivars.append(info.varnode) | ||||
if isinstance(op, BackwardGraph): | |||||
ovars = G.apply_backward_varnode(op, *ivars) | |||||
else: | |||||
ovars = G.apply_normal_varnode(op, *ivars) | |||||
ovars = G.apply_normal_varnode(op, *ivars) | |||||
if require_links and len(ovars) > 0: | if require_links and len(ovars) > 0: | ||||
io_links = (ovars[0],) | io_links = (ovars[0],) | ||||
@@ -805,14 +802,11 @@ class trace: | |||||
name=info.name, | name=info.name, | ||||
) | ) | ||||
ivars.append(h2v[h]) | ivars.append(h2v[h]) | ||||
if isinstance(op, BackwardGraph): | |||||
ovars = G.apply_backward_varnode(op, *ivars) | |||||
else: | |||||
if isinstance(op, BatchNorm): | |||||
assert ( | |||||
op.fwd_mode == BatchNorm.FwdMode.INFERENCE | |||||
), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?" | |||||
ovars = G.apply_normal_varnode(op, *ivars) | |||||
if isinstance(op, BatchNorm): | |||||
assert ( | |||||
op.fwd_mode == BatchNorm.FwdMode.INFERENCE | |||||
), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?" | |||||
ovars = G.apply_normal_varnode(op, *ivars) | |||||
AutoNaming.record_opnode(ovars[0].op) | AutoNaming.record_opnode(ovars[0].op) | ||||
@@ -1088,10 +1082,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||||
ivars[0] = opnode.outputs[0] | ivars[0] = opnode.outputs[0] | ||||
active_trace._lazy_eval_links = (ivars[0],) | active_trace._lazy_eval_links = (ivars[0],) | ||||
if isinstance(op, BackwardGraph): | |||||
ovars = G.apply_backward_varnode(op, *ivars) | |||||
else: | |||||
ovars = G.apply_normal_varnode(op, *ivars) | |||||
ovars = G.apply_normal_varnode(op, *ivars) | |||||
outputs = [RawTensor(o) for o in ovars] | outputs = [RawTensor(o) for o in ovars] | ||||
if require_links: | if require_links: | ||||
@@ -75,9 +75,9 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( | |||||
input_requires_grad[i] = python::input_requires_grad(ctx, i); | input_requires_grad[i] = python::input_requires_grad(ctx, i); | ||||
} | } | ||||
std::shared_ptr<OptimizedBackwardGraphResult> ret; | std::shared_ptr<OptimizedBackwardGraphResult> ret; | ||||
auto bg = proxy_graph_detail::make_backward_graph( | |||||
auto bg = OpDef::make_backward_graph( | |||||
*ctx.op, inputs, input_requires_grad, output_has_grad); | *ctx.op, inputs, input_requires_grad, output_has_grad); | ||||
if (bg.backward) { | |||||
if (!bg.backward.empty()) { | |||||
ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | ||||
} | } | ||||
backward_graph_cache.emplace(key, ret); | backward_graph_cache.emplace(key, ret); | ||||
@@ -112,7 +112,7 @@ struct BackwardGraphWithClosure { | |||||
size_t count = std::count_if(save_for_backward.begin(), | size_t count = std::count_if(save_for_backward.begin(), | ||||
save_for_backward.end(), | save_for_backward.end(), | ||||
ranges::identity{}); | ranges::identity{}); | ||||
if (backward_graph->precomp) { | |||||
if (!backward_graph->precomp.empty()) { | |||||
auto&& irng = ranges::span(ctx.args, ctx.nargs); | auto&& irng = ranges::span(ctx.args, ctx.nargs); | ||||
auto&& orng = views::transform(outputs, [](auto&& i){return i.get();}); | auto&& orng = views::transform(outputs, [](auto&& i){return i.get();}); | ||||
auto precomp = apply(backward_graph->precomp, views::concat(irng, orng)); | auto precomp = apply(backward_graph->precomp, views::concat(irng, orng)); | ||||
@@ -30,26 +30,14 @@ using namespace imperative; | |||||
using namespace interpreter; | using namespace interpreter; | ||||
namespace { | |||||
std::optional<std::tuple<std::shared_ptr<OpDef>, std::vector<bool>, std::vector<bool>>> | |||||
make_backward_graph( | |||||
const OpDef& opdef, std::vector<LogicalTensorDesc> inputs, | |||||
std::vector<bool> input_requires_grad, | |||||
std::vector<bool> output_has_grad) { | |||||
auto res = OpDef::make_backward_graph(opdef, | |||||
SmallVector<LogicalTensorDesc>(inputs.begin(), inputs.end()), | |||||
SmallVector<bool>(input_requires_grad.begin(), input_requires_grad.end()), | |||||
SmallVector<bool>(output_has_grad.begin(), output_has_grad.end())); | |||||
if (res.backward) { | |||||
return std::optional<std::tuple<std::shared_ptr<OpDef>, std::vector<bool>, std::vector<bool>>>{ | |||||
std::in_place, res.backward, res.save_for_backward, res.input_has_grad}; | |||||
} else { | |||||
return {}; | |||||
} | |||||
} | |||||
} // namespace | |||||
void init_imperative_rt(py::module m) { | void init_imperative_rt(py::module m) { | ||||
m.def("make_backward_graph", &make_backward_graph); | |||||
auto make_backward_graph = []( | |||||
const OpDef& def, | |||||
const SmallVector<LogicalTensorDesc>& inputs, | |||||
const SmallVector<bool>& input_requires_grad, | |||||
const SmallVector<bool>& output_has_grad){ | |||||
auto result = OpDef::make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | |||||
return std::make_tuple("backward_graph", result.save_for_backward, result.input_has_grad); | |||||
}; | |||||
m.def("make_backward_graph", make_backward_graph); | |||||
} | } |
@@ -367,42 +367,6 @@ void _init_py_op_def(py::module m) { | |||||
} | } | ||||
/*********** begin of hand-write opdefs **************/ | /*********** begin of hand-write opdefs **************/ | ||||
PyOpDefBegin(BackwardGraph) // {{ | |||||
// }; | |||||
PyOpDefEnd(BackwardGraph) | |||||
void _init_py_backward_graph(py::module m) { | |||||
using py_op = PyOp(BackwardGraph); | |||||
auto& py_type = PyOpType(BackwardGraph); | |||||
py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||||
py_type.tp_name = "megengine.core._imperative_rt.ops.BackwardGraph"; | |||||
py_type.tp_basicsize = sizeof(PyOp(BackwardGraph)); | |||||
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||||
py_type.tp_doc = "BackwardGraph"; | |||||
py_type.tp_base = &PyOpType(OpDef); | |||||
py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||||
py_type.tp_new = py_new_generic<py_op>; | |||||
mgb_assert(PyType_Ready(&py_type) >= 0); | |||||
// FIXME: rewrite interpret function in cpython instead wrap directly by pybind11::cppfunction | |||||
auto interpret = py::cpp_function( | |||||
[](OpDef& self, py::object pyf, py::object pyc, | |||||
const mgb::SmallVector<py::object>& inputs) { | |||||
auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) { | |||||
return py::cast<mgb::SmallVector<py::object>>(pyf(op.shared_from_this(), inputs)); | |||||
}; | |||||
auto c = [pyc](const TensorPtr& tensor) { | |||||
return pyc(tensor->dev_tensor()); | |||||
}; | |||||
return self.cast_final_safe<BackwardGraph>().graph().interpret<py::object>(f, c, inputs); | |||||
}); | |||||
mgb_assert(PyDict_SetItemString( | |||||
py_type.tp_dict, "interpret", interpret.release().ptr()) >= 0); | |||||
PyType_Modified(&py_type); | |||||
m.add_object("BackwardGraph", reinterpret_cast<PyObject*>(&py_type)); | |||||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second); | |||||
} | |||||
struct PyOpBase : PyOpDef { | struct PyOpBase : PyOpDef { | ||||
static PyTypeObject py_type; | static PyTypeObject py_type; | ||||
@@ -496,7 +460,6 @@ FOR_EACH_BIT_COMBINED_ENUM_PARAM(BIT_COMBINED_ENUM_CASTER_IMPL) | |||||
void init_ops(py::module m) { | void init_ops(py::module m) { | ||||
_init_py_op_def(m); | _init_py_op_def(m); | ||||
_init_py_backward_graph(m); | |||||
_init_py_op_base(m); | _init_py_op_base(m); | ||||
INIT_ALL_OP(m) | INIT_ALL_OP(m) | ||||
@@ -156,9 +156,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||||
ctx.args = &tensors[0]; | ctx.args = &tensors[0]; | ||||
ctx.nargs = nargs; | ctx.nargs = nargs; | ||||
ctx.pytype = pytype; | ctx.pytype = pytype; | ||||
if (ctx.op->same_type<BackwardGraph>()) { | |||||
ctx.backward = true; | |||||
} | |||||
if (py::isinstance<PySymbolVar>(py::handle(args[0]))){ | if (py::isinstance<PySymbolVar>(py::handle(args[0]))){ | ||||
SmallVector<cg::VarNode*> vinputs(nargs); | SmallVector<cg::VarNode*> vinputs(nargs); | ||||
@@ -248,31 +248,53 @@ apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) { | |||||
return apply(ctx); | return apply(ctx); | ||||
} | } | ||||
template <typename T> | |||||
auto apply(std::shared_ptr<OpDef> op, T&& tensors) | |||||
-> std::enable_if_t<std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>, | |||||
apply_result_t> { | |||||
inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) { | |||||
ApplyContext ctx; | ApplyContext ctx; | ||||
ctx.op = std::move(op); | ctx.op = std::move(op); | ||||
ctx.nargs = tensors.size(); | |||||
Tensor* args[ctx.nargs]; | |||||
ctx.nargs = nargs; | |||||
ctx.args = args; | ctx.args = args; | ||||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
args[i] = resolve_arrow(tensors[i]); | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
ctx.flags |= args[i]->m_flags; | ctx.flags |= args[i]->m_flags; | ||||
} | } | ||||
return apply(ctx); | return apply(ctx); | ||||
} | } | ||||
inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) { | |||||
ApplyContext ctx; | |||||
ctx.op = std::move(op); | |||||
ctx.nargs = nargs; | |||||
ctx.args = args; | |||||
template <typename T> | |||||
auto apply(std::shared_ptr<OpDef> op, T&& tensors) | |||||
-> std::enable_if_t<std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>, | |||||
apply_result_t> { | |||||
size_t nargs = tensors.size(); | |||||
Tensor* args[nargs]; | |||||
for (size_t i = 0; i < nargs; ++i) { | for (size_t i = 0; i < nargs; ++i) { | ||||
ctx.flags |= args[i]->m_flags; | |||||
args[i] = resolve_arrow(tensors[i]); | |||||
} | } | ||||
return apply(ctx); | |||||
return apply(op, args, nargs); | |||||
} | |||||
inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) { | |||||
SmallVector<std::shared_ptr<Tensor>> inputs; | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
inputs.push_back(args[i]->shared_from_this()); | |||||
} | |||||
auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<std::shared_ptr<Tensor>> inputs) { | |||||
return apply(op, inputs); | |||||
}; | |||||
auto const_functor = [](imperative::TensorPtr value) { | |||||
return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor())); | |||||
}; | |||||
return graph.apply(inputs, apply_functor, const_functor); | |||||
} | |||||
template <typename T> | |||||
auto apply(Subgraph graph, T&& tensors) | |||||
-> std::enable_if_t<std::is_same_v<decltype(tensors[0]), Tensor*>, | |||||
apply_result_t> { | |||||
size_t nargs = tensors.size(); | |||||
Tensor* args[nargs]; | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
args[i] = resolve_arrow(tensors[i]); | |||||
} | |||||
return apply(graph, args, nargs); | |||||
} | } | ||||
void init_tensor(pybind11::module); | void init_tensor(pybind11::module); | ||||
@@ -22,7 +22,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||||
apply_result_t outputs; | apply_result_t outputs; | ||||
if (ctx.backward) { | if (ctx.backward) { | ||||
// call megbrain_graph.py apply(BackwardGraph, *args) | |||||
// reach here when compiled=True | |||||
auto args = py::tuple(ctx.nargs + 1); | auto args = py::tuple(ctx.nargs + 1); | ||||
args[0] = py::cast(ctx.op); | args[0] = py::cast(ctx.op); | ||||
for (size_t i = 0; i < ctx.nargs; i++) { | for (size_t i = 0; i < ctx.nargs; i++) { | ||||
@@ -18,24 +18,22 @@ using namespace imperative; | |||||
OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src) | OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src) | ||||
: input_has_grad(src.input_has_grad) { | : input_has_grad(src.input_has_grad) { | ||||
if (!src.backward->same_type<BackwardGraph>()) { | |||||
if (src.backward.exprs.size() <= 1) { | |||||
// backward graph only contains a single op | // backward graph only contains a single op | ||||
backward = src.backward; | backward = src.backward; | ||||
save_for_backward = src.save_for_backward; | save_for_backward = src.save_for_backward; | ||||
return; | return; | ||||
} | } | ||||
save_for_backward.resize(src.save_for_backward.size(), false); | save_for_backward.resize(src.save_for_backward.size(), false); | ||||
precomp.reset(new BackwardGraph); | |||||
backward.reset(new BackwardGraph); | |||||
auto&& graph = src.backward->cast_final_safe<BackwardGraph>().graph(); | |||||
auto&& graph = src.backward; | |||||
auto&& mask = src.save_for_backward; | auto&& mask = src.save_for_backward; | ||||
size_t input_size = src.input_has_grad.size(); | size_t input_size = src.input_has_grad.size(); | ||||
size_t output_size = (mask.size() - input_size) / 2; | size_t output_size = (mask.size() - input_size) / 2; | ||||
mgb_assert(input_size + output_size * 2 == mask.size()); | mgb_assert(input_size + output_size * 2 == mask.size()); | ||||
auto& fgraph = precomp->cast_final<BackwardGraph>().graph(); | |||||
auto& bgraph = backward->cast_final<BackwardGraph>().graph(); | |||||
auto& fgraph = precomp; | |||||
auto& bgraph = backward; | |||||
// optimization: move ops (e.g. GetVarShape) to forward to | // optimization: move ops (e.g. GetVarShape) to forward to | ||||
// reduce memory footprint | // reduce memory footprint | ||||
@@ -113,6 +111,6 @@ OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphRe | |||||
} | } | ||||
if (!fgraph.outputs.size()) { | if (!fgraph.outputs.size()) { | ||||
precomp.reset(); | |||||
precomp = {}; | |||||
} | } | ||||
} | } |
@@ -911,8 +911,7 @@ auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle { | |||||
op_type == RemoteSend::typeinfo() || | op_type == RemoteSend::typeinfo() || | ||||
op_type == CollectiveComm::typeinfo() || | op_type == CollectiveComm::typeinfo() || | ||||
op_type == opr::InputCallback::typeinfo() || | op_type == opr::InputCallback::typeinfo() || | ||||
op_type == opr::OutputCallback::typeinfo() || | |||||
op_type == BackwardGraph::typeinfo()) { | |||||
op_type == opr::OutputCallback::typeinfo()) { | |||||
return m_commands.end(); | return m_commands.end(); | ||||
} | } | ||||
} else if constexpr (std::is_same_v<T, GetValue>) { | } else if constexpr (std::is_same_v<T, GetValue>) { | ||||
@@ -10,6 +10,9 @@ | |||||
*/ | */ | ||||
#include "megbrain/imperative/op_def.h" | #include "megbrain/imperative/op_def.h" | ||||
#include <sstream> | |||||
#include "megbrain/imperative/ops/opr_attr.h" | #include "megbrain/imperative/ops/opr_attr.h" | ||||
#include "./op_trait.h" | #include "./op_trait.h" | ||||
@@ -117,6 +120,67 @@ const std::string OpDef::make_name() const { | |||||
return m_scope + "." + trait()->make_name(*this); | return m_scope + "." + trait()->make_name(*this); | ||||
} | } | ||||
std::string Subgraph::repr() const { | |||||
std::ostringstream buf; | |||||
buf << "("; | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
if (i > 0) buf << ", "; | |||||
buf << "%" << inputs[i]; | |||||
} | |||||
buf << ") => {\n"; | |||||
auto fmt_const = [](size_t i, const TensorPtr& t) { | |||||
if (t->shape().ndim == 1 && t->shape()[0] == 1) { | |||||
auto&& v = t->get_value(); | |||||
if (v.dtype() == dtype::Float32{}) { | |||||
return std::to_string(*v.ptr<dt_float32>()); | |||||
} else if (v.dtype() == dtype::Int32{}) { | |||||
return std::to_string(*v.ptr<int32_t>()); | |||||
} | |||||
} | |||||
return std::string("%c") + std::to_string(i); | |||||
}; | |||||
std::unordered_map<size_t, std::string> const_reps; | |||||
for (auto&& [i, t] : constants) { | |||||
const_reps.emplace(i, fmt_const(i, t)); | |||||
} | |||||
for (auto& [op, ins, outs] : exprs) { | |||||
buf << " "; | |||||
if (outs.size()) { | |||||
for (size_t i = 0; i < outs.size(); ++i) { | |||||
if (i > 0) buf << ", "; | |||||
buf << "%" << outs[i]; | |||||
} | |||||
buf << " = "; | |||||
} | |||||
if (auto* p = op->try_cast_final<OprAttr>()) { | |||||
buf << p->type; | |||||
} else { | |||||
buf << op->dyn_typeinfo()->name; | |||||
} | |||||
for (size_t i : ins) { | |||||
buf << " "; | |||||
auto&& it = const_reps.find(i); | |||||
if (it != const_reps.end()) { | |||||
buf << it->second; | |||||
} else { | |||||
buf << "%" << i; | |||||
} | |||||
} | |||||
buf << "\n"; | |||||
} | |||||
buf << " "; | |||||
if (outputs.size()) { | |||||
for (size_t i = 0; i < outputs.size(); ++i) { | |||||
if (i > 0) buf << ", "; | |||||
buf << "%" << outputs[i]; | |||||
} | |||||
} else { | |||||
buf << "()"; | |||||
} | |||||
buf << "\n}\n"; | |||||
return buf.str(); | |||||
} | |||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -19,147 +19,6 @@ | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
SmallVector<TensorPtr> | |||||
BackwardGraph::InternalGraph::apply( | |||||
const SmallVector<TensorPtr>& inputs) const { | |||||
return interpret<TensorPtr>( | |||||
&OpDef::apply_on_physical_tensor, | |||||
[](const TensorPtr& x) {return x;}, | |||||
inputs); | |||||
} | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> BackwardGraph::InternalGraph::infer_attrs( | |||||
const SmallVector<LogicalTensorDesc>& inputs) const { | |||||
using TensorAttr = LogicalTensorDesc; | |||||
ThinHashMap<size_t, TensorAttr> node2attr; | |||||
auto&& input_nodes = this->inputs; | |||||
auto&& output_nodes = this->outputs; | |||||
mgb_assert(inputs.size() == input_nodes.size()); | |||||
for (size_t i = 0; i < inputs.size(); ++ i) { | |||||
node2attr[input_nodes[i]] = inputs[i]; | |||||
} | |||||
for (auto &&i : constants) { | |||||
auto* value = i.second->try_get_value(); | |||||
mgb_assert(value); | |||||
node2attr[i.first] = TensorAttr{ | |||||
i.second->layout(), i.second->comp_node(), | |||||
value->proxy_to_default_cpu()}; | |||||
} | |||||
bool validated = true; | |||||
for (size_t i = 0; i < exprs.size(); ++ i) { | |||||
auto&& [expr_op, expr_inps, expr_oups] = exprs[i]; | |||||
SmallVector<TensorAttr> expr_input_descs; | |||||
for (auto &&inp : expr_inps) { | |||||
expr_input_descs.push_back(node2attr.at(inp)); | |||||
} | |||||
auto [expr_output_descs, expr_validated] = OpDef::infer_output_attrs_fallible( | |||||
*expr_op, expr_input_descs); | |||||
validated = validated && expr_validated; | |||||
mgb_assert(expr_output_descs.size() == expr_oups.size()); | |||||
for (size_t i = 0; i < expr_output_descs.size(); ++ i) { | |||||
node2attr[expr_oups[i]] = expr_output_descs[i]; | |||||
} | |||||
} | |||||
SmallVector<TensorAttr> ret; | |||||
for (auto &&i : output_nodes) { | |||||
ret.push_back(node2attr.at(i)); | |||||
} | |||||
return {ret, validated}; | |||||
} | |||||
std::string BackwardGraph::InternalGraph::repr() { | |||||
std::ostringstream buf; | |||||
buf << "("; | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
if (i > 0) buf << ", "; | |||||
buf << "%" << inputs[i]; | |||||
} | |||||
buf << ") => {\n"; | |||||
auto fmt_const = [](size_t i, TensorPtr& t) { | |||||
if (t->shape().ndim == 1 && t->shape()[0] == 1) { | |||||
auto&& v = t->get_value(); | |||||
if (v.dtype() == dtype::Float32{}) { | |||||
return std::to_string(*v.ptr<dt_float32>()); | |||||
} else if (v.dtype() == dtype::Int32{}) { | |||||
return std::to_string(*v.ptr<int32_t>()); | |||||
} | |||||
} | |||||
return std::string("%c") + std::to_string(i); | |||||
}; | |||||
std::unordered_map<size_t, std::string> const_reps; | |||||
for (auto&& [i, t] : constants) { | |||||
const_reps.emplace(i, fmt_const(i, t)); | |||||
} | |||||
for (auto& [op, ins, outs] : exprs) { | |||||
buf << " "; | |||||
if (outs.size()) { | |||||
for (size_t i = 0; i < outs.size(); ++i) { | |||||
if (i > 0) buf << ", "; | |||||
buf << "%" << outs[i]; | |||||
} | |||||
buf << " = "; | |||||
} | |||||
if (auto* p = op->try_cast_final<OprAttr>()) { | |||||
buf << p->type; | |||||
} else { | |||||
buf << op->dyn_typeinfo()->name; | |||||
} | |||||
for (size_t i : ins) { | |||||
buf << " "; | |||||
auto&& it = const_reps.find(i); | |||||
if (it != const_reps.end()) { | |||||
buf << it->second; | |||||
} else { | |||||
buf << "%" << i; | |||||
} | |||||
} | |||||
buf << "\n"; | |||||
} | |||||
buf << " "; | |||||
if (outputs.size()) { | |||||
for (size_t i = 0; i < outputs.size(); ++i) { | |||||
if (i > 0) buf << ", "; | |||||
buf << "%" << outputs[i]; | |||||
} | |||||
} else { | |||||
buf << "()"; | |||||
} | |||||
buf << "\n}\n"; | |||||
return buf.str(); | |||||
} | |||||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardGraph); | |||||
namespace { | |||||
SmallVector<TensorPtr> backward_impl( | |||||
const OpDef& backward_graph, | |||||
const SmallVector<TensorPtr>& tensors) { | |||||
return backward_graph.cast_final_safe<BackwardGraph>() | |||||
.graph().apply(tensors); | |||||
} | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_tensor_attrs( | |||||
const OpDef& backward_graph, | |||||
const SmallVector<LogicalTensorDesc> inputs) { | |||||
return backward_graph.cast_final_safe<BackwardGraph>() | |||||
.graph().infer_attrs(inputs); | |||||
} | |||||
std::vector<std::pair<const char*, std::string>> props( | |||||
const OpDef& backward_graph) { | |||||
return {}; | |||||
} | |||||
OP_TRAIT_REG(BackwardGraph, BackwardGraph) | |||||
.apply_on_physical_tensor(backward_impl) | |||||
.infer_output_attrs_fallible(infer_tensor_attrs) | |||||
.props(props) | |||||
.fallback(); | |||||
} // anonymous namespace | |||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -669,8 +669,7 @@ ProxyGraph::make_backward_graph( | |||||
auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo()); | auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo()); | ||||
BackwardGraphResult result; | BackwardGraphResult result; | ||||
auto&& backward = BackwardGraph::make(); | |||||
auto&& igraph = backward->cast_final_safe<BackwardGraph>().graph(); | |||||
auto&& igraph = result.backward; | |||||
size_t nr_backward_graph_inputs = 0; | size_t nr_backward_graph_inputs = 0; | ||||
auto gen_expr = [this, &var2idx, &igraph, &push, &fwd, | auto gen_expr = [this, &var2idx, &igraph, &push, &fwd, | ||||
@@ -682,7 +681,7 @@ ProxyGraph::make_backward_graph( | |||||
++ nr_backward_graph_inputs; | ++ nr_backward_graph_inputs; | ||||
push(op->output(0)); | push(op->output(0)); | ||||
} else { | } else { | ||||
std::vector<size_t> inputs, outputs; | |||||
SmallVector<size_t> inputs, outputs; | |||||
for (auto &&i : op->input()) { | for (auto &&i : op->input()) { | ||||
if (i->owner_opr() == fwd) { | if (i->owner_opr() == fwd) { | ||||
if (var2idx.find(i) == var2idx.end()) { | if (var2idx.find(i) == var2idx.end()) { | ||||
@@ -695,7 +694,7 @@ ProxyGraph::make_backward_graph( | |||||
for (auto &&i : op->usable_output()) { | for (auto &&i : op->usable_output()) { | ||||
outputs.push_back(push(i)); | outputs.push_back(push(i)); | ||||
} | } | ||||
igraph.exprs.emplace_back(OpDef::make_from_op_node(op), inputs, outputs); | |||||
igraph.exprs.push_back({OpDef::make_from_op_node(op), inputs, outputs}); | |||||
} | } | ||||
}; | }; | ||||
@@ -770,36 +769,6 @@ ProxyGraph::make_backward_graph( | |||||
write_inputs(outputs); | write_inputs(outputs); | ||||
write_inputs(output_grads); | write_inputs(output_grads); | ||||
mgb_assert(igraph.inputs.size() == nr_backward_graph_inputs); | mgb_assert(igraph.inputs.size() == nr_backward_graph_inputs); | ||||
auto treat_as_single = [](auto&& igraph) { | |||||
if (igraph.exprs.size() != 1) | |||||
return false; | |||||
auto&& expr = igraph.exprs[0]; | |||||
auto&& expr_inputs = std::get<1>(expr); | |||||
if (expr_inputs.size() != igraph.inputs.size()) { | |||||
return false; | |||||
} | |||||
for (size_t i = 0; i < expr_inputs.size(); ++ i) { | |||||
if (igraph.inputs[i] != expr_inputs[i]) { | |||||
return false; | |||||
} | |||||
} | |||||
auto&& expr_outputs = std::get<2>(expr); | |||||
if (expr_outputs.size() != igraph.outputs.size()) { | |||||
return false; | |||||
} | |||||
for (size_t i = 0; i < expr_outputs.size(); ++ i) { | |||||
if (igraph.outputs[i] != expr_outputs[i]) { | |||||
return false; | |||||
} | |||||
} | |||||
return true; | |||||
}; | |||||
if (treat_as_single(igraph)) { | |||||
result.backward = std::get<0>(igraph.exprs[0]); | |||||
} else { | |||||
result.backward = backward; | |||||
} | |||||
return result; | return result; | ||||
} | } | ||||
@@ -65,7 +65,7 @@ private: | |||||
class InputPlaceholder; | class InputPlaceholder; | ||||
struct ProxyGraphInst; | struct ProxyGraphInst; | ||||
struct GradGraph; | struct GradGraph; | ||||
struct CurOprGuard; | |||||
class CurOprGuard; | |||||
void reset(); | void reset(); | ||||
@@ -15,7 +15,7 @@ namespace mgb::imperative::proxy_graph { | |||||
// e.g. friend class mgb::imperative::proxy_graph::ProxyGraph | // e.g. friend class mgb::imperative::proxy_graph::ProxyGraph | ||||
struct ProxyGraph { | struct ProxyGraph { | ||||
struct InputPlaceholder; | struct InputPlaceholder; | ||||
struct MiniGraph; | |||||
class MiniGraph; | |||||
}; | }; | ||||
} // namespace mgb::imperative::proxy_graph | } // namespace mgb::imperative::proxy_graph |
@@ -75,30 +75,7 @@ apply_on_physical_tensor(const OpDef& def, | |||||
auto output_descs = infer_output_attrs(def, inputs); | auto output_descs = infer_output_attrs(def, inputs); | ||||
SmallVector<TensorPtr> outputs(output_descs.size(), {}); | SmallVector<TensorPtr> outputs(output_descs.size(), {}); | ||||
for (size_t i = 0; i < outputs.size(); i++) { | for (size_t i = 0; i < outputs.size(); i++) { | ||||
auto& output = outputs[i]; | |||||
auto& output_desc = output_descs[i]; | |||||
if (def.same_type<Elemwise>()) { | |||||
for (size_t j = 0; j < inputs.size(); j++) { | |||||
// TODO: reindex inputs to support inplace exprs like 'y = x op x'. | |||||
auto& input = inputs[j]; | |||||
// Because we pass inputs by value, if input and input->blob() are all unique, | |||||
// their ownerships are on the stack, thus we can reuse them safely. | |||||
// @see: interpreter::intl::ChannelImpl::process_one_task | |||||
if (input.unique() && input->blob().unique() && input->blob()->storage().unique() && | |||||
input->layout().dtype == output_desc.layout.dtype && | |||||
input->layout().eq_layout(output_desc.layout) && | |||||
input->comp_node() == output_desc.comp_node) { | |||||
static std::atomic_llong inplace_count = 0; | |||||
mgb_log_debug("do inplace for elemwise, layout: %s, count: %lld", | |||||
output_desc.layout.to_string().c_str(), ++inplace_count); | |||||
output = Tensor::make(input->blob(), input->layout(), input->offset()); | |||||
break; | |||||
} | |||||
} | |||||
} | |||||
if (!output) { | |||||
output = Tensor::make(output_desc.layout, output_desc.comp_node); | |||||
} | |||||
outputs[i] = Tensor::make(output_descs[i].layout, output_descs[i].comp_node); | |||||
} | } | ||||
exec(def, inputs, outputs); | exec(def, inputs, outputs); | ||||
auto async_error = ProxyGraph::get_async_error(); | auto async_error = ProxyGraph::get_async_error(); | ||||
@@ -14,10 +14,10 @@ | |||||
namespace mgb::imperative { | namespace mgb::imperative { | ||||
struct OptimizedBackwardGraphResult { | struct OptimizedBackwardGraphResult { | ||||
std::shared_ptr<OpDef> precomp; | |||||
std::shared_ptr<OpDef> backward; | |||||
std::vector<bool> save_for_backward; | |||||
std::vector<bool> input_has_grad; | |||||
Subgraph precomp; | |||||
Subgraph backward; | |||||
SmallVector<bool> save_for_backward; | |||||
SmallVector<bool> input_has_grad; | |||||
OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph); | OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph); | ||||
}; | }; | ||||
@@ -26,10 +26,60 @@ enum DispatchMode { | |||||
KERNEL = 1 | KERNEL = 1 | ||||
}; | }; | ||||
using SharedOp = std::shared_ptr<OpDef>; | |||||
template <typename T> | |||||
struct Expr { | |||||
std::shared_ptr<OpDef> op; | |||||
SmallVector<T> inputs; | |||||
SmallVector<T> outputs; | |||||
}; | |||||
struct Subgraph { | |||||
SmallVector<size_t> inputs; | |||||
SmallVector<std::pair<size_t, TensorPtr>> constants; | |||||
SmallVector<size_t> outputs; | |||||
SmallVector<Expr<size_t>> exprs; | |||||
template <typename T, typename F, typename C> | |||||
SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const { | |||||
std::unordered_map<size_t, T> idx2var; | |||||
mgb_assert(inputs.size() == input_vars.size(), "input size mismatch"); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
idx2var[inputs[i]] = input_vars[i]; | |||||
} | |||||
for (auto&& [idx, val]: constants) { | |||||
idx2var[idx] = c(val); | |||||
} | |||||
for (auto& expr: exprs) { | |||||
SmallVector<T> expr_inputs; | |||||
for (auto idx: expr.inputs) { | |||||
expr_inputs.push_back(idx2var[idx]); | |||||
} | |||||
SmallVector<T> expr_outputs = f(expr.op, std::move(expr_inputs)); | |||||
mgb_assert(expr_outputs.size() == expr.outputs.size(), "output size mismatch"); | |||||
for (size_t i = 0; i < expr_outputs.size(); ++i) { | |||||
idx2var[expr.outputs[i]] = expr_outputs[i]; | |||||
} | |||||
} | |||||
SmallVector<T> output_vars; | |||||
for (auto idx: outputs) { | |||||
output_vars.push_back(idx2var[idx]); | |||||
} | |||||
return output_vars; | |||||
} | |||||
bool empty() const { | |||||
return outputs.size() == 0; | |||||
} | |||||
std::string repr() const; | |||||
}; | |||||
struct BackwardGraphResult { | struct BackwardGraphResult { | ||||
std::shared_ptr<OpDef> backward; | |||||
std::vector<bool> save_for_backward; | |||||
std::vector<bool> input_has_grad; | |||||
Subgraph backward; | |||||
SmallVector<bool> save_for_backward; | |||||
SmallVector<bool> input_has_grad; | |||||
}; | }; | ||||
class OpDef : public Hashable, | class OpDef : public Hashable, | ||||
@@ -15,92 +15,6 @@ | |||||
namespace mgb { | namespace mgb { | ||||
namespace imperative { | namespace imperative { | ||||
// a special OpDef used for taking gradient on physical tensor | |||||
struct BackwardGraph final : public OpDefImplBase<BackwardGraph> { | |||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||||
public: | |||||
struct InternalGraph { | |||||
// op, inputs, outputs | |||||
using Expr = std::tuple<std::shared_ptr<OpDef>, | |||||
std::vector<size_t>, std::vector<size_t>>; | |||||
std::vector<Expr> exprs; | |||||
// index array of input nodes | |||||
std::vector<size_t> inputs; | |||||
// index array of output nodes | |||||
std::vector<size_t> outputs; | |||||
// pair of (node index, correspending constant) | |||||
std::vector<std::pair<size_t, TensorPtr>> constants; | |||||
SmallVector<TensorPtr> | |||||
apply(const SmallVector<TensorPtr>& inputs) const; | |||||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_attrs( | |||||
const SmallVector<LogicalTensorDesc>& inputs) const; | |||||
template <typename T, typename F, typename C> | |||||
SmallVector<T> interpret(F&& f, C&& c, const SmallVector<T>& inputs) const { | |||||
ThinHashMap<size_t, T> node2tensor; | |||||
auto&& input_nodes = this->inputs; | |||||
mgb_assert(inputs.size() == input_nodes.size()); | |||||
for (size_t i = 0; i < inputs.size(); ++ i) { | |||||
node2tensor[input_nodes[i]] = inputs[i]; | |||||
} | |||||
for (auto &&i : constants) { | |||||
node2tensor[i.first] = c(i.second); | |||||
} | |||||
for (size_t i = 0; i < exprs.size(); ++ i) { | |||||
auto&& expr = exprs[i]; | |||||
SmallVector<T> inputs; | |||||
for (auto &&in : std::get<1>(expr)) { | |||||
inputs.push_back(node2tensor.at(in)); | |||||
} | |||||
auto&& outputs = f(*std::get<0>(expr), std::move(inputs)); | |||||
auto&& output_nodes = std::get<2>(expr); | |||||
mgb_assert(outputs.size() == output_nodes.size()); | |||||
for (size_t i = 0; i < outputs.size(); ++ i) { | |||||
node2tensor[output_nodes[i]] = std::move(outputs[i]); | |||||
} | |||||
} | |||||
SmallVector<T> ret; | |||||
for (auto &&i : outputs) { | |||||
ret.push_back(node2tensor.at(i)); | |||||
} | |||||
return ret; | |||||
} | |||||
std::string repr(); | |||||
}; | |||||
const InternalGraph& graph() const { | |||||
return m_graph; | |||||
} | |||||
InternalGraph& graph() { | |||||
return m_graph; | |||||
} | |||||
bool is_same_st(const Hashable& rhs) const override { | |||||
if (!rhs.same_type<BackwardGraph>()) { | |||||
return false; | |||||
} | |||||
auto& other = rhs.cast_final_safe<BackwardGraph>(); | |||||
if (this == &other) { | |||||
return true; | |||||
} | |||||
// FIXME | |||||
return false; | |||||
} | |||||
std::string repr() {return m_graph.repr();} | |||||
private: | |||||
InternalGraph m_graph; | |||||
}; | |||||
} // namespace imperative | } // namespace imperative | ||||
} // namespace mgb | } // namespace mgb | ||||
@@ -29,7 +29,7 @@ struct GenericPyOp final : OpDefImplBase<GenericPyOp> { | |||||
} | } | ||||
bool is_same_st(const Hashable& rhs) const override { | bool is_same_st(const Hashable& rhs) const override { | ||||
return obj.equal(static_cast<const GenericPyOp&>(rhs).obj); | |||||
return obj.equal(rhs.cast_final<GenericPyOp>().obj); | |||||
} | } | ||||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | MGB_DYN_TYPE_OBJ_FINAL_DECL; | ||||
@@ -75,6 +75,10 @@ T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, cons | |||||
return ret; | return ret; | ||||
} | } | ||||
SmallVector<TensorPtr> apply_shared_on_physical_tensor(std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs) { | |||||
return OpDef::apply_on_physical_tensor(*def, inputs); | |||||
} | |||||
TEST(TestImperative, BackwardGraphBasic) { | TEST(TestImperative, BackwardGraphBasic) { | ||||
HostTensorGenerator<> gen; | HostTensorGenerator<> gen; | ||||
SmallVector<HostTensorND> hvs; | SmallVector<HostTensorND> hvs; | ||||
@@ -114,7 +118,11 @@ TEST(TestImperative, BackwardGraphBasic) { | |||||
} | } | ||||
} | } | ||||
inputs.clear(); | inputs.clear(); | ||||
auto input_grads = OpDef::apply_on_physical_tensor(*(result.backward), backward_graph_inputs); | |||||
auto input_grads = result.backward.apply( | |||||
backward_graph_inputs, | |||||
apply_shared_on_physical_tensor, | |||||
[&](auto&& x){ return x; } | |||||
); | |||||
mgb_assert(input_grads.size() == input_has_grad.size()); | mgb_assert(input_grads.size() == input_has_grad.size()); | ||||
for (size_t i = 0; i < input_has_grad.size(); ++ i) { | for (size_t i = 0; i < input_has_grad.size(); ++ i) { | ||||
mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i])); | mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i])); | ||||
@@ -164,7 +172,11 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||||
} | } | ||||
} | } | ||||
inputs.clear(); | inputs.clear(); | ||||
auto input_grads = OpDef::apply_on_physical_tensor(*(result.backward), backward_graph_inputs); | |||||
auto input_grads = result.backward.apply( | |||||
backward_graph_inputs, | |||||
apply_shared_on_physical_tensor, | |||||
[&](auto&& x){ return x; } | |||||
); | |||||
mgb_assert(input_grads.size() == input_has_grad.size()); | mgb_assert(input_grads.size() == input_has_grad.size()); | ||||
for (size_t i = 0; i < input_has_grad.size(); ++ i) { | for (size_t i = 0; i < input_has_grad.size(); ++ i) { | ||||
mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i])); | mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i])); | ||||
@@ -224,9 +236,17 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||||
auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0]; | auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0]; | ||||
auto backward_graph_inputs = prepare_backward_graph_inputs<SmallVector<TensorPtr>>(bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | auto backward_graph_inputs = prepare_backward_graph_inputs<SmallVector<TensorPtr>>(bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | ||||
auto grads = expand_grads(bg, OpDef::apply_on_physical_tensor(*bg.backward, backward_graph_inputs)); | |||||
auto grads = expand_grads(bg, bg.backward.apply( | |||||
backward_graph_inputs, | |||||
apply_shared_on_physical_tensor, | |||||
[&](auto&& x){ return x; } | |||||
)); | |||||
auto precomp = OpDef::apply_on_physical_tensor(*obg.precomp, {a_tn, b_tn, c_tn}); | |||||
auto precomp = obg.precomp.apply( | |||||
SmallVector<TensorPtr>{a_tn, b_tn, c_tn}, | |||||
apply_shared_on_physical_tensor, | |||||
[&](auto&& x){ return x; } | |||||
); | |||||
ASSERT_EQ(precomp.size(), 2); | ASSERT_EQ(precomp.size(), 2); | ||||
ASSERT_EQ(precomp[0]->shape().ndim, 1); | ASSERT_EQ(precomp[0]->shape().ndim, 1); | ||||
ASSERT_LE(precomp[0]->shape()[0], 2); | ASSERT_LE(precomp[0]->shape()[0], 2); | ||||
@@ -234,7 +254,11 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||||
ASSERT_LE(precomp[1]->shape()[0], 2); | ASSERT_LE(precomp[1]->shape()[0], 2); | ||||
auto backward_inputs = prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | auto backward_inputs = prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | ||||
auto grads2 = expand_grads(obg, OpDef::apply_on_physical_tensor(*obg.backward, backward_inputs)); | |||||
auto grads2 = expand_grads(obg, obg.backward.apply( | |||||
backward_inputs, | |||||
apply_shared_on_physical_tensor, | |||||
[&](auto&& x){ return x; } | |||||
)); | |||||
ASSERT_EQ(grads2.size(), 2); | ASSERT_EQ(grads2.size(), 2); | ||||
MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value()); | MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value()); | ||||