GitOrigin-RevId: 141ff0a24f
release-1.2
@@ -70,7 +70,7 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph( | |||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||
inputs[i].comp_node = ctx.args[i]->comp_node(); | |||
inputs[i].layout.dtype = ctx.args[i]->dtype(); | |||
input_requires_grad[i] = bool(ctx.args[i]->m_grad_info.grad_fn); | |||
input_requires_grad[i] = python::input_requires_grad(ctx, i); | |||
} | |||
auto result = std::make_shared<BackwardGraphResult>( | |||
proxy_graph_detail::make_backward_graph( | |||
@@ -82,21 +82,6 @@ std::shared_ptr<BackwardGraphResult> make_backward_graph( | |||
return result; | |||
} | |||
struct BackwardContext { | |||
PyTypeObject* pytype = nullptr; | |||
auto wrap_tensor(std::shared_ptr<Tensor> t) { | |||
if (pytype) { | |||
return TensorWrapper::make(pytype, std::move(t)); | |||
} | |||
return TensorWrapper::make(std::move(t)); | |||
} | |||
auto wrap_tensor(Tensor* t) { | |||
return wrap_tensor(t->shared_from_this()); | |||
} | |||
}; | |||
struct BackwardGraphWithClosure { | |||
std::shared_ptr<BackwardGraphResult> backward_graph; | |||
SmallVector<std::shared_ptr<Tensor>> closure; | |||
@@ -270,7 +255,7 @@ struct GradFn : std::enable_shared_from_this<GradFn> { | |||
// same length as inputs (of forward op) | |||
SmallVector<GradSlotProducerPtr> dsts; | |||
// encapsules actual function to compute gradient | |||
std::variant<std::monostate, BackwardGraphWithClosure, PythonBackward> backward; | |||
std::variant<std::monostate, BackwardGraphWithClosure, PythonBackward, CustomBackward> backward; | |||
// a flag used during backward | |||
bool in_ref_keeper = false; | |||
@@ -335,8 +320,7 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { | |||
pyin[i] = TensorWrapper::make(ctx.pytype, ctx.args[i]->shared_from_this()); | |||
} | |||
auto grad_rule = py::getattr(op->obj, "_grad_rule"); | |||
auto pyret = (scoped_disable(Flags::GRAD), | |||
py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr))); // comma expression | |||
auto pyret = py::reinterpret_steal<py::object>(PyObject_Call(grad_rule.ptr(), pyin.ptr(), nullptr)); | |||
auto [outputs, backward] = py::cast<std::tuple<py::object, py::function>>(pyret); | |||
ret_grad_fn.emplace<PythonBackward>(std::move(backward), ctx.nargs); | |||
if (auto* tw = TensorWrapper::try_cast(outputs.ptr())) { | |||
@@ -388,9 +372,25 @@ apply_result_t apply_grad(ApplyContext& ctx) { | |||
} | |||
GradFnHelper grad_fn_holder; | |||
auto outputs = ctx.op->same_type<GenericPyOp>() ? | |||
python_grad_rule(ctx, grad_fn_holder) : | |||
backward_graph_grad_rule(ctx, grad_fn_holder); | |||
auto outputs = [&]() { | |||
auto _ = scoped_disable(Flags::GRAD); | |||
if (ctx.op->same_type<GenericPyOp>()) { | |||
return python_grad_rule(ctx, grad_fn_holder); | |||
} | |||
auto&& registry = grad_rule_registry(); | |||
auto&& it = registry.find(ctx.op->dyn_typeinfo()); | |||
if (it != registry.end()) { | |||
auto&& maker = grad_fn_holder.emplace<CustomBackward>().maker(ctx); | |||
try { | |||
auto ret = it->second(ctx, maker); | |||
maker.finalize(); | |||
return ret; | |||
} catch (GradRuleFallback&) { | |||
grad_fn_holder.emplace<std::monostate>(); | |||
} | |||
} | |||
return backward_graph_grad_rule(ctx, grad_fn_holder); | |||
}(); | |||
auto& grad_fn = grad_fn_holder.grad_fn; | |||
if (!grad_fn) { | |||
@@ -407,7 +407,7 @@ apply_result_t apply_grad(ApplyContext& ctx) { | |||
mgb_assert(0); | |||
} else { | |||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||
if (backward.input_has_grad(i)) { | |||
if (backward.input_has_grad(i) && input_requires_grad(ctx, i)) { | |||
auto& input_grad_info = ctx.args[i]->m_grad_info; | |||
grad_fn->dsts.emplace_back(input_grad_info); | |||
// register as grad producer | |||
@@ -487,18 +487,8 @@ void accum_grad(std::shared_ptr<Tensor>& grad, T&& delta) { | |||
grad = std::forward<T>(delta); | |||
return; | |||
} | |||
static ApplyContext ctx; | |||
if (!ctx.op) { | |||
ctx.op = std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD)); | |||
ctx.nargs = 2; | |||
} | |||
Tensor* args[2] = {grad.get(), delta.get()}; | |||
ctx.args = args; | |||
ctx.flags = grad->m_flags | delta->m_flags; | |||
if (is_tracing) { | |||
ctx.flags |= Flags::TRACE; | |||
} | |||
grad = apply(ctx)[0]; | |||
static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new Elemwise(Elemwise::Mode::ADD)); | |||
grad = apply(op, grad, std::forward<T>(delta))[0]; | |||
} | |||
void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWrapper*> grads) { | |||
@@ -582,4 +572,9 @@ GradKey::~GradKey() { | |||
cleanup(); | |||
} | |||
std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() { | |||
static std::unordered_map<Typeinfo*, GradRuleFn> registry; | |||
return registry; | |||
} | |||
} // namespace mgb::imperative::python |
@@ -45,6 +45,117 @@ struct GradKeyWrapper { | |||
void backward(std::vector<TensorWrapper*>, std::vector<TensorWrapper*>); | |||
}; | |||
struct BackwardContext { | |||
PyTypeObject* pytype = nullptr; | |||
auto wrap_tensor(std::shared_ptr<Tensor> t) { | |||
if (pytype) { | |||
return TensorWrapper::make(pytype, std::move(t)); | |||
} | |||
return TensorWrapper::make(std::move(t)); | |||
} | |||
auto wrap_tensor(Tensor* t) { | |||
return wrap_tensor(t->shared_from_this()); | |||
} | |||
}; | |||
struct CustomBackward { | |||
using BackwardFn = std::function<apply_result_t(BackwardContext&, Tensor*const*, size_t)>; | |||
BackwardFn m_backward; | |||
SmallVector<bool, 8> m_input_has_grad; | |||
struct OutputAttr {bool requires_grad = true, captured = true;}; | |||
SmallVector<OutputAttr> m_output_attrs; | |||
public: | |||
template<typename T, typename R> | |||
void operator()(BackwardContext& ctx, T&& grads, R&& receiver) { | |||
size_t nargs = grads.size(); | |||
Tensor* args[nargs]; | |||
for (size_t i = 0; i < nargs; ++i) { | |||
args[i] = grads[i]; | |||
} | |||
auto ret = m_backward(ctx, args, nargs); | |||
for (size_t i = 0; i < ret.size(); ++i) { | |||
if (auto&& t = ret[i]) { | |||
receiver(i, std::move(t)); | |||
} | |||
} | |||
} | |||
bool input_has_grad(size_t i) {return m_input_has_grad[i];} | |||
bool output_requires_grad(size_t i) {return m_output_attrs[i].requires_grad;} | |||
bool output_captured(size_t i) {return m_output_attrs[i].captured;} | |||
class Maker { | |||
bool output_size_set = false, input_has_grad_initialized = false; | |||
CustomBackward& target; | |||
ApplyContext& ctx; | |||
void init_input_has_grad() { | |||
if (!input_has_grad_initialized) { | |||
input_has_grad_initialized = true; | |||
target.m_input_has_grad.resize(ctx.nargs, true); | |||
} | |||
} | |||
public: | |||
Maker(CustomBackward& target_, ApplyContext& ctx_) : target(target_), ctx(ctx_) {} | |||
template<typename F> | |||
Maker& backward(F&& f) { | |||
mgb_assert(!target.m_backward); | |||
target.m_backward = std::forward<F>(f); | |||
return *this; | |||
} | |||
// mandatory | |||
Maker& output_size(size_t sz) { | |||
mgb_assert(!output_size_set); | |||
output_size_set = true; | |||
target.m_output_attrs.resize(sz); | |||
return *this; | |||
} | |||
// optional, defaults to all true | |||
Maker& input_has_grad(size_t i, bool v) { | |||
init_input_has_grad(); | |||
target.m_input_has_grad.at(i) = v; | |||
return *this; | |||
} | |||
// optional, defaults to all true | |||
Maker& output_requires_grad(size_t i, bool v) { | |||
target.m_output_attrs.at(i).requires_grad = v; | |||
return *this; | |||
} | |||
// optional, defaults to all true | |||
Maker& output_captured(size_t i, bool v) { | |||
target.m_output_attrs.at(i).captured = v; | |||
return *this; | |||
} | |||
void finalize() { | |||
mgb_assert(output_size_set); | |||
init_input_has_grad(); | |||
} | |||
}; | |||
Maker maker(ApplyContext& ctx) {return {*this, ctx};} | |||
}; | |||
using GradRuleFn = std::function<apply_result_t(ApplyContext&, CustomBackward::Maker&)>; | |||
std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry(); | |||
inline bool input_requires_grad(const ApplyContext& ctx, size_t i) { | |||
return bool(ctx.args[i]->m_grad_info.grad_fn); | |||
} | |||
struct GradRuleFallback : std::exception {}; | |||
template<typename T> | |||
bool register_grad_rule(Typeinfo* typeinfo, T&& rule) { | |||
return grad_rule_registry().emplace(typeinfo, std::forward<T>(rule)).second; | |||
} | |||
} // namespace mgb::imperative::python | |||
namespace pybind11::detail { | |||
@@ -0,0 +1,63 @@ | |||
/** | |||
* \file imperative/python/src/grad_override.cpp | |||
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
* | |||
* Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
* | |||
* Unless required by applicable law or agreed to in writing, | |||
* software distributed under the License is distributed on an | |||
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
*/ | |||
#include "./grad.h" | |||
#include "megbrain/imperative/ops/autogen.h" | |||
namespace mgb::imperative::python { | |||
namespace { | |||
std::shared_ptr<Tensor> get_shape(Tensor* x) { | |||
static auto op = GetVarShape::make(); | |||
return python::apply(op, x)[0]; | |||
} | |||
std::shared_ptr<Tensor> reduce_to(Tensor* x, Tensor* s) { | |||
static auto op = Reduce::make(); | |||
return python::apply(op, x, s)[0]; | |||
} | |||
apply_result_t elemwise_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { | |||
auto& op = ctx.op->cast_final_safe<Elemwise>(); | |||
if (op.mode == Elemwise::Mode::ADD) { | |||
mgb_assert(ctx.nargs == 2); | |||
std::array<std::shared_ptr<Tensor>, 2> input_shapes; | |||
for (size_t i = 0; i < 2; ++i) { | |||
if (input_requires_grad(ctx, i)) { | |||
input_shapes[i] = get_shape(ctx.args[i]); | |||
} | |||
} | |||
maker.output_size(1).output_captured(0, false); | |||
maker.backward([shapes=std::move(input_shapes)](BackwardContext&, Tensor*const* grads, size_t ngrads) { | |||
mgb_assert(ngrads == 1); | |||
Tensor* grad = grads[0]; | |||
apply_result_t ret(2); | |||
for (size_t i = 0; i < 2; ++i) { | |||
if (shapes[i]) { | |||
ret[i] = reduce_to(grad, shapes[i].get()); | |||
} | |||
} | |||
return ret; | |||
}); | |||
return apply(ctx); | |||
} | |||
throw GradRuleFallback(); | |||
} | |||
struct Init { | |||
Init() { | |||
auto& reg = grad_rule_registry(); | |||
reg.emplace(Elemwise::typeinfo(), elemwise_grad_rule); | |||
} | |||
} _; | |||
} // namespace | |||
} // namespace mgb::imperative::python |
@@ -199,12 +199,59 @@ using apply_result_t = SmallVector<std::shared_ptr<Tensor>, 8>; | |||
apply_result_t apply(ApplyContext& ctx); | |||
void init_tensor(pybind11::module); | |||
template <typename T> | |||
decltype(auto) resolve_arrow(T&& p) { | |||
if constexpr (std::is_pointer_v<std::remove_reference_t<T>>) { | |||
auto* ret = p; | |||
return ret; | |||
} else { | |||
auto probe = [](auto&& p) -> decltype(p.operator->()) {}; | |||
if constexpr (std::is_invocable_v<decltype(probe), decltype(p)>) { | |||
return resolve_arrow(p.operator->()); | |||
} else { | |||
return p; | |||
} | |||
} | |||
} | |||
template <typename... Args> | |||
constexpr bool is_all_tensor_ptr = (... && std::is_same_v<decltype(resolve_arrow(std::declval<Args>())), Tensor*>); | |||
extern bool is_tracing; | |||
extern bool is_tracing; // FIXME: should use ApplyContext::global_enable | |||
extern bool is_symbolic; | |||
extern bool is_compiled; | |||
template <typename... Args, std::enable_if_t<is_all_tensor_ptr<Args...>, int> = 0> | |||
apply_result_t apply(std::shared_ptr<OpDef> op, Args&&... args) { | |||
ApplyContext ctx; | |||
Tensor* arg_arr[] = {resolve_arrow(args)...}; | |||
ctx.flags = (0 | ... | args->m_flags); | |||
ctx.flags |= is_tracing ? Tensor::Flags::TRACE : 0; | |||
ctx.args = arg_arr; | |||
ctx.nargs = sizeof...(args); | |||
ctx.op = std::move(op); | |||
return apply(ctx); | |||
} | |||
template <typename T> | |||
auto apply(std::shared_ptr<OpDef> op, T&& tensors) | |||
-> std::enable_if_t<std::is_same_v<decltype(resolve_arrow(tensors[0])), Tensor*>, | |||
apply_result_t> { | |||
ApplyContext ctx; | |||
ctx.op = std::move(op); | |||
ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0; | |||
ctx.nargs = tensors.size(); | |||
Tensor* args[ctx.nargs]; | |||
ctx.args = args; | |||
for (size_t i = 0; i < ctx.nargs; ++i) { | |||
args[i] = resolve_arrow(tensors[i]); | |||
ctx.flags |= args[i]->m_flags; | |||
} | |||
return apply(ctx); | |||
} | |||
void init_tensor(pybind11::module); | |||
extern pybind11::object cpp_apply_with_tracing, cpp_apply_compiled_mode; | |||
extern pybind11::object cpp_apply_backward_varnode; | |||