@@ -18,7 +18,6 @@ import numpy as np | |||
from .. import _imperative_rt | |||
from .._imperative_rt import GraphOptimizeOptions | |||
from .._imperative_rt.core2 import apply, set_cpp_apply_backward_varnode | |||
from .._imperative_rt.ops import BackwardGraph | |||
from .._wrap import device as as_device | |||
from ..ops.builtin import OpDef | |||
from .core import TensorBase | |||
@@ -481,21 +480,6 @@ def apply_normal_varnode(op: OpDef, *args: VarNode): | |||
return _wrap(outputs) | |||
def apply_backward_varnode(op: BackwardGraph, *args: VarNode): | |||
assert args | |||
graph = args[0].graph | |||
outputs = op.interpret( | |||
op, | |||
lambda op, args: apply_normal_varnode(op, *args), | |||
graph._make_const_for_backward, | |||
args, | |||
) | |||
return outputs | |||
set_cpp_apply_backward_varnode(apply_backward_varnode) | |||
def input_callback(callback, *args, device=None, dtype=None, shape=None, graph=None): | |||
outputs = _imperative_rt.input_callback( | |||
callback, as_device(device).to_c(), dtype, shape, _unwrap(args), graph=graph | |||
@@ -32,7 +32,7 @@ from ..core._imperative_rt.ops import ( | |||
) | |||
from ..core._trace_option import set_symbolic_shape | |||
from ..core._wrap import device as as_device | |||
from ..core.ops.builtin import BackwardGraph, BatchNorm, OpDef | |||
from ..core.ops.builtin import BatchNorm, OpDef | |||
from ..core.ops.special import Const | |||
from ..core.tensor import megbrain_graph as G | |||
from ..core.tensor.utils import setscalar | |||
@@ -587,10 +587,7 @@ class trace: | |||
ivars.append(info.varnode) | |||
if isinstance(op, BackwardGraph): | |||
ovars = G.apply_backward_varnode(op, *ivars) | |||
else: | |||
ovars = G.apply_normal_varnode(op, *ivars) | |||
ovars = G.apply_normal_varnode(op, *ivars) | |||
if require_links and len(ovars) > 0: | |||
io_links = (ovars[0],) | |||
@@ -805,14 +802,11 @@ class trace: | |||
name=info.name, | |||
) | |||
ivars.append(h2v[h]) | |||
if isinstance(op, BackwardGraph): | |||
ovars = G.apply_backward_varnode(op, *ivars) | |||
else: | |||
if isinstance(op, BatchNorm): | |||
assert ( | |||
op.fwd_mode == BatchNorm.FwdMode.INFERENCE | |||
), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?" | |||
ovars = G.apply_normal_varnode(op, *ivars) | |||
if isinstance(op, BatchNorm): | |||
assert ( | |||
op.fwd_mode == BatchNorm.FwdMode.INFERENCE | |||
), "can not dump BatchNorm in training mode, maybe you forget to do model.eval()?" | |||
ovars = G.apply_normal_varnode(op, *ivars) | |||
AutoNaming.record_opnode(ovars[0].op) | |||
@@ -1088,10 +1082,7 @@ def apply_symbolic_mode(op: OpDef, *args: RawTensor): | |||
ivars[0] = opnode.outputs[0] | |||
active_trace._lazy_eval_links = (ivars[0],) | |||
if isinstance(op, BackwardGraph): | |||
ovars = G.apply_backward_varnode(op, *ivars) | |||
else: | |||
ovars = G.apply_normal_varnode(op, *ivars) | |||
ovars = G.apply_normal_varnode(op, *ivars) | |||
outputs = [RawTensor(o) for o in ovars] | |||
if require_links: | |||
@@ -75,9 +75,9 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( | |||
input_requires_grad[i] = python::input_requires_grad(ctx, i); | |||
} | |||
std::shared_ptr<OptimizedBackwardGraphResult> ret; | |||
auto bg = proxy_graph_detail::make_backward_graph( | |||
auto bg = OpDef::make_backward_graph( | |||
*ctx.op, inputs, input_requires_grad, output_has_grad); | |||
if (bg.backward) { | |||
if (!bg.backward.empty()) { | |||
ret = std::make_shared<OptimizedBackwardGraphResult>(bg); | |||
} | |||
backward_graph_cache.emplace(key, ret); | |||
@@ -112,7 +112,7 @@ struct BackwardGraphWithClosure { | |||
size_t count = std::count_if(save_for_backward.begin(), | |||
save_for_backward.end(), | |||
ranges::identity{}); | |||
if (backward_graph->precomp) { | |||
if (!backward_graph->precomp.empty()) { | |||
auto&& irng = ranges::span(ctx.args, ctx.nargs); | |||
auto&& orng = views::transform(outputs, [](auto&& i){return i.get();}); | |||
auto precomp = apply(backward_graph->precomp, views::concat(irng, orng)); | |||
@@ -30,26 +30,14 @@ using namespace imperative; | |||
using namespace interpreter; | |||
namespace { | |||
std::optional<std::tuple<std::shared_ptr<OpDef>, std::vector<bool>, std::vector<bool>>> | |||
make_backward_graph( | |||
const OpDef& opdef, std::vector<LogicalTensorDesc> inputs, | |||
std::vector<bool> input_requires_grad, | |||
std::vector<bool> output_has_grad) { | |||
auto res = OpDef::make_backward_graph(opdef, | |||
SmallVector<LogicalTensorDesc>(inputs.begin(), inputs.end()), | |||
SmallVector<bool>(input_requires_grad.begin(), input_requires_grad.end()), | |||
SmallVector<bool>(output_has_grad.begin(), output_has_grad.end())); | |||
if (res.backward) { | |||
return std::optional<std::tuple<std::shared_ptr<OpDef>, std::vector<bool>, std::vector<bool>>>{ | |||
std::in_place, res.backward, res.save_for_backward, res.input_has_grad}; | |||
} else { | |||
return {}; | |||
} | |||
} | |||
} // namespace | |||
void init_imperative_rt(py::module m) { | |||
m.def("make_backward_graph", &make_backward_graph); | |||
auto make_backward_graph = []( | |||
const OpDef& def, | |||
const SmallVector<LogicalTensorDesc>& inputs, | |||
const SmallVector<bool>& input_requires_grad, | |||
const SmallVector<bool>& output_has_grad){ | |||
auto result = OpDef::make_backward_graph(def, inputs, input_requires_grad, output_has_grad); | |||
return std::make_tuple("backward_graph", result.save_for_backward, result.input_has_grad); | |||
}; | |||
m.def("make_backward_graph", make_backward_graph); | |||
} |
@@ -367,42 +367,6 @@ void _init_py_op_def(py::module m) { | |||
} | |||
/*********** begin of hand-write opdefs **************/ | |||
PyOpDefBegin(BackwardGraph) // {{ | |||
// }; | |||
PyOpDefEnd(BackwardGraph) | |||
void _init_py_backward_graph(py::module m) { | |||
using py_op = PyOp(BackwardGraph); | |||
auto& py_type = PyOpType(BackwardGraph); | |||
py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||
py_type.tp_name = "megengine.core._imperative_rt.ops.BackwardGraph"; | |||
py_type.tp_basicsize = sizeof(PyOp(BackwardGraph)); | |||
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
py_type.tp_doc = "BackwardGraph"; | |||
py_type.tp_base = &PyOpType(OpDef); | |||
py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||
py_type.tp_new = py_new_generic<py_op>; | |||
mgb_assert(PyType_Ready(&py_type) >= 0); | |||
// FIXME: rewrite interpret function in cpython instead wrap directly by pybind11::cppfunction | |||
auto interpret = py::cpp_function( | |||
[](OpDef& self, py::object pyf, py::object pyc, | |||
const mgb::SmallVector<py::object>& inputs) { | |||
auto f = [pyf](OpDef& op, const mgb::SmallVector<py::object>& inputs) { | |||
return py::cast<mgb::SmallVector<py::object>>(pyf(op.shared_from_this(), inputs)); | |||
}; | |||
auto c = [pyc](const TensorPtr& tensor) { | |||
return pyc(tensor->dev_tensor()); | |||
}; | |||
return self.cast_final_safe<BackwardGraph>().graph().interpret<py::object>(f, c, inputs); | |||
}); | |||
mgb_assert(PyDict_SetItemString( | |||
py_type.tp_dict, "interpret", interpret.release().ptr()) >= 0); | |||
PyType_Modified(&py_type); | |||
m.add_object("BackwardGraph", reinterpret_cast<PyObject*>(&py_type)); | |||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second); | |||
} | |||
struct PyOpBase : PyOpDef { | |||
static PyTypeObject py_type; | |||
@@ -496,7 +460,6 @@ FOR_EACH_BIT_COMBINED_ENUM_PARAM(BIT_COMBINED_ENUM_CASTER_IMPL) | |||
void init_ops(py::module m) { | |||
_init_py_op_def(m); | |||
_init_py_backward_graph(m); | |||
_init_py_op_base(m); | |||
INIT_ALL_OP(m) | |||
@@ -156,9 +156,6 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||
ctx.args = &tensors[0]; | |||
ctx.nargs = nargs; | |||
ctx.pytype = pytype; | |||
if (ctx.op->same_type<BackwardGraph>()) { | |||
ctx.backward = true; | |||
} | |||
if (py::isinstance<PySymbolVar>(py::handle(args[0]))){ | |||
SmallVector<cg::VarNode*> vinputs(nargs); | |||
@@ -248,31 +248,53 @@ apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) { | |||
return apply(ctx); | |||
} | |||
template <typename T> | |||
auto apply(std::shared_ptr<OpDef> op, T&& tensors) | |||
-> std::enable_if_t<std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>, | |||
apply_result_t> { | |||
inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) { | |||
ApplyContext ctx; | |||
ctx.op = std::move(op); | |||
ctx.nargs = tensors.size(); | |||
Tensor* args[ctx.nargs]; | |||
ctx.nargs = nargs; | |||
ctx.args = args; | |||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||
args[i] = resolve_arrow(tensors[i]); | |||
for (size_t i = 0; i < nargs; ++i) { | |||
ctx.flags |= args[i]->m_flags; | |||
} | |||
return apply(ctx); | |||
} | |||
inline auto apply(std::shared_ptr<OpDef> op, Tensor*const* args, size_t nargs) { | |||
ApplyContext ctx; | |||
ctx.op = std::move(op); | |||
ctx.nargs = nargs; | |||
ctx.args = args; | |||
template <typename T> | |||
auto apply(std::shared_ptr<OpDef> op, T&& tensors) | |||
-> std::enable_if_t<std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>, | |||
apply_result_t> { | |||
size_t nargs = tensors.size(); | |||
Tensor* args[nargs]; | |||
for (size_t i = 0; i < nargs; ++i) { | |||
ctx.flags |= args[i]->m_flags; | |||
args[i] = resolve_arrow(tensors[i]); | |||
} | |||
return apply(ctx); | |||
return apply(op, args, nargs); | |||
} | |||
inline auto apply(Subgraph graph, Tensor*const* args, size_t nargs) { | |||
SmallVector<std::shared_ptr<Tensor>> inputs; | |||
for (size_t i = 0; i < nargs; ++i) { | |||
inputs.push_back(args[i]->shared_from_this()); | |||
} | |||
auto apply_functor = [](std::shared_ptr<OpDef> op, SmallVector<std::shared_ptr<Tensor>> inputs) { | |||
return apply(op, inputs); | |||
}; | |||
auto const_functor = [](imperative::TensorPtr value) { | |||
return std::make_shared<Tensor>(interpreter_for_py->put(value->dev_tensor())); | |||
}; | |||
return graph.apply(inputs, apply_functor, const_functor); | |||
} | |||
template <typename T> | |||
auto apply(Subgraph graph, T&& tensors) | |||
-> std::enable_if_t<std::is_same_v<decltype(tensors[0]), Tensor*>, | |||
apply_result_t> { | |||
size_t nargs = tensors.size(); | |||
Tensor* args[nargs]; | |||
for (size_t i = 0; i < nargs; ++i) { | |||
args[i] = resolve_arrow(tensors[i]); | |||
} | |||
return apply(graph, args, nargs); | |||
} | |||
void init_tensor(pybind11::module); | |||
@@ -22,7 +22,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||
apply_result_t outputs; | |||
if (ctx.backward) { | |||
// call megbrain_graph.py apply(BackwardGraph, *args) | |||
// reach here when compiled=True | |||
auto args = py::tuple(ctx.nargs + 1); | |||
args[0] = py::cast(ctx.op); | |||
for (size_t i = 0; i < ctx.nargs; i++) { | |||
@@ -18,24 +18,22 @@ using namespace imperative; | |||
OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphResult& src) | |||
: input_has_grad(src.input_has_grad) { | |||
if (!src.backward->same_type<BackwardGraph>()) { | |||
if (src.backward.exprs.size() <= 1) { | |||
// backward graph only contains a single op | |||
backward = src.backward; | |||
save_for_backward = src.save_for_backward; | |||
return; | |||
} | |||
save_for_backward.resize(src.save_for_backward.size(), false); | |||
precomp.reset(new BackwardGraph); | |||
backward.reset(new BackwardGraph); | |||
auto&& graph = src.backward->cast_final_safe<BackwardGraph>().graph(); | |||
auto&& graph = src.backward; | |||
auto&& mask = src.save_for_backward; | |||
size_t input_size = src.input_has_grad.size(); | |||
size_t output_size = (mask.size() - input_size) / 2; | |||
mgb_assert(input_size + output_size * 2 == mask.size()); | |||
auto& fgraph = precomp->cast_final<BackwardGraph>().graph(); | |||
auto& bgraph = backward->cast_final<BackwardGraph>().graph(); | |||
auto& fgraph = precomp; | |||
auto& bgraph = backward; | |||
// optimization: move ops (e.g. GetVarShape) to forward to | |||
// reduce memory footprint | |||
@@ -113,6 +111,6 @@ OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphRe | |||
} | |||
if (!fgraph.outputs.size()) { | |||
precomp.reset(); | |||
precomp = {}; | |||
} | |||
} |
@@ -911,8 +911,7 @@ auto ChannelImpl::CommandBuffer::flush_pos_for(const Command& cmd) -> Handle { | |||
op_type == RemoteSend::typeinfo() || | |||
op_type == CollectiveComm::typeinfo() || | |||
op_type == opr::InputCallback::typeinfo() || | |||
op_type == opr::OutputCallback::typeinfo() || | |||
op_type == BackwardGraph::typeinfo()) { | |||
op_type == opr::OutputCallback::typeinfo()) { | |||
return m_commands.end(); | |||
} | |||
} else if constexpr (std::is_same_v<T, GetValue>) { | |||
@@ -10,6 +10,9 @@ | |||
*/ | |||
#include "megbrain/imperative/op_def.h" | |||
#include <sstream> | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "./op_trait.h" | |||
@@ -117,6 +120,67 @@ const std::string OpDef::make_name() const { | |||
return m_scope + "." + trait()->make_name(*this); | |||
} | |||
std::string Subgraph::repr() const { | |||
std::ostringstream buf; | |||
buf << "("; | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
if (i > 0) buf << ", "; | |||
buf << "%" << inputs[i]; | |||
} | |||
buf << ") => {\n"; | |||
auto fmt_const = [](size_t i, const TensorPtr& t) { | |||
if (t->shape().ndim == 1 && t->shape()[0] == 1) { | |||
auto&& v = t->get_value(); | |||
if (v.dtype() == dtype::Float32{}) { | |||
return std::to_string(*v.ptr<dt_float32>()); | |||
} else if (v.dtype() == dtype::Int32{}) { | |||
return std::to_string(*v.ptr<int32_t>()); | |||
} | |||
} | |||
return std::string("%c") + std::to_string(i); | |||
}; | |||
std::unordered_map<size_t, std::string> const_reps; | |||
for (auto&& [i, t] : constants) { | |||
const_reps.emplace(i, fmt_const(i, t)); | |||
} | |||
for (auto& [op, ins, outs] : exprs) { | |||
buf << " "; | |||
if (outs.size()) { | |||
for (size_t i = 0; i < outs.size(); ++i) { | |||
if (i > 0) buf << ", "; | |||
buf << "%" << outs[i]; | |||
} | |||
buf << " = "; | |||
} | |||
if (auto* p = op->try_cast_final<OprAttr>()) { | |||
buf << p->type; | |||
} else { | |||
buf << op->dyn_typeinfo()->name; | |||
} | |||
for (size_t i : ins) { | |||
buf << " "; | |||
auto&& it = const_reps.find(i); | |||
if (it != const_reps.end()) { | |||
buf << it->second; | |||
} else { | |||
buf << "%" << i; | |||
} | |||
} | |||
buf << "\n"; | |||
} | |||
buf << " "; | |||
if (outputs.size()) { | |||
for (size_t i = 0; i < outputs.size(); ++i) { | |||
if (i > 0) buf << ", "; | |||
buf << "%" << outputs[i]; | |||
} | |||
} else { | |||
buf << "()"; | |||
} | |||
buf << "\n}\n"; | |||
return buf.str(); | |||
} | |||
} // namespace imperative | |||
} // namespace mgb | |||
@@ -19,147 +19,6 @@ | |||
namespace mgb { | |||
namespace imperative { | |||
SmallVector<TensorPtr> | |||
BackwardGraph::InternalGraph::apply( | |||
const SmallVector<TensorPtr>& inputs) const { | |||
return interpret<TensorPtr>( | |||
&OpDef::apply_on_physical_tensor, | |||
[](const TensorPtr& x) {return x;}, | |||
inputs); | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> BackwardGraph::InternalGraph::infer_attrs( | |||
const SmallVector<LogicalTensorDesc>& inputs) const { | |||
using TensorAttr = LogicalTensorDesc; | |||
ThinHashMap<size_t, TensorAttr> node2attr; | |||
auto&& input_nodes = this->inputs; | |||
auto&& output_nodes = this->outputs; | |||
mgb_assert(inputs.size() == input_nodes.size()); | |||
for (size_t i = 0; i < inputs.size(); ++ i) { | |||
node2attr[input_nodes[i]] = inputs[i]; | |||
} | |||
for (auto &&i : constants) { | |||
auto* value = i.second->try_get_value(); | |||
mgb_assert(value); | |||
node2attr[i.first] = TensorAttr{ | |||
i.second->layout(), i.second->comp_node(), | |||
value->proxy_to_default_cpu()}; | |||
} | |||
bool validated = true; | |||
for (size_t i = 0; i < exprs.size(); ++ i) { | |||
auto&& [expr_op, expr_inps, expr_oups] = exprs[i]; | |||
SmallVector<TensorAttr> expr_input_descs; | |||
for (auto &&inp : expr_inps) { | |||
expr_input_descs.push_back(node2attr.at(inp)); | |||
} | |||
auto [expr_output_descs, expr_validated] = OpDef::infer_output_attrs_fallible( | |||
*expr_op, expr_input_descs); | |||
validated = validated && expr_validated; | |||
mgb_assert(expr_output_descs.size() == expr_oups.size()); | |||
for (size_t i = 0; i < expr_output_descs.size(); ++ i) { | |||
node2attr[expr_oups[i]] = expr_output_descs[i]; | |||
} | |||
} | |||
SmallVector<TensorAttr> ret; | |||
for (auto &&i : output_nodes) { | |||
ret.push_back(node2attr.at(i)); | |||
} | |||
return {ret, validated}; | |||
} | |||
std::string BackwardGraph::InternalGraph::repr() { | |||
std::ostringstream buf; | |||
buf << "("; | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
if (i > 0) buf << ", "; | |||
buf << "%" << inputs[i]; | |||
} | |||
buf << ") => {\n"; | |||
auto fmt_const = [](size_t i, TensorPtr& t) { | |||
if (t->shape().ndim == 1 && t->shape()[0] == 1) { | |||
auto&& v = t->get_value(); | |||
if (v.dtype() == dtype::Float32{}) { | |||
return std::to_string(*v.ptr<dt_float32>()); | |||
} else if (v.dtype() == dtype::Int32{}) { | |||
return std::to_string(*v.ptr<int32_t>()); | |||
} | |||
} | |||
return std::string("%c") + std::to_string(i); | |||
}; | |||
std::unordered_map<size_t, std::string> const_reps; | |||
for (auto&& [i, t] : constants) { | |||
const_reps.emplace(i, fmt_const(i, t)); | |||
} | |||
for (auto& [op, ins, outs] : exprs) { | |||
buf << " "; | |||
if (outs.size()) { | |||
for (size_t i = 0; i < outs.size(); ++i) { | |||
if (i > 0) buf << ", "; | |||
buf << "%" << outs[i]; | |||
} | |||
buf << " = "; | |||
} | |||
if (auto* p = op->try_cast_final<OprAttr>()) { | |||
buf << p->type; | |||
} else { | |||
buf << op->dyn_typeinfo()->name; | |||
} | |||
for (size_t i : ins) { | |||
buf << " "; | |||
auto&& it = const_reps.find(i); | |||
if (it != const_reps.end()) { | |||
buf << it->second; | |||
} else { | |||
buf << "%" << i; | |||
} | |||
} | |||
buf << "\n"; | |||
} | |||
buf << " "; | |||
if (outputs.size()) { | |||
for (size_t i = 0; i < outputs.size(); ++i) { | |||
if (i > 0) buf << ", "; | |||
buf << "%" << outputs[i]; | |||
} | |||
} else { | |||
buf << "()"; | |||
} | |||
buf << "\n}\n"; | |||
return buf.str(); | |||
} | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(BackwardGraph); | |||
namespace { | |||
SmallVector<TensorPtr> backward_impl( | |||
const OpDef& backward_graph, | |||
const SmallVector<TensorPtr>& tensors) { | |||
return backward_graph.cast_final_safe<BackwardGraph>() | |||
.graph().apply(tensors); | |||
} | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_tensor_attrs( | |||
const OpDef& backward_graph, | |||
const SmallVector<LogicalTensorDesc> inputs) { | |||
return backward_graph.cast_final_safe<BackwardGraph>() | |||
.graph().infer_attrs(inputs); | |||
} | |||
std::vector<std::pair<const char*, std::string>> props( | |||
const OpDef& backward_graph) { | |||
return {}; | |||
} | |||
OP_TRAIT_REG(BackwardGraph, BackwardGraph) | |||
.apply_on_physical_tensor(backward_impl) | |||
.infer_output_attrs_fallible(infer_tensor_attrs) | |||
.props(props) | |||
.fallback(); | |||
} // anonymous namespace | |||
} // namespace imperative | |||
} // namespace mgb | |||
@@ -669,8 +669,7 @@ ProxyGraph::make_backward_graph( | |||
auto* gfunc = cg::lookup_grad_func(fwd->dyn_typeinfo()); | |||
BackwardGraphResult result; | |||
auto&& backward = BackwardGraph::make(); | |||
auto&& igraph = backward->cast_final_safe<BackwardGraph>().graph(); | |||
auto&& igraph = result.backward; | |||
size_t nr_backward_graph_inputs = 0; | |||
auto gen_expr = [this, &var2idx, &igraph, &push, &fwd, | |||
@@ -682,7 +681,7 @@ ProxyGraph::make_backward_graph( | |||
++ nr_backward_graph_inputs; | |||
push(op->output(0)); | |||
} else { | |||
std::vector<size_t> inputs, outputs; | |||
SmallVector<size_t> inputs, outputs; | |||
for (auto &&i : op->input()) { | |||
if (i->owner_opr() == fwd) { | |||
if (var2idx.find(i) == var2idx.end()) { | |||
@@ -695,7 +694,7 @@ ProxyGraph::make_backward_graph( | |||
for (auto &&i : op->usable_output()) { | |||
outputs.push_back(push(i)); | |||
} | |||
igraph.exprs.emplace_back(OpDef::make_from_op_node(op), inputs, outputs); | |||
igraph.exprs.push_back({OpDef::make_from_op_node(op), inputs, outputs}); | |||
} | |||
}; | |||
@@ -770,36 +769,6 @@ ProxyGraph::make_backward_graph( | |||
write_inputs(outputs); | |||
write_inputs(output_grads); | |||
mgb_assert(igraph.inputs.size() == nr_backward_graph_inputs); | |||
auto treat_as_single = [](auto&& igraph) { | |||
if (igraph.exprs.size() != 1) | |||
return false; | |||
auto&& expr = igraph.exprs[0]; | |||
auto&& expr_inputs = std::get<1>(expr); | |||
if (expr_inputs.size() != igraph.inputs.size()) { | |||
return false; | |||
} | |||
for (size_t i = 0; i < expr_inputs.size(); ++ i) { | |||
if (igraph.inputs[i] != expr_inputs[i]) { | |||
return false; | |||
} | |||
} | |||
auto&& expr_outputs = std::get<2>(expr); | |||
if (expr_outputs.size() != igraph.outputs.size()) { | |||
return false; | |||
} | |||
for (size_t i = 0; i < expr_outputs.size(); ++ i) { | |||
if (igraph.outputs[i] != expr_outputs[i]) { | |||
return false; | |||
} | |||
} | |||
return true; | |||
}; | |||
if (treat_as_single(igraph)) { | |||
result.backward = std::get<0>(igraph.exprs[0]); | |||
} else { | |||
result.backward = backward; | |||
} | |||
return result; | |||
} | |||
@@ -65,7 +65,7 @@ private: | |||
class InputPlaceholder; | |||
struct ProxyGraphInst; | |||
struct GradGraph; | |||
struct CurOprGuard; | |||
class CurOprGuard; | |||
void reset(); | |||
@@ -15,7 +15,7 @@ namespace mgb::imperative::proxy_graph { | |||
// e.g. friend class mgb::imperative::proxy_graph::ProxyGraph | |||
struct ProxyGraph { | |||
struct InputPlaceholder; | |||
struct MiniGraph; | |||
class MiniGraph; | |||
}; | |||
} // namespace mgb::imperative::proxy_graph |
@@ -75,30 +75,7 @@ apply_on_physical_tensor(const OpDef& def, | |||
auto output_descs = infer_output_attrs(def, inputs); | |||
SmallVector<TensorPtr> outputs(output_descs.size(), {}); | |||
for (size_t i = 0; i < outputs.size(); i++) { | |||
auto& output = outputs[i]; | |||
auto& output_desc = output_descs[i]; | |||
if (def.same_type<Elemwise>()) { | |||
for (size_t j = 0; j < inputs.size(); j++) { | |||
// TODO: reindex inputs to support inplace exprs like 'y = x op x'. | |||
auto& input = inputs[j]; | |||
// Because we pass inputs by value, if input and input->blob() are all unique, | |||
// their ownerships are on the stack, thus we can reuse them safely. | |||
// @see: interpreter::intl::ChannelImpl::process_one_task | |||
if (input.unique() && input->blob().unique() && input->blob()->storage().unique() && | |||
input->layout().dtype == output_desc.layout.dtype && | |||
input->layout().eq_layout(output_desc.layout) && | |||
input->comp_node() == output_desc.comp_node) { | |||
static std::atomic_llong inplace_count = 0; | |||
mgb_log_debug("do inplace for elemwise, layout: %s, count: %lld", | |||
output_desc.layout.to_string().c_str(), ++inplace_count); | |||
output = Tensor::make(input->blob(), input->layout(), input->offset()); | |||
break; | |||
} | |||
} | |||
} | |||
if (!output) { | |||
output = Tensor::make(output_desc.layout, output_desc.comp_node); | |||
} | |||
outputs[i] = Tensor::make(output_descs[i].layout, output_descs[i].comp_node); | |||
} | |||
exec(def, inputs, outputs); | |||
auto async_error = ProxyGraph::get_async_error(); | |||
@@ -14,10 +14,10 @@ | |||
namespace mgb::imperative { | |||
struct OptimizedBackwardGraphResult { | |||
std::shared_ptr<OpDef> precomp; | |||
std::shared_ptr<OpDef> backward; | |||
std::vector<bool> save_for_backward; | |||
std::vector<bool> input_has_grad; | |||
Subgraph precomp; | |||
Subgraph backward; | |||
SmallVector<bool> save_for_backward; | |||
SmallVector<bool> input_has_grad; | |||
OptimizedBackwardGraphResult(const BackwardGraphResult& bgraph); | |||
}; | |||
@@ -26,10 +26,60 @@ enum DispatchMode { | |||
KERNEL = 1 | |||
}; | |||
using SharedOp = std::shared_ptr<OpDef>; | |||
template <typename T> | |||
struct Expr { | |||
std::shared_ptr<OpDef> op; | |||
SmallVector<T> inputs; | |||
SmallVector<T> outputs; | |||
}; | |||
struct Subgraph { | |||
SmallVector<size_t> inputs; | |||
SmallVector<std::pair<size_t, TensorPtr>> constants; | |||
SmallVector<size_t> outputs; | |||
SmallVector<Expr<size_t>> exprs; | |||
template <typename T, typename F, typename C> | |||
SmallVector<T> apply(SmallVector<T> input_vars, F&& f, C&& c) const { | |||
std::unordered_map<size_t, T> idx2var; | |||
mgb_assert(inputs.size() == input_vars.size(), "input size mismatch"); | |||
for (size_t i = 0; i < inputs.size(); ++i) { | |||
idx2var[inputs[i]] = input_vars[i]; | |||
} | |||
for (auto&& [idx, val]: constants) { | |||
idx2var[idx] = c(val); | |||
} | |||
for (auto& expr: exprs) { | |||
SmallVector<T> expr_inputs; | |||
for (auto idx: expr.inputs) { | |||
expr_inputs.push_back(idx2var[idx]); | |||
} | |||
SmallVector<T> expr_outputs = f(expr.op, std::move(expr_inputs)); | |||
mgb_assert(expr_outputs.size() == expr.outputs.size(), "output size mismatch"); | |||
for (size_t i = 0; i < expr_outputs.size(); ++i) { | |||
idx2var[expr.outputs[i]] = expr_outputs[i]; | |||
} | |||
} | |||
SmallVector<T> output_vars; | |||
for (auto idx: outputs) { | |||
output_vars.push_back(idx2var[idx]); | |||
} | |||
return output_vars; | |||
} | |||
bool empty() const { | |||
return outputs.size() == 0; | |||
} | |||
std::string repr() const; | |||
}; | |||
struct BackwardGraphResult { | |||
std::shared_ptr<OpDef> backward; | |||
std::vector<bool> save_for_backward; | |||
std::vector<bool> input_has_grad; | |||
Subgraph backward; | |||
SmallVector<bool> save_for_backward; | |||
SmallVector<bool> input_has_grad; | |||
}; | |||
class OpDef : public Hashable, | |||
@@ -15,92 +15,6 @@ | |||
namespace mgb { | |||
namespace imperative { | |||
// a special OpDef used for taking gradient on physical tensor | |||
struct BackwardGraph final : public OpDefImplBase<BackwardGraph> { | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
public: | |||
struct InternalGraph { | |||
// op, inputs, outputs | |||
using Expr = std::tuple<std::shared_ptr<OpDef>, | |||
std::vector<size_t>, std::vector<size_t>>; | |||
std::vector<Expr> exprs; | |||
// index array of input nodes | |||
std::vector<size_t> inputs; | |||
// index array of output nodes | |||
std::vector<size_t> outputs; | |||
// pair of (node index, correspending constant) | |||
std::vector<std::pair<size_t, TensorPtr>> constants; | |||
SmallVector<TensorPtr> | |||
apply(const SmallVector<TensorPtr>& inputs) const; | |||
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_attrs( | |||
const SmallVector<LogicalTensorDesc>& inputs) const; | |||
template <typename T, typename F, typename C> | |||
SmallVector<T> interpret(F&& f, C&& c, const SmallVector<T>& inputs) const { | |||
ThinHashMap<size_t, T> node2tensor; | |||
auto&& input_nodes = this->inputs; | |||
mgb_assert(inputs.size() == input_nodes.size()); | |||
for (size_t i = 0; i < inputs.size(); ++ i) { | |||
node2tensor[input_nodes[i]] = inputs[i]; | |||
} | |||
for (auto &&i : constants) { | |||
node2tensor[i.first] = c(i.second); | |||
} | |||
for (size_t i = 0; i < exprs.size(); ++ i) { | |||
auto&& expr = exprs[i]; | |||
SmallVector<T> inputs; | |||
for (auto &&in : std::get<1>(expr)) { | |||
inputs.push_back(node2tensor.at(in)); | |||
} | |||
auto&& outputs = f(*std::get<0>(expr), std::move(inputs)); | |||
auto&& output_nodes = std::get<2>(expr); | |||
mgb_assert(outputs.size() == output_nodes.size()); | |||
for (size_t i = 0; i < outputs.size(); ++ i) { | |||
node2tensor[output_nodes[i]] = std::move(outputs[i]); | |||
} | |||
} | |||
SmallVector<T> ret; | |||
for (auto &&i : outputs) { | |||
ret.push_back(node2tensor.at(i)); | |||
} | |||
return ret; | |||
} | |||
std::string repr(); | |||
}; | |||
const InternalGraph& graph() const { | |||
return m_graph; | |||
} | |||
InternalGraph& graph() { | |||
return m_graph; | |||
} | |||
bool is_same_st(const Hashable& rhs) const override { | |||
if (!rhs.same_type<BackwardGraph>()) { | |||
return false; | |||
} | |||
auto& other = rhs.cast_final_safe<BackwardGraph>(); | |||
if (this == &other) { | |||
return true; | |||
} | |||
// FIXME | |||
return false; | |||
} | |||
std::string repr() {return m_graph.repr();} | |||
private: | |||
InternalGraph m_graph; | |||
}; | |||
} // namespace imperative | |||
} // namespace mgb | |||
@@ -29,7 +29,7 @@ struct GenericPyOp final : OpDefImplBase<GenericPyOp> { | |||
} | |||
bool is_same_st(const Hashable& rhs) const override { | |||
return obj.equal(static_cast<const GenericPyOp&>(rhs).obj); | |||
return obj.equal(rhs.cast_final<GenericPyOp>().obj); | |||
} | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
@@ -75,6 +75,10 @@ T prepare_optimized_backward_inputs(const OptimizedBackwardGraphResult& bg, cons | |||
return ret; | |||
} | |||
SmallVector<TensorPtr> apply_shared_on_physical_tensor(std::shared_ptr<OpDef> def, SmallVector<TensorPtr> inputs) { | |||
return OpDef::apply_on_physical_tensor(*def, inputs); | |||
} | |||
TEST(TestImperative, BackwardGraphBasic) { | |||
HostTensorGenerator<> gen; | |||
SmallVector<HostTensorND> hvs; | |||
@@ -114,7 +118,11 @@ TEST(TestImperative, BackwardGraphBasic) { | |||
} | |||
} | |||
inputs.clear(); | |||
auto input_grads = OpDef::apply_on_physical_tensor(*(result.backward), backward_graph_inputs); | |||
auto input_grads = result.backward.apply( | |||
backward_graph_inputs, | |||
apply_shared_on_physical_tensor, | |||
[&](auto&& x){ return x; } | |||
); | |||
mgb_assert(input_grads.size() == input_has_grad.size()); | |||
for (size_t i = 0; i < input_has_grad.size(); ++ i) { | |||
mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i])); | |||
@@ -164,7 +172,11 @@ TEST(TestImperative, BackwardGraphIdentity) { | |||
} | |||
} | |||
inputs.clear(); | |||
auto input_grads = OpDef::apply_on_physical_tensor(*(result.backward), backward_graph_inputs); | |||
auto input_grads = result.backward.apply( | |||
backward_graph_inputs, | |||
apply_shared_on_physical_tensor, | |||
[&](auto&& x){ return x; } | |||
); | |||
mgb_assert(input_grads.size() == input_has_grad.size()); | |||
for (size_t i = 0; i < input_has_grad.size(); ++ i) { | |||
mgb_assert(input_has_grad[i] == static_cast<bool>(input_grads[i])); | |||
@@ -224,9 +236,17 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||
auto c_tn = OpDef::apply_on_physical_tensor(*op, {a_tn, b_tn})[0]; | |||
auto backward_graph_inputs = prepare_backward_graph_inputs<SmallVector<TensorPtr>>(bg, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
auto grads = expand_grads(bg, OpDef::apply_on_physical_tensor(*bg.backward, backward_graph_inputs)); | |||
auto grads = expand_grads(bg, bg.backward.apply( | |||
backward_graph_inputs, | |||
apply_shared_on_physical_tensor, | |||
[&](auto&& x){ return x; } | |||
)); | |||
auto precomp = OpDef::apply_on_physical_tensor(*obg.precomp, {a_tn, b_tn, c_tn}); | |||
auto precomp = obg.precomp.apply( | |||
SmallVector<TensorPtr>{a_tn, b_tn, c_tn}, | |||
apply_shared_on_physical_tensor, | |||
[&](auto&& x){ return x; } | |||
); | |||
ASSERT_EQ(precomp.size(), 2); | |||
ASSERT_EQ(precomp[0]->shape().ndim, 1); | |||
ASSERT_LE(precomp[0]->shape()[0], 2); | |||
@@ -234,7 +254,11 @@ TEST(TestImperative, OptimizedBackwardGraphBasic) { | |||
ASSERT_LE(precomp[1]->shape()[0], 2); | |||
auto backward_inputs = prepare_optimized_backward_inputs<SmallVector<TensorPtr>>(obg, precomp, {a_tn, b_tn}, {c_tn}, {dc_tn}); | |||
auto grads2 = expand_grads(obg, OpDef::apply_on_physical_tensor(*obg.backward, backward_inputs)); | |||
auto grads2 = expand_grads(obg, obg.backward.apply( | |||
backward_inputs, | |||
apply_shared_on_physical_tensor, | |||
[&](auto&& x){ return x; } | |||
)); | |||
ASSERT_EQ(grads2.size(), 2); | |||
MGB_ASSERT_TENSOR_EQ(grads[0]->get_value(), grads2[0]->get_value()); | |||