@@ -12,6 +12,7 @@ | |||
#include "./grad.h" | |||
#include "megbrain/imperative/proxy_graph_detail.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include "megbrain/imperative/ops/utility.h" | |||
#include "megbrain/utils/mempool.h" | |||
#include "range/v3/all.hpp" | |||
@@ -21,6 +22,9 @@ namespace views = ranges::views; | |||
namespace mgb::imperative::python { | |||
using scoped_disable = ApplyContext::scoped_disable; | |||
using Flags = Tensor::Flags; | |||
namespace { | |||
struct GradSlotWeakPtr { | |||
@@ -78,6 +82,21 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph( | |||
return result; | |||
} | |||
struct BackwardContext { | |||
PyTypeObject* pytype = nullptr; | |||
auto wrap_tensor(std::shared_ptr<Tensor> t) { | |||
if (pytype) { | |||
return TensorWrapper::make(pytype, std::move(t)); | |||
} | |||
return TensorWrapper::make(std::move(t)); | |||
} | |||
auto wrap_tensor(Tensor* t) { | |||
return wrap_tensor(t->shared_from_this()); | |||
} | |||
}; | |||
struct BackwardGraphWithClosure { | |||
std::shared_ptr<BackwardGraphResult> backward_graph; | |||
SmallVector<std::shared_ptr<Tensor>> closure; | |||
@@ -119,7 +138,7 @@ struct BackwardGraphWithClosure { | |||
} | |||
template <typename T, typename R> | |||
void operator()(T&& grads, R&& receiver) { | |||
void operator()(BackwardContext&, T&& grads, R&& receiver) { | |||
Tensor* args[closure.size() + grads.size()]; | |||
size_t nargs = 0; | |||
for (auto&& t : closure) { | |||
@@ -143,7 +162,7 @@ struct BackwardGraphWithClosure { | |||
ApplyContext ctx; | |||
ctx.op = backward_graph->backward; | |||
ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0; | |||
ctx.flags = is_tracing ? Flags::TRACE : 0; | |||
ctx.nargs = nargs; | |||
ctx.args = args; | |||
for (size_t i = 0; i < nargs; ++i) { | |||
@@ -174,6 +193,47 @@ struct BackwardGraphWithClosure { | |||
} | |||
}; | |||
struct PythonBackward { | |||
py::object pyfunc; | |||
size_t input_size; | |||
PythonBackward(py::object f, size_t nin) | |||
: pyfunc(f), input_size(nin) {} | |||
template <typename T, typename R> | |||
void operator()(BackwardContext& ctx, T&& grads, R&& receiver) { | |||
auto args = py::tuple(grads.size()); | |||
for (size_t i = 0; i < grads.size(); ++i) { | |||
auto&& g = grads[i]; | |||
args[i] = g ? ctx.wrap_tensor(g) : py::none(); | |||
} | |||
auto input_grads = py::reinterpret_steal<py::object>(PyObject_Call(pyfunc.ptr(), args.ptr(), nullptr)); | |||
if (input_grads.is_none()) return; | |||
if (auto* tw = TensorWrapper::try_cast(input_grads.ptr())) { | |||
if (input_size != 1) { | |||
throw py::value_error("custom grad rule returned wrong number of grads"); | |||
} | |||
receiver(0, tw->m_tensor); | |||
return; | |||
} | |||
if (py::len(input_grads) != input_size) { | |||
throw py::value_error("custom grad rule returned wrong number of grads"); | |||
} | |||
for (auto [i, g] : views::enumerate(input_grads)) { | |||
if (g.is_none()) continue; | |||
auto* tw = TensorWrapper::try_cast(g.ptr()); | |||
if (!tw) { | |||
throw py::type_error("custom grad rule returned non-tensor"); | |||
} | |||
receiver(i, tw->m_tensor); | |||
} | |||
} | |||
static constexpr bool input_has_grad(size_t) {return true;} | |||
static constexpr bool output_requires_grad(size_t) {return true;} | |||
static constexpr bool output_captured(size_t) {return true;} | |||
}; | |||
} // namespace | |||
struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> { | |||
@@ -210,7 +270,7 @@ struct GradFn : std::enable_shared_from_this<GradFn> { | |||
// same length as inputs (of forward op) | |||
SmallVector<GradSlotProducerPtr> dsts; | |||
// encapsules actual function to compute gradient | |||
std::variant<std::monostate, BackwardGraphWithClosure> backward; | |||
std::variant<std::monostate, BackwardGraphWithClosure, PythonBackward> backward; | |||
// a flag used during backward | |||
bool in_ref_keeper = false; | |||
@@ -268,6 +328,30 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra | |||
return outputs; | |||
} | |||
apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { | |||
auto* op = ctx.op->try_cast_final<GenericPyOp>(); | |||
py::tuple pyin(ctx.nargs); | |||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||
pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); | |||
} | |||
auto grad_rule = py::getattr(op->obj, "_grad_rule"); | |||
auto pyret = (scoped_disable(Flags::GRAD), | |||
py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr))); // comma expression | |||
auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret); | |||
ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs); | |||
if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) { | |||
return {tw->m_tensor}; | |||
} | |||
apply_result_t ret; | |||
ret.reserve(py::len(outputs)); | |||
for (auto&& i : outputs) { | |||
auto* tw = TensorWrapper::try_cast(i.ptr()); | |||
mgb_assert(tw); | |||
ret.push_back(tw->m_tensor); | |||
} | |||
return ret; | |||
} | |||
} // namespace | |||
apply_result_t apply_grad(ApplyContext& ctx) { | |||
@@ -290,21 +374,23 @@ apply_result_t apply_grad(ApplyContext& ctx) { | |||
// cleanup stale grad info | |||
// under what condition? | |||
tensor->m_grad_info = {}; | |||
tensor->m_flags &= ~Tensor::Flags::GRAD; | |||
tensor->m_flags &= ~Flags::GRAD; | |||
} | |||
} else { | |||
tensor->m_flags &= ~Tensor::Flags::GRAD; | |||
tensor->m_flags &= ~Flags::GRAD; | |||
} | |||
} | |||
ctx.flags &= ~Tensor::Flags::GRAD; | |||
ctx.flags &= ~Flags::GRAD; | |||
if (!grad_key) { | |||
return apply(ctx); | |||
} | |||
GradFnHelper grad_fn_holder; | |||
auto outputs = backward_graph_grad_rule(ctx, grad_fn_holder); | |||
auto outputs = ctx.op->same_type<GenericPyOp>() ? | |||
python_grad_rule(ctx, grad_fn_holder) : | |||
backward_graph_grad_rule(ctx, grad_fn_holder); | |||
auto& grad_fn = grad_fn_holder.grad_fn; | |||
if (!grad_fn) { | |||
@@ -341,7 +427,7 @@ apply_result_t apply_grad(ApplyContext& ctx) { | |||
grad_info.grad_fn = grad_fn; | |||
grad_info.idx = i; | |||
grad_info.insert_after(grad_key->free_vars_head); | |||
outputs[i]->m_flags |= Tensor::Flags::GRAD; | |||
outputs[i]->m_flags |= Flags::GRAD; | |||
} | |||
} | |||
} | |||
@@ -357,7 +443,7 @@ void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) { | |||
if (nargs != 2) { | |||
throw py::type_error("expect 2 arguments"); | |||
} | |||
auto* tw = TensorWrapper::cast_safe(args[0]); | |||
auto* tw = TensorWrapper::try_cast(args[0]); | |||
if (!tw) { | |||
throw py::type_error("argument 1 must be Tensor"); | |||
} | |||
@@ -390,14 +476,15 @@ void GradKey::attach(Tensor* tensor, pybind11::object callback) { | |||
grad_fn->key = shared_from_this(); | |||
grad_fn->slots.resize(1); | |||
tensor->m_grad_info.insert_after(free_vars_head); | |||
tensor->m_flags |= Tensor::Flags::GRAD; | |||
tensor->m_flags |= Flags::GRAD; | |||
} | |||
tensor->m_grad_info.grad_fn->slots[0].callback = std::move(callback); | |||
} | |||
void accum_grad(std::shared_ptr<Tensor>& grad, std::shared_ptr<Tensor>&& delta) { | |||
template<typename T> | |||
void accum_grad(std::shared_ptr<Tensor>& grad, T&& delta) { | |||
if (!grad) { | |||
grad = std::forward<decltype(delta)>(delta); | |||
grad = std::forward<T>(delta); | |||
return; | |||
} | |||
static ApplyContext ctx; | |||
@@ -409,7 +496,7 @@ void accum_grad(std::shared_ptr<Tensor>& grad, std::shared_ptr<Tensor>&& delta) | |||
ctx.args = args; | |||
ctx.flags = grad->m_flags | delta->m_flags; | |||
if (is_tracing) { | |||
ctx.flags |= Tensor::Flags::TRACE; | |||
ctx.flags |= Flags::TRACE; | |||
} | |||
grad = apply(ctx)[0]; | |||
} | |||
@@ -440,6 +527,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr | |||
} | |||
} | |||
BackwardContext bctx{pytype}; | |||
std::vector<std::shared_ptr<GradFn>> ref_keeper; | |||
ref_keeper.reserve(tape.size()); | |||
// back-propagation in reverse order | |||
@@ -456,7 +544,7 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr | |||
mgb_assert(0); | |||
} else { | |||
auto&& grads = views::transform(grad_fn->slots, [](auto&& slot) {return slot.grad.get();}); | |||
backward(std::forward<decltype(grads)>(grads), grad_receiver); | |||
backward(bctx, std::forward<decltype(grads)>(grads), grad_receiver); | |||
} | |||
}, grad_fn->backward); | |||
@@ -14,6 +14,7 @@ | |||
#include "megbrain/imperative.h" | |||
#include "megbrain/imperative/ops/backward_graph.h" | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "megbrain/imperative/ops/utility.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
#include <Python.h> | |||
@@ -245,6 +246,35 @@ void _init_py_backward_graph(py::module m) { | |||
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(BackwardGraph::typeinfo(), &py_type).second); | |||
} | |||
struct PyOpBase : PyOpDef { | |||
static PyTypeObject py_type; | |||
static PyObject* tp_new(PyTypeObject* type, PyObject*, PyObject*) { | |||
auto* obj = type->tp_alloc(type, 0); | |||
if (obj) { | |||
auto* self = reinterpret_cast<PyOpBase*>(obj); | |||
new(&self->op) decltype(self->op); | |||
} | |||
return obj; | |||
} | |||
}; | |||
PyTypeObject PyOpBase::py_type; | |||
void _init_py_op_base(py::module m) { | |||
using py_op = PyOpBase; | |||
auto& py_type = PyOpBase::py_type; | |||
py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; | |||
py_type.tp_name = "megengine.core._imperative_rt.ops.PyOpBase"; | |||
py_type.tp_basicsize = sizeof(py_op); | |||
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; | |||
py_type.tp_doc = "PyOpBase"; | |||
py_type.tp_base = &PyOpType(OpDef); | |||
py_type.tp_dealloc = py_dealloc_generic<py_op>; | |||
py_type.tp_new = py_op::tp_new; | |||
mgb_assert(PyType_Ready(&py_type) >= 0); | |||
m.add_object("PyOpBase", reinterpret_cast<PyObject*>(&py_type)); | |||
} | |||
/*********** end of hand-write opdefs **************/ | |||
// auto generated opdefs | |||
@@ -260,9 +290,16 @@ bool type_caster<OpDef>::load(handle src, bool convert) { | |||
return false; | |||
} | |||
value = reinterpret_cast<PyOp(OpDef)*>(obj)->op; | |||
if (!value) { | |||
// opdef only defined in Python | |||
value = std::make_shared<GenericPyOp>(reinterpret_borrow<object>(src)); | |||
} | |||
return true; | |||
} | |||
handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) { | |||
if (auto* pyop = op.try_cast_final<GenericPyOp>()) { | |||
return object(pyop->obj).release(); | |||
} | |||
PyTypeObject* pytype; | |||
auto& c2p = PyOp(OpDef)::ctype2pytype; | |||
auto&& iter = c2p.find(op.dyn_typeinfo()); | |||
@@ -283,5 +320,6 @@ handle type_caster<OpDef>::cast(const OpDef& op, return_value_policy, handle) { | |||
void init_ops(py::module m) { | |||
_init_py_op_def(m); | |||
_init_py_backward_graph(m); | |||
_init_py_op_base(m); | |||
INIT_ALL_OP(m) | |||
} |
@@ -11,6 +11,7 @@ | |||
#include "megbrain/dtype.h" | |||
#include "megbrain/common.h" | |||
#include "megbrain/imperative/ops/utility.h" | |||
#include "./tensor.h" | |||
#include "./grad.h" | |||
@@ -22,10 +23,12 @@ | |||
#include <pybind11/numpy.h> | |||
#include <pybind11/operators.h> | |||
#include <range/v3/all.hpp> | |||
#include <unordered_map> | |||
namespace py = pybind11; | |||
namespace views = ranges::views; | |||
namespace mgb::imperative::python { | |||
@@ -69,21 +72,45 @@ SET_UNSET_PROP(compiled) | |||
bool skip_tracing = false; | |||
Tensor::flags_t ApplyContext::global_disable = 0; | |||
apply_result_t apply(ApplyContext& ctx) { | |||
// emulating scalar should be put to specific op's apply, e.g., | |||
// elementwise, reduce, typecvt. Currently it's still handled at python | |||
// side. It could be move to C++ side if it has an impact on performance | |||
if (ctx.flags & Tensor::Flags::SCALAR) { | |||
auto flags = ctx.flags & ~ApplyContext::global_disable; | |||
if (flags & Tensor::Flags::SCALAR) { | |||
// TODO: emulate scalar | |||
} | |||
if (ctx.flags & Tensor::Flags::GRAD) { | |||
if (flags & Tensor::Flags::GRAD) { | |||
return apply_grad(ctx); | |||
} | |||
if (ctx.flags & Tensor::Flags::TRACE) { | |||
if (flags & Tensor::Flags::TRACE) { | |||
return apply_trace(ctx); | |||
} else { | |||
if (auto* op = ctx.op->try_cast_final<GenericPyOp>()) { | |||
py::tuple pyin(ctx.nargs); | |||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||
pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); | |||
} | |||
auto f = py::getattr(op->obj, "_default_rule"); | |||
auto pyout = py::reinterpret_steal<py::object>(PyObject_Call(f.ptr(), pyin.ptr(), nullptr)); | |||
if (auto* tw = TensorWrapper::try_cast(pyout.ptr())) { | |||
return {tw->m_tensor}; | |||
} | |||
apply_result_t ret; | |||
ret.reserve(py::len(pyout)); | |||
for (auto&& i : pyout) { | |||
auto* tw = TensorWrapper::try_cast(i.ptr()); | |||
mgb_assert(tw); | |||
ret.push_back(tw->m_tensor); | |||
} | |||
return ret; | |||
} | |||
SmallVector<interpreter::Interpreter::Handle> handles(ctx.nargs); | |||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||
handles[i] = ctx.args[i]->m_handle.get(); | |||
@@ -125,12 +152,13 @@ PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObje | |||
SmallVector<Tensor*, 64> tensors(nargs); | |||
ctx.args = &tensors[0]; | |||
ctx.nargs = nargs; | |||
ctx.pytype = pytype; | |||
if (strstr(op->ob_type->tp_name, "BackwardGraph")) { | |||
ctx.backward = true; | |||
} | |||
for (size_t i = 0; i < nargs; ++i) { | |||
if (TensorWrapper* tw = TensorWrapper::cast_safe(args[i])) { | |||
if (TensorWrapper* tw = TensorWrapper::try_cast(args[i])) { | |||
auto* t = tensors[i] = tw->m_tensor.get(); | |||
ctx.flags |= t->m_flags; | |||
} else { | |||
@@ -166,7 +194,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||
if (nargs == 0) { | |||
throw py::type_error("too few arguments"); | |||
} | |||
if (auto* t = cast_safe(tup[0].ptr())) { | |||
if (auto* t = try_cast(tup[0].ptr())) { | |||
if (nargs > 1) { | |||
throw py::type_error("expect 1 argument"); | |||
} | |||
@@ -211,7 +239,7 @@ TensorWrapper::TensorWrapper(PyObject* args, PyObject* kwargs) { | |||
auto ret = pyf(*tup); | |||
auto py_ret = py::reinterpret_borrow<py::list>(ret); | |||
if (auto* t = cast_safe(py_ret[0].ptr())) { | |||
if (auto* t = try_cast(py_ret[0].ptr())) { | |||
m_tensor = t->m_tensor; | |||
} | |||
return; | |||
@@ -349,7 +377,7 @@ PyObject* TensorWrapper::varnode() { | |||
} | |||
void TensorWrapper::reset(PyObject* tensor) { | |||
TensorWrapper* t = TensorWrapper::cast_safe(tensor); | |||
TensorWrapper* t = TensorWrapper::try_cast(tensor); | |||
if (!t) { | |||
throw py::type_error("expect Tensor"); | |||
} | |||
@@ -446,7 +474,7 @@ uint8_t max_priority(SmallVector<PyArray_Descr*> types) { | |||
} | |||
} | |||
// Returns the data type with sufficient size to hold all types of | |||
// Returns the data type with sufficient size to hold all types of | |||
// category `cat` in the list `types`. | |||
PyArray_Descr* promote_types(SmallVector<PyArray_Descr*> types, uint8_t cat) { | |||
// Return value: New reference | |||
@@ -507,7 +535,7 @@ PyArray_Descr* _dtype_promotion(PyObject*const* args, size_t nargs) { | |||
for (size_t i = 0; i < nargs; ++i) { | |||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | |||
if (handle == Py_None) continue; | |||
TensorWrapper* tw = TensorWrapper::cast_safe(handle); | |||
TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||
if (tw) { | |||
mgb::DType type = tw->m_tensor->dtype(); | |||
auto&& descr = npy::dtype_mgb2np_descr(type); | |||
@@ -562,7 +590,7 @@ CompNode _get_device(PyObject*const* args, size_t nargs) { | |||
CompNode cn; | |||
for (size_t i = 0; i < nargs; ++i) { | |||
PyObject* handle = is_tuple ? PyTuple_GetItem(tuple, i): args[i]; | |||
TensorWrapper* tw = TensorWrapper::cast_safe(handle); | |||
TensorWrapper* tw = TensorWrapper::try_cast(handle); | |||
if (tw) { | |||
if (!valid) { | |||
cn = tw->m_tensor->comp_node(); | |||
@@ -124,7 +124,7 @@ struct TensorWrapper { | |||
friend wrap_t; | |||
inline static TensorWrapper* cast(PyObject* op) {return reinterpret_cast<wrap_t*>(op)->inst();} | |||
inline static TensorWrapper* cast_safe(PyObject* op) { | |||
inline static TensorWrapper* try_cast(PyObject* op) { | |||
if (!wrap_t::type().isinstance(op)) return nullptr; | |||
return cast(op); | |||
} | |||
@@ -173,11 +173,26 @@ struct TensorWrapper { | |||
PyObject* py_apply(PyObject* self, PyObject*const* args, size_t nargs/* , PyObject* kwnames */); | |||
struct ApplyContext { | |||
static Tensor::flags_t global_disable; | |||
Tensor::flags_t flags; | |||
std::shared_ptr<OpDef> op; | |||
Tensor*const* args; | |||
size_t nargs; | |||
PyTypeObject* pytype = nullptr; | |||
bool backward = false; | |||
class scoped_disable : NonCopyableObj { | |||
Tensor::flags_t saved_flags; | |||
public: | |||
scoped_disable(Tensor::flags_t flags) : saved_flags(ApplyContext::global_disable) { | |||
ApplyContext::global_disable |= flags; | |||
} | |||
~scoped_disable() { | |||
ApplyContext::global_disable = saved_flags; | |||
} | |||
}; | |||
}; | |||
using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>; | |||
@@ -85,7 +85,7 @@ apply_result_t apply_trace(ApplyContext& ctx) { | |||
// assumption: python function always returns PyList | |||
auto tup = py::reinterpret_borrow<py::list>(ret); | |||
for (auto i = 0; i < tup.size(); i++) { | |||
auto tw = TensorWrapper::cast_safe(tup[i].ptr()); | |||
auto tw = TensorWrapper::try_cast(tup[i].ptr()); | |||
outputs.emplace_back(tw->m_tensor); | |||
} | |||
return outputs; | |||
@@ -0,0 +1,21 @@ | |||
/** | |||
* \file imperative/src/impl/ops/utility.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "megbrain/imperative/ops/utility.h" | |||
#include "megbrain/imperative/ops/opr_attr.h" | |||
#include "megbrain/opr/utility.h" | |||
#include "../op_trait.h" | |||
namespace mgb::imperative { | |||
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GenericPyOp); | |||
} // namespace mgb::imperative |
@@ -0,0 +1,38 @@ | |||
/** | |||
* \file imperative/src/include/megbrain/imperative/ops/utility.h | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#pragma once | |||
#include "megbrain/imperative/op_def.h" | |||
#include "megbrain/utils/hash.h" | |||
#include <pybind11/pybind11.h> | |||
namespace mgb::imperative { | |||
struct GenericPyOp final : OpDefImplBase<GenericPyOp> { | |||
pybind11::object obj; | |||
GenericPyOp(pybind11::object obj_) : obj(std::move(obj_)) {}; | |||
size_t hash() const override { | |||
return pybind11::hash(obj); | |||
} | |||
bool is_same_st(const Hashable& rhs) const override { | |||
return obj.equal(static_cast<const GenericPyOp&>(rhs).obj); | |||
} | |||
MGB_DYN_TYPE_OBJ_FINAL_DECL; | |||
}; | |||
} // namespace mgb::imperative |