GitOrigin-RevId: 141ff0a24f
release-1.2
@@ -70,7 +70,7 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph( | |||||
for (size_t i = 0; i < ctx.nargs; ++i) { | for (size_t i = 0; i < ctx.nargs; ++i) { | ||||
inputs[i].comp_node = ctx.args[i]->comp_node(); | inputs[i].comp_node = ctx.args[i]->comp_node(); | ||||
inputs[i].layout.dtype = ctx.args[i]->dtype(); | inputs[i].layout.dtype = ctx.args[i]->dtype(); | ||||
input_requires_grad[i] = bool(ctx.args[i]->m_grad_info.grad_fn); | |||||
input_requires_grad[i] = python::input_requires_grad(ctx, i); | |||||
} | } | ||||
auto result = std::make_shared<BackwardGraphResult>( | auto result = std::make_shared<BackwardGraphResult>( | ||||
proxy_graph_detail::make_backward_graph( | proxy_graph_detail::make_backward_graph( | ||||
@@ -82,21 +82,6 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph( | |||||
return result; | return result; | ||||
} | } | ||||
struct BackwardContext { | |||||
PyTypeObject* pytype = nullptr; | |||||
auto wrap_tensor(std::shared_ptr<Tensor> t) { | |||||
if (pytype) { | |||||
return TensorWrapper::make(pytype, std::move(t)); | |||||
} | |||||
return TensorWrapper::make(std::move(t)); | |||||
} | |||||
auto wrap_tensor(Tensor* t) { | |||||
return wrap_tensor(t->shared_from_this()); | |||||
} | |||||
}; | |||||
struct BackwardGraphWithClosure { | struct BackwardGraphWithClosure { | ||||
std::shared_ptr<BackwardGraphResult> backward_graph; | std::shared_ptr<BackwardGraphResult> backward_graph; | ||||
SmallVector<std::shared_ptr<Tensor>> closure; | SmallVector<std::shared_ptr<Tensor>> closure; | ||||
@@ -270,7 +255,7 @@ struct GradFn : std::enable_shared_from_this<GradFn> { | |||||
// same length as inputs (of forward op) | // same length as inputs (of forward op) | ||||
SmallVector<GradSlotProducerPtr> dsts; | SmallVector<GradSlotProducerPtr> dsts; | ||||
// encapsules actual function to compute gradient | // encapsules actual function to compute gradient | ||||
std::variant<std::monostate, BackwardGraphWithClosure, PythonBackward> backward; | |||||
std::variant<std::monostate, BackwardGraphWithClosure, PythonBackward, CustomBackward> backward; | |||||
// a flag used during backward | // a flag used during backward | ||||
bool in_ref_keeper = false; | bool in_ref_keeper = false; | ||||
@@ -335,8 +320,7 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { | |||||
pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); | pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); | ||||
} | } | ||||
auto grad_rule = py::getattr(op->obj, "_grad_rule"); | auto grad_rule = py::getattr(op->obj, "_grad_rule"); | ||||
auto pyret = (scoped_disable(Flags::GRAD), | |||||
py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr))); // comma expression | |||||
auto pyret = py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr)); | |||||
auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret); | auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret); | ||||
ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs); | ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs); | ||||
if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) { | if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) { | ||||
@@ -388,9 +372,25 @@ apply_result_t apply_grad(ApplyContext& ctx) { | |||||
} | } | ||||
GradFnHelper grad_fn_holder; | GradFnHelper grad_fn_holder; | ||||
auto outputs = ctx.op->same_type<GenericPyOp>() ? | |||||
python_grad_rule(ctx, grad_fn_holder) : | |||||
backward_graph_grad_rule(ctx, grad_fn_holder); | |||||
auto outputs = [&]() { | |||||
auto _ = scoped_disable(Flags::GRAD); | |||||
if (ctx.op->same_type<GenericPyOp>()) { | |||||
return python_grad_rule(ctx, grad_fn_holder); | |||||
} | |||||
auto&& registry = grad_rule_registry(); | |||||
auto&& it = registry.find(ctx.op->dyn_typeinfo()); | |||||
if (it != registry.end()) { | |||||
auto&& maker = grad_fn_holder.emplace<CustomBackward>().maker(ctx); | |||||
try { | |||||
auto ret = it->second(ctx, maker); | |||||
maker.finalize(); | |||||
return ret; | |||||
} catch (GradRuleFallback&) { | |||||
grad_fn_holder.emplace<std::monostate>(); | |||||
} | |||||
} | |||||
return backward_graph_grad_rule(ctx, grad_fn_holder); | |||||
}(); | |||||
auto& grad_fn = grad_fn_holder.grad_fn; | auto& grad_fn = grad_fn_holder.grad_fn; | ||||
if (!grad_fn) { | if (!grad_fn) { | ||||
@@ -407,7 +407,7 @@ apply_result_t apply_grad(ApplyContext& ctx) { | |||||
mgb_assert(0); | mgb_assert(0); | ||||
} else { | } else { | ||||
for (size_t i = 0; i < ctx.nargs; ++i) { | for (size_t i = 0; i < ctx.nargs; ++i) { | ||||
if (backward.input_has_grad(i)) { | |||||
if (backward.input_has_grad(i) && input_requires_grad(ctx, i)) { | |||||
auto& input_grad_info = ctx.args[i]->m_grad_info; | auto& input_grad_info = ctx.args[i]->m_grad_info; | ||||
grad_fn->dsts.emplace_back(input_grad_info); | grad_fn->dsts.emplace_back(input_grad_info); | ||||
// register as grad producer | // register as grad producer | ||||
@@ -487,18 +487,8 @@ void accum_grad(std::shared_ptr<Tensor>& grad, T&& delta) { | |||||
grad = std::forward<T>(delta); | grad = std::forward<T>(delta); | ||||
return; | return; | ||||
} | } | ||||
static ApplyContext ctx; | |||||
if (!ctx.op) { | |||||
ctx.op = std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD)); | |||||
ctx.nargs = 2; | |||||
} | |||||
Tensor* args[2] = {grad.get(), delta.get()}; | |||||
ctx.args = args; | |||||
ctx.flags = grad->m_flags | delta->m_flags; | |||||
if (is_tracing) { | |||||
ctx.flags |= Flags::TRACE; | |||||
} | |||||
grad = apply(ctx)[0]; | |||||
static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD)); | |||||
grad = apply(op, grad, std::forward<T>(delta))[0]; | |||||
} | } | ||||
void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) { | void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) { | ||||
@@ -582,4 +572,9 @@ GradKey::~GradKey() { | |||||
cleanup(); | cleanup(); | ||||
} | } | ||||
std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() { | |||||
static std::unordered_map<Typeinfo*, GradRuleFn> registry; | |||||
return registry; | |||||
} | |||||
} // namespace mgb::imperative::python | } // namespace mgb::imperative::python |
@@ -45,6 +45,117 @@ struct GradKeyWrapper { | |||||
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | ||||
}; | }; | ||||
struct BackwardContext { | |||||
PyTypeObject* pytype = nullptr; | |||||
auto wrap_tensor(std::shared_ptr<Tensor> t) { | |||||
if (pytype) { | |||||
return TensorWrapper::make(pytype, std::move(t)); | |||||
} | |||||
return TensorWrapper::make(std::move(t)); | |||||
} | |||||
auto wrap_tensor(Tensor* t) { | |||||
return wrap_tensor(t->shared_from_this()); | |||||
} | |||||
}; | |||||
struct CustomBackward { | |||||
using BackwardFn = std::function<apply_result_t(BackwardContext&, Tensor*const*, size_t)>; | |||||
BackwardFn m_backward; | |||||
SmallVector<bool, 8> m_input_has_grad; | |||||
struct OutputAttr {bool requires_grad = true, captured = true;}; | |||||
SmallVector<OutputAttr> m_output_attrs; | |||||
public: | |||||
template<typename T, typename R> | |||||
void operator()(BackwardContext& ctx, T&& grads, R&& receiver) { | |||||
size_t nargs = grads.size(); | |||||
Tensor* args[nargs]; | |||||
for (size_t i = 0; i < nargs; ++i) { | |||||
args[i] = grads[i]; | |||||
} | |||||
auto ret = m_backward(ctx, args, nargs); | |||||
for (size_t i = 0; i < ret.size(); ++i) { | |||||
if (auto&& t = ret[i]) { | |||||
receiver(i, std::move(t)); | |||||
} | |||||
} | |||||
} | |||||
bool input_has_grad(size_t i) {return m_input_has_grad[i];} | |||||
bool output_requires_grad(size_t i) {return m_output_attrs[i].requires_grad;} | |||||
bool output_captured(size_t i) {return m_output_attrs[i].captured;} | |||||
class Maker { | |||||
bool output_size_set = false, input_has_grad_initialized = false; | |||||
CustomBackward& target; | |||||
ApplyContext& ctx; | |||||
void init_input_has_grad() { | |||||
if (!input_has_grad_initialized) { | |||||
input_has_grad_initialized = true; | |||||
target.m_input_has_grad.resize(ctx.nargs, true); | |||||
} | |||||
} | |||||
public: | |||||
Maker(CustomBackward& target_, ApplyContext& ctx_) : target(target_), ctx(ctx_) {} | |||||
template<typename F> | |||||
Maker& backward(F&& f) { | |||||
mgb_assert(!target.m_backward); | |||||
target.m_backward = std::forward<F>(f); | |||||
return *this; | |||||
} | |||||
// mandatory | |||||
Maker& output_size(size_t sz) { | |||||
mgb_assert(!output_size_set); | |||||
output_size_set = true; | |||||
target.m_output_attrs.resize(sz); | |||||
return *this; | |||||
} | |||||
// optional, defaults to all true | |||||
Maker& input_has_grad(size_t i, bool v) { | |||||
init_input_has_grad(); | |||||
target.m_input_has_grad.at(i) = v; | |||||
return *this; | |||||
} | |||||
// optional, defaults to all true | |||||
Maker& output_requires_grad(size_t i, bool v) { | |||||
target.m_output_attrs.at(i).requires_grad = v; | |||||
return *this; | |||||
} | |||||
// optional, defaults to all true | |||||
Maker& output_captured(size_t i, bool v) { | |||||
target.m_output_attrs.at(i).captured = v; | |||||
return *this; | |||||
} | |||||
void finalize() { | |||||
mgb_assert(output_size_set); | |||||
init_input_has_grad(); | |||||
} | |||||
}; | |||||
Maker maker(ApplyContext& ctx) {return {*this, ctx};} | |||||
}; | |||||
using GradRuleFn = std::function<apply_result_t(ApplyContext&, CustomBackward::Maker&)>; | |||||
std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry(); | |||||
inline bool input_requires_grad(const ApplyContext& ctx, size_t i) { | |||||
return bool(ctx.args[i]->m_grad_info.grad_fn); | |||||
} | |||||
struct GradRuleFallback : std::exception {}; | |||||
template<typename T> | |||||
bool register_grad_rule(Typeinfo* typeinfo, T&& rule) { | |||||
return grad_rule_registry().emplace(typeinfo, std::forward<T>(rule)).second; | |||||
} | |||||
} // namespace mgb::imperative::python | } // namespace mgb::imperative::python | ||||
namespace pybind11::detail { | namespace pybind11::detail { | ||||
@@ -0,0 +1,63 @@ | |||||
/** | |||||
* \file imperative/python/src/grad_override.cpp | |||||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
* | |||||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, | |||||
* software distributed under the License is distributed on an | |||||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
*/ | |||||
#include "./grad.h" | |||||
#include "megbrain/imperative/ops/autogen.h" | |||||
namespace mgb::imperative::python { | |||||
namespace { | |||||
std::shared_ptr<Tensor> get_shape(Tensor* x) { | |||||
static auto op = GetVarShape::make(); | |||||
return python::apply(op, x)[0]; | |||||
} | |||||
std::shared_ptr<Tensor> reduce_to(Tensor* x, Tensor* s) { | |||||
static auto op = Reduce::make(); | |||||
return python::apply(op, x, s)[0]; | |||||
} | |||||
apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { | |||||
auto& op = ctx.op->cast_final_safe<Elemwise>(); | |||||
if (op.mode == Elemwise::Mode::ADD) { | |||||
mgb_assert(ctx.nargs == 2); | |||||
std::array<std::shared_ptr<Tensor>, 2> input_shapes; | |||||
for (size_t i = 0; i < 2; ++i) { | |||||
if (input_requires_grad(ctx, i)) { | |||||
input_shapes[i] = get_shape(ctx.args[i]); | |||||
} | |||||
} | |||||
maker.output_size(1).output_captured(0, false); | |||||
maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) { | |||||
mgb_assert(ngrads == 1); | |||||
Tensor* grad = grads[0]; | |||||
apply_result_t ret(2); | |||||
for (size_t i = 0; i < 2; ++i) { | |||||
if (shapes[i]) { | |||||
ret[i] = reduce_to(grad, shapes[i].get()); | |||||
} | |||||
} | |||||
return ret; | |||||
}); | |||||
return apply(ctx); | |||||
} | |||||
throw GradRuleFallback(); | |||||
} | |||||
struct Init { | |||||
Init() { | |||||
auto& reg = grad_rule_registry(); | |||||
reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule); | |||||
} | |||||
} _; | |||||
} // namespace | |||||
} // namespace mgb::imperative::python |
@@ -199,12 +199,59 @@ using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>; | |||||
apply_result_t apply(ApplyContext& ctx); | apply_result_t apply(ApplyContext& ctx); | ||||
void init_tensor(pybind11::module); | |||||
template <typename T> | |||||
decltype(auto) resolve_arrow(T&& p) { | |||||
if constexpr (std::is_pointer_v<std::remove_reference_t<T>>) { | |||||
auto* ret = p; | |||||
return ret; | |||||
} else { | |||||
auto probe = [](auto&& p) -> decltype(p.operator->()) {}; | |||||
if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) { | |||||
return resolve_arrow(p.operator->()); | |||||
} else { | |||||
return p; | |||||
} | |||||
} | |||||
} | |||||
template <typename... Args> | |||||
constexpr bool is_all_tensor_ptr = (... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>); | |||||
extern bool is_tracing; | |||||
extern bool is_tracing; // FIXME: should use ApplyContext::global_enable | |||||
extern bool is_symbolic; | extern bool is_symbolic; | ||||
extern bool is_compiled; | extern bool is_compiled; | ||||
template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0> | |||||
apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) { | |||||
ApplyContext ctx; | |||||
Tensor* arg_arr[] = {resolve_arrow(args)...}; | |||||
ctx.flags = (0 | ... | args->m_flags); | |||||
ctx.flags |= is_tracing ? Tensor::Flags::TRACE : 0; | |||||
ctx.args = arg_arr; | |||||
ctx.nargs = sizeof...(args); | |||||
ctx.op = std::move(op); | |||||
return apply(ctx); | |||||
} | |||||
template <typename T> | |||||
auto apply(std::shared_ptr<OpDef> op, T&& tensors) | |||||
-> std::enable_if_t<std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>, | |||||
apply_result_t> { | |||||
ApplyContext ctx; | |||||
ctx.op = std::move(op); | |||||
ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0; | |||||
ctx.nargs = tensors.size(); | |||||
Tensor* args[ctx.nargs]; | |||||
ctx.args = args; | |||||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||||
args[i] = resolve_arrow(tensors[i]); | |||||
ctx.flags |= args[i]->m_flags; | |||||
} | |||||
return apply(ctx); | |||||
} | |||||
void init_tensor(pybind11::module); | |||||
extern pybind11::object cpp_apply_with_tracing, cpp_apply_compiled_mode; | extern pybind11::object cpp_apply_with_tracing, cpp_apply_compiled_mode; | ||||
extern pybind11::object cpp_apply_backward_varnode; | extern pybind11::object cpp_apply_backward_varnode; | ||||