|
|
@@ -54,7 +54,7 @@ std::shared_ptr<OptimizedBackwardGraphResult> make_backward_graph( |
|
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
|
*(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle()); |
|
|
|
*(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node()); |
|
|
|
*(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn); |
|
|
|
*(bool_ptr++) = !ctx.args[i]->m_grad_info_dict.empty(); |
|
|
|
} |
|
|
|
mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) && |
|
|
|
bool_ptr == reinterpret_cast<bool*>(buf + buf_size)); |
|
|
@@ -321,7 +321,7 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra |
|
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
|
inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]); |
|
|
|
inputs_copy_weak.push_back(inputs_copy.back().get()); |
|
|
|
inputs_copy.back()->m_grad_info = ctx.args[i]->m_grad_info; |
|
|
|
inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict; |
|
|
|
} |
|
|
|
ApplyContext ctx_dup = ctx; |
|
|
|
ctx_dup.args = inputs_copy_weak.data(); |
|
|
@@ -365,25 +365,19 @@ apply_result_t python_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { |
|
|
|
} // namespace |
|
|
|
|
|
|
|
apply_result_t apply_grad(ApplyContext& ctx) { |
|
|
|
std::shared_ptr<GradKey> grad_key; |
|
|
|
std::unordered_set<std::shared_ptr<GradKey>> grad_keys; |
|
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
|
auto* tensor = ctx.args[i]; |
|
|
|
if (tensor->m_grad_info.grad_fn) { |
|
|
|
auto&& input_grad_key = tensor->m_grad_info.grad_fn->key.lock(); |
|
|
|
// tensor is attached to a live GradKey |
|
|
|
if (input_grad_key && input_grad_key->active) { |
|
|
|
if (grad_key) { |
|
|
|
if (grad_key != input_grad_key) { |
|
|
|
PyErr_SetString(PyExc_NotImplementedError, "second order grad"); |
|
|
|
throw pyext17::py_err_set(); |
|
|
|
} |
|
|
|
} else { |
|
|
|
grad_key = std::move(input_grad_key); |
|
|
|
if (!tensor->m_grad_info_dict.empty()) { |
|
|
|
size_t grad_cnt = 0; |
|
|
|
for (auto&& grad_info: tensor->m_grad_info_dict) { |
|
|
|
auto input_grad_key = grad_info.grad_fn->key.lock(); |
|
|
|
if (input_grad_key && input_grad_key->active && !input_grad_key->is_blocked()) { |
|
|
|
grad_keys.insert(input_grad_key); |
|
|
|
grad_cnt++; |
|
|
|
} |
|
|
|
} else { |
|
|
|
// cleanup stale grad info |
|
|
|
// under what condition? |
|
|
|
tensor->m_grad_info = {}; |
|
|
|
} |
|
|
|
if (!grad_cnt) { |
|
|
|
tensor->m_flags &= ~Flags::GRAD; |
|
|
|
} |
|
|
|
} else { |
|
|
@@ -393,7 +387,7 @@ apply_result_t apply_grad(ApplyContext& ctx) { |
|
|
|
|
|
|
|
ctx.flags &= ~Flags::GRAD; |
|
|
|
|
|
|
|
if (!grad_key) { |
|
|
|
if (grad_keys.empty()) { |
|
|
|
return apply(ctx); |
|
|
|
} |
|
|
|
|
|
|
@@ -418,54 +412,65 @@ apply_result_t apply_grad(ApplyContext& ctx) { |
|
|
|
return backward_graph_grad_rule(ctx, grad_fn_holder); |
|
|
|
}(); |
|
|
|
|
|
|
|
auto& grad_fn = grad_fn_holder.grad_fn; |
|
|
|
if (!grad_fn) { |
|
|
|
if (!grad_fn_holder.grad_fn) { |
|
|
|
return outputs; |
|
|
|
} |
|
|
|
|
|
|
|
grad_fn->key = grad_key; |
|
|
|
grad_fn->slots.resize(outputs.size()); |
|
|
|
grad_fn->dsts.reserve(ctx.nargs); |
|
|
|
for (auto&& grad_key: grad_keys) { |
|
|
|
auto grad_fn = std::make_shared<GradFn>(); |
|
|
|
grad_fn->backward = grad_fn_holder.grad_fn->backward; |
|
|
|
grad_fn->key = grad_key; |
|
|
|
grad_fn->slots.resize(outputs.size()); |
|
|
|
grad_fn->dsts.reserve(ctx.nargs); |
|
|
|
|
|
|
|
std::visit([&](auto& backward) { |
|
|
|
using T = std::decay_t<decltype(backward)>; |
|
|
|
if constexpr (std::is_same_v<T, std::monostate>) { |
|
|
|
mgb_assert(0); |
|
|
|
} else { |
|
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
|
if (backward.input_has_grad(i) && input_requires_grad(ctx, i)) { |
|
|
|
auto& input_grad_info = ctx.args[i]->m_grad_info; |
|
|
|
grad_fn->dsts.emplace_back(input_grad_info); |
|
|
|
// register as grad producer |
|
|
|
grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head); |
|
|
|
} else { |
|
|
|
grad_fn->dsts.emplace_back(); |
|
|
|
std::visit([&](auto& backward) { |
|
|
|
using T = std::decay_t<decltype(backward)>; |
|
|
|
if constexpr (std::is_same_v<T, std::monostate>) { |
|
|
|
mgb_assert(0); |
|
|
|
} else { |
|
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
|
if (backward.input_has_grad(i) && input_requires_grad(ctx, i) && ctx.args[i]->m_grad_info_dict.count(grad_key.get())) { |
|
|
|
auto& input_grad_info = ctx.args[i]->m_grad_info_dict.at(grad_key.get()); |
|
|
|
grad_fn->dsts.emplace_back(input_grad_info); |
|
|
|
// register as grad producer |
|
|
|
grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head); |
|
|
|
} else { |
|
|
|
grad_fn->dsts.emplace_back(); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
for (size_t i = 0; i < outputs.size(); ++i) { |
|
|
|
if (backward.output_requires_grad(i)) { |
|
|
|
if (backward.output_captured(i)) { |
|
|
|
// avoid reference cycle [Tensor <-> GradFn] |
|
|
|
static std::shared_ptr<OpDef> op = std::shared_ptr<OpDef>(new FastpathCopy()); |
|
|
|
outputs[i] = python::apply(op, outputs[i])[0]; |
|
|
|
for (size_t i = 0; i < outputs.size(); ++i) { |
|
|
|
if (backward.output_requires_grad(i)) { |
|
|
|
if (backward.output_captured(i)) { |
|
|
|
// avoid reference cycle [Tensor <-> GradFn] |
|
|
|
static std::shared_ptr<OpDef> op = std::make_shared<FastpathCopy>(); |
|
|
|
outputs[i] = python::apply(op, outputs[i])[0]; |
|
|
|
} |
|
|
|
// populate grad info of output tensor |
|
|
|
auto& grad_info = outputs[i]->m_grad_info_dict[grad_key.get()]; |
|
|
|
grad_info.grad_fn = grad_fn; |
|
|
|
grad_info.idx = i; |
|
|
|
grad_info.insert_after(grad_key->free_vars_head); |
|
|
|
outputs[i]->m_flags |= Flags::GRAD; |
|
|
|
} |
|
|
|
// populate grad info of output tensor |
|
|
|
auto& grad_info = outputs[i]->m_grad_info; |
|
|
|
grad_info.grad_fn = grad_fn; |
|
|
|
grad_info.idx = i; |
|
|
|
grad_info.insert_after(grad_key->free_vars_head); |
|
|
|
outputs[i]->m_flags |= Flags::GRAD; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
}, grad_fn->backward); |
|
|
|
}, grad_fn->backward); |
|
|
|
|
|
|
|
// record forward history |
|
|
|
grad_key->tape.emplace_back(grad_fn); |
|
|
|
// record forward history |
|
|
|
grad_key->tape.emplace_back(grad_fn); |
|
|
|
} |
|
|
|
|
|
|
|
return outputs; |
|
|
|
} |
|
|
|
|
|
|
|
PyObject* GradKeyWrapper::get_priority() { |
|
|
|
return py::cast(m_key->priority).release().ptr(); |
|
|
|
} |
|
|
|
|
|
|
|
void GradKeyWrapper::set_priority(pybind11::handle priority) { |
|
|
|
m_key->name = py::cast<int>(priority); |
|
|
|
} |
|
|
|
|
|
|
|
void GradKeyWrapper::attach(PyObject*const* args, size_t nargs) { |
|
|
|
if (nargs != 2) { |
|
|
|
throw py::type_error("expect 2 arguments"); |
|
|
@@ -488,24 +493,21 @@ void GradKey::attach(Tensor* tensor, pybind11::object callback) { |
|
|
|
throw py::value_error("grad key finalized"); |
|
|
|
} |
|
|
|
|
|
|
|
if (tensor->m_grad_info.grad_fn) { |
|
|
|
if (tensor->m_grad_info.grad_fn->key.lock().get() != this) { |
|
|
|
PyErr_SetString(PyExc_NotImplementedError, "second order grad"); |
|
|
|
throw pyext17::py_err_set(); |
|
|
|
} |
|
|
|
if (tensor->m_grad_info->callback) { |
|
|
|
if (tensor->m_grad_info_dict.count(this)) { |
|
|
|
if (tensor->m_grad_info_dict.at(this)->callback) { |
|
|
|
throw py::value_error("callback already set on this tensor"); |
|
|
|
} |
|
|
|
} else { |
|
|
|
tensor->m_grad_info.idx = 0; |
|
|
|
auto& grad_fn = tensor->m_grad_info.grad_fn; |
|
|
|
auto& grad_info = tensor->m_grad_info_dict[this]; |
|
|
|
grad_info.idx = 0; |
|
|
|
auto& grad_fn = grad_info.grad_fn; |
|
|
|
grad_fn = std::make_shared<GradFn>(); |
|
|
|
grad_fn->key = shared_from_this(); |
|
|
|
grad_fn->slots.resize(1); |
|
|
|
tensor->m_grad_info.insert_after(free_vars_head); |
|
|
|
grad_info.insert_after(free_vars_head); |
|
|
|
tensor->m_flags |= Flags::GRAD; |
|
|
|
} |
|
|
|
tensor->m_grad_info.grad_fn->slots[0].callback = std::move(callback); |
|
|
|
tensor->m_grad_info_dict.at(this).grad_fn->slots[0].callback = std::move(callback); |
|
|
|
} |
|
|
|
|
|
|
|
template<typename T> |
|
|
@@ -530,8 +532,15 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr |
|
|
|
active = false; |
|
|
|
struct CleanupGuard { |
|
|
|
GradKey* owner; |
|
|
|
CleanupGuard(GradKey* this_) : owner(this_) {} |
|
|
|
~CleanupGuard() {owner->cleanup();} |
|
|
|
size_t priority_backup; |
|
|
|
CleanupGuard(GradKey* this_) : owner(this_) { |
|
|
|
priority_backup = sm_min_priority; |
|
|
|
sm_min_priority = owner->priority; |
|
|
|
} |
|
|
|
~CleanupGuard() { |
|
|
|
owner->cleanup(); |
|
|
|
sm_min_priority = priority_backup; |
|
|
|
} |
|
|
|
} _cleanup_guard(this); |
|
|
|
|
|
|
|
if (tape.empty()) return; |
|
|
@@ -542,14 +551,16 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr |
|
|
|
} |
|
|
|
|
|
|
|
for (size_t i = 0; i < tensors.size(); ++i) { |
|
|
|
auto& grad_info = tensors[i]->m_tensor->m_grad_info; |
|
|
|
if (grad_info.grad_fn && grad_info.grad_fn->key.lock().get() == this) { |
|
|
|
grad_info->grad = grads[i]->m_tensor; |
|
|
|
if (tensors[i]->m_tensor->m_grad_info_dict.count(this) == 0) { |
|
|
|
continue; |
|
|
|
} |
|
|
|
auto& grad_info = tensors[i]->m_tensor->m_grad_info_dict.at(this); |
|
|
|
grad_info->grad = grads[i]->m_tensor; |
|
|
|
} |
|
|
|
|
|
|
|
std::vector<std::shared_ptr<GradFn>> ref_keeper; |
|
|
|
ref_keeper.reserve(tape.size()); |
|
|
|
|
|
|
|
// back-propagation in reverse order |
|
|
|
for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) { |
|
|
|
auto&& grad_fn = tape[k].lock(); |
|
|
@@ -619,13 +630,14 @@ PyObject* GradKeyWrapper::is_attached_to(PyObject*const* args, size_t nargs) { |
|
|
|
PyErr_SetString(PyExc_TypeError, "expect Tensor"); |
|
|
|
return nullptr; |
|
|
|
} |
|
|
|
auto&& grad_fn = tw->m_tensor->m_grad_info.grad_fn; |
|
|
|
if (grad_fn && grad_fn->key.lock() == m_key) { |
|
|
|
if (tw->m_tensor->m_grad_info_dict.count(m_key.get())) { |
|
|
|
Py_RETURN_TRUE; |
|
|
|
} |
|
|
|
Py_RETURN_FALSE; |
|
|
|
} |
|
|
|
|
|
|
|
int GradKey::sm_min_priority = 0; |
|
|
|
|
|
|
|
GradKey::~GradKey() { |
|
|
|
cleanup(); |
|
|
|
} |
|
|
@@ -635,4 +647,41 @@ std::unordered_map<Typeinfo*, GradRuleFn>& grad_rule_registry() { |
|
|
|
return registry; |
|
|
|
} |
|
|
|
|
|
|
|
void GradInfoCollection::_shrink() { |
|
|
|
auto pred = [](GradInfo& info){ return !(info.grad_fn) || info.grad_fn->key.expired(); }; |
|
|
|
auto iter = std::remove_if(m_storage.begin(), m_storage.end(), pred); |
|
|
|
m_storage.erase(iter, m_storage.end()); |
|
|
|
} |
|
|
|
|
|
|
|
bool GradInfoCollection::contains(GradKey* key) { |
|
|
|
_shrink(); |
|
|
|
for (auto&& grad_info: m_storage) { |
|
|
|
if (grad_info.grad_fn->key.lock().get() == key) { |
|
|
|
return true; |
|
|
|
} |
|
|
|
} |
|
|
|
return false; |
|
|
|
} |
|
|
|
|
|
|
|
GradInfo& GradInfoCollection::operator[](GradKey* key) { |
|
|
|
_shrink(); |
|
|
|
for (auto&& grad_info: m_storage) { |
|
|
|
if (grad_info.grad_fn->key.lock().get() == key) { |
|
|
|
return grad_info; |
|
|
|
} |
|
|
|
} |
|
|
|
m_storage.emplace_back(); |
|
|
|
return m_storage.back(); |
|
|
|
} |
|
|
|
|
|
|
|
GradInfo& GradInfoCollection::at(GradKey* key) { |
|
|
|
_shrink(); |
|
|
|
for (auto&& grad_info: m_storage) { |
|
|
|
if (grad_info.grad_fn->key.lock().get() == key) { |
|
|
|
return grad_info; |
|
|
|
} |
|
|
|
} |
|
|
|
mgb_assert(false); |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace mgb::imperative::python |