|
@@ -0,0 +1,543 @@ |
|
|
|
|
|
/** |
|
|
|
|
|
* \file imperative/src/impl/transformations/grad.cpp |
|
|
|
|
|
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License") |
|
|
|
|
|
* |
|
|
|
|
|
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved. |
|
|
|
|
|
* |
|
|
|
|
|
* Unless required by applicable law or agreed to in writing, |
|
|
|
|
|
* software distributed under the License is distributed on an |
|
|
|
|
|
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
|
|
*/ |
|
|
|
|
|
|
|
|
|
|
|
#include "megbrain/imperative/transformations/grad.h" |
|
|
|
|
|
|
|
|
|
|
|
#include "megbrain/imperative/graph_cache.h" |
|
|
|
|
|
|
|
|
|
|
|
#include <range/v3/all.hpp> |
|
|
|
|
|
|
|
|
|
|
|
namespace mgb { |
|
|
|
|
|
namespace imperative { |
|
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<OptimizedBackwardGraphResult> make_optimized_backward_graph( |
|
|
|
|
|
std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs, |
|
|
|
|
|
Span<bool> inputs_require_grad) { |
|
|
|
|
|
// hash |
|
|
|
|
|
using OptimizedBackwardGraphCache = OpMethResultCache< |
|
|
|
|
|
std::shared_ptr<OptimizedBackwardGraphResult>, SmallVector<bool>>; |
|
|
|
|
|
thread_local auto cache = std::make_unique<OptimizedBackwardGraphCache>(); |
|
|
|
|
|
OptimizedBackwardGraphCache::key_t cache_key{op}; |
|
|
|
|
|
SmallVector<LogicalTensorDesc>& input_descs = cache_key.inputs; |
|
|
|
|
|
std::get<0>(cache_key.extras) = inputs_require_grad.copy_into<SmallVector<bool>>(); |
|
|
|
|
|
input_descs.resize(inputs.size()); |
|
|
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
|
|
input_descs[i].layout.dtype = inputs[i].dtype().cast<DTypeValue>(); |
|
|
|
|
|
input_descs[i].comp_node = inputs[i].device().cast<CompNodeValue>(); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
auto iter = cache->find(cache_key); |
|
|
|
|
|
if (iter != cache->end()) { |
|
|
|
|
|
return iter->second; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
// slow path |
|
|
|
|
|
SmallVector<bool> output_has_grad(outputs.size(), true); |
|
|
|
|
|
std::shared_ptr<OptimizedBackwardGraphResult> ret; |
|
|
|
|
|
auto bg = OpDef::make_backward_graph( |
|
|
|
|
|
*op, input_descs, std::get<0>(cache_key.extras), output_has_grad); |
|
|
|
|
|
if (!bg.graph.empty()) { |
|
|
|
|
|
ret = std::make_shared<OptimizedBackwardGraphResult>(bg); |
|
|
|
|
|
} |
|
|
|
|
|
cache->emplace(cache_key, ret); |
|
|
|
|
|
return ret; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
BackwardGraphWithClosure::BackwardGraphWithClosure( |
|
|
|
|
|
std::shared_ptr<OptimizedBackwardGraphResult> backward_graph, |
|
|
|
|
|
std::shared_ptr<OpDef> op, Span<ValueRef> inputs, Span<ValueRef> outputs) |
|
|
|
|
|
: backward_graph(backward_graph), |
|
|
|
|
|
output_mask_offset(inputs.size()), |
|
|
|
|
|
grad_mask_offset(inputs.size() + outputs.size()) { |
|
|
|
|
|
auto& save_for_backward = backward_graph->save_for_backward; |
|
|
|
|
|
mgb_assert(save_for_backward.size() == inputs.size() + 2 * outputs.size()); |
|
|
|
|
|
size_t count = std::count_if( |
|
|
|
|
|
save_for_backward.begin(), save_for_backward.end(), ranges::identity{}); |
|
|
|
|
|
if (!backward_graph->precomp.empty()) { |
|
|
|
|
|
SmallVector<ValueRef> inputs_and_outputs; |
|
|
|
|
|
for (auto&& input : inputs) { |
|
|
|
|
|
inputs_and_outputs.push_back(input); |
|
|
|
|
|
} |
|
|
|
|
|
for (auto&& output : outputs) { |
|
|
|
|
|
inputs_and_outputs.push_back(output); |
|
|
|
|
|
} |
|
|
|
|
|
auto precomp = imperative::apply(backward_graph->precomp, inputs_and_outputs); |
|
|
|
|
|
closure.reserve(precomp.size() + count); |
|
|
|
|
|
std::copy(precomp.begin(), precomp.end(), std::back_inserter(closure)); |
|
|
|
|
|
} else { |
|
|
|
|
|
closure.reserve(count); |
|
|
|
|
|
} |
|
|
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
|
|
if (save_for_backward[i]) { |
|
|
|
|
|
closure.push_back(inputs[i]); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
for (size_t i = 0; i < outputs.size(); ++i) { |
|
|
|
|
|
if (save_for_backward[inputs.size() + i]) { |
|
|
|
|
|
closure.push_back(outputs[i]); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
void BackwardGraphWithClosure::operator()( |
|
|
|
|
|
std::vector<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) { |
|
|
|
|
|
ValueRef args[closure.size() + grads.size()]; |
|
|
|
|
|
size_t nargs = 0; |
|
|
|
|
|
for (auto&& value : closure) { |
|
|
|
|
|
args[nargs++] = value; |
|
|
|
|
|
} |
|
|
|
|
|
bool null_grad = false; |
|
|
|
|
|
for (size_t i = 0; i < grads.size(); ++i) { |
|
|
|
|
|
if (backward_graph->save_for_backward[grad_mask_offset + i]) { |
|
|
|
|
|
if (grads[i]) { |
|
|
|
|
|
mgb_assert(!null_grad, "null_grad"); |
|
|
|
|
|
args[nargs++] = grads[i]; |
|
|
|
|
|
} else { |
|
|
|
|
|
null_grad = true; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
if (null_grad) { |
|
|
|
|
|
return; |
|
|
|
|
|
} |
|
|
|
|
|
auto igrads = imperative::apply(backward_graph->backward, Span(args, nargs)); |
|
|
|
|
|
auto&& iter = igrads.begin(); |
|
|
|
|
|
for (auto [i, p] : ranges::views::enumerate(backward_graph->input_has_grad)) { |
|
|
|
|
|
if (p) { |
|
|
|
|
|
receiver(i, std::move(*iter)); |
|
|
|
|
|
++iter; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void CustomBackward::operator()( |
|
|
|
|
|
std::vector<ValueRef> grads, std::function<void(size_t, ValueRef)> receiver) { |
|
|
|
|
|
size_t nargs = grads.size(); |
|
|
|
|
|
ValueRef args[nargs]; |
|
|
|
|
|
for (size_t i = 0; i < nargs; ++i) { |
|
|
|
|
|
args[i] = grads[i]; |
|
|
|
|
|
} |
|
|
|
|
|
auto ret = m_backward({args, nargs}); |
|
|
|
|
|
for (size_t i = 0; i < ret.size(); ++i) { |
|
|
|
|
|
if (auto&& t = ret[i]) { |
|
|
|
|
|
receiver(i, std::move(t)); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::string GradSlot::to_string() const { |
|
|
|
|
|
bool has_callback = bool(callback); |
|
|
|
|
|
return ssprintf( |
|
|
|
|
|
"GradSlot{grad=%s, has_callback=%d}", m_grad.to_string().c_str(), |
|
|
|
|
|
(int)has_callback); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::string GradFn::to_string() const { |
|
|
|
|
|
return ssprintf("GradFn{dests=%s}", imperative::to_string(m_dests).c_str()); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::string GradSlotPtr::to_string() const { |
|
|
|
|
|
if (!m_fn) { |
|
|
|
|
|
return "<empty>"; |
|
|
|
|
|
} |
|
|
|
|
|
return (*this)->to_string(); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::string GradValue::to_string() const { |
|
|
|
|
|
return ssprintf( |
|
|
|
|
|
"GradValue{key=\"%s\", slot=%s, value=%s}", m_key->name().c_str(), |
|
|
|
|
|
m_slot.to_string().c_str(), m_value.to_string().c_str()); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<Typeinfo*, CustomBackward::BackwardRule>& |
|
|
|
|
|
get_backward_rule_storage() { |
|
|
|
|
|
static std::unordered_map<Typeinfo*, CustomBackward::BackwardRule> sl_storage; |
|
|
|
|
|
return sl_storage; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
bool CustomBackward::register_grad_rule(Typeinfo* typeinfo, BackwardRule rule) { |
|
|
|
|
|
return get_backward_rule_storage().insert({typeinfo, rule}).second; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
auto CustomBackward::lookup_grad_rule(Typeinfo* typeinfo) -> BackwardRule { |
|
|
|
|
|
auto iter = get_backward_rule_storage().find(typeinfo); |
|
|
|
|
|
if (iter == get_backward_rule_storage().end()) { |
|
|
|
|
|
return {}; |
|
|
|
|
|
} |
|
|
|
|
|
return iter->second; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void GradKey::backward() { |
|
|
|
|
|
mgb_assert(m_frozen); |
|
|
|
|
|
auto& tape = m_frozen_tape; |
|
|
|
|
|
for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) { |
|
|
|
|
|
auto& [grad_fn, op] = tape[k]; |
|
|
|
|
|
auto grad_receiver = [&, grad_fn = grad_fn](size_t i, ValueRef grad) { |
|
|
|
|
|
auto& dest = grad_fn->m_dests[i]; |
|
|
|
|
|
if (dest) { |
|
|
|
|
|
auto& existing_grad = dest->m_grad; |
|
|
|
|
|
if (!existing_grad) { |
|
|
|
|
|
existing_grad = grad; |
|
|
|
|
|
} else { |
|
|
|
|
|
existing_grad = imperative::apply( |
|
|
|
|
|
ApplyOp(*Elemwise::make(Elemwise::Mode::ADD)), |
|
|
|
|
|
existing_grad, grad)[0]; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
}; |
|
|
|
|
|
// clang-format off |
|
|
|
|
|
std::visit([&, grad_fn = grad_fn, op = op](auto&& backward) { |
|
|
|
|
|
using T = std::decay_t<decltype(backward)>; |
|
|
|
|
|
if constexpr (std::is_same_v<T, std::monostate>) { |
|
|
|
|
|
mgb_throw(AssertionError, "invalid backward"); |
|
|
|
|
|
} else { |
|
|
|
|
|
mgb_assert(grad_fn->m_slots.size() > 0); |
|
|
|
|
|
std::vector<ValueRef> grads; |
|
|
|
|
|
for (auto&& slot : grad_fn->m_slots) { |
|
|
|
|
|
grads.push_back(slot.m_grad); |
|
|
|
|
|
} |
|
|
|
|
|
backward(grads, grad_receiver); |
|
|
|
|
|
} |
|
|
|
|
|
}, grad_fn->m_backward); |
|
|
|
|
|
// clang-format on |
|
|
|
|
|
for (auto&& dest : grad_fn->m_dests) { |
|
|
|
|
|
if (!dest) { |
|
|
|
|
|
continue; |
|
|
|
|
|
} |
|
|
|
|
|
if (!dest.m_producer_record.next && dest->callback && dest->m_grad) { |
|
|
|
|
|
// I'm the last grad producer, invoke callback |
|
|
|
|
|
dest->callback(dest->m_grad); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
grad_fn->clear(); |
|
|
|
|
|
} |
|
|
|
|
|
tape.clear(); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
GradValue::ref_t GradKey::attach( |
|
|
|
|
|
ValueRef tensor, std::function<void(ValueRef)> callback) { |
|
|
|
|
|
auto grad_value = tensor.as_ref<GradValue>(); |
|
|
|
|
|
if (grad_value && grad_value->has_key(shared_from_this())) { |
|
|
|
|
|
mgb_assert( |
|
|
|
|
|
!tensor.cast<GradValue>().slot_for(shared_from_this())->callback, |
|
|
|
|
|
"callback exists"); |
|
|
|
|
|
} else { |
|
|
|
|
|
GradSlotPtr grad_slot; |
|
|
|
|
|
auto& grad_fn = grad_slot.m_fn; |
|
|
|
|
|
grad_fn = std::make_shared<GradFn>(); |
|
|
|
|
|
grad_fn->m_key = shared_from_this(); |
|
|
|
|
|
grad_fn->m_slots.resize(1); |
|
|
|
|
|
grad_slot.m_index = 0; |
|
|
|
|
|
grad_value = GradValue::make(tensor, shared_from_this(), grad_slot); |
|
|
|
|
|
} |
|
|
|
|
|
grad_value->slot_for(shared_from_this()).m_fn->m_slots[0].callback = callback; |
|
|
|
|
|
return grad_value; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void GradKey::freeze() { |
|
|
|
|
|
mgb_assert(m_frozen_tape.empty() && !m_frozen); |
|
|
|
|
|
for (auto&& [grad_fn, op] : m_tape) { |
|
|
|
|
|
if (auto valid_grad_fn = grad_fn.lock()) { |
|
|
|
|
|
m_frozen_tape.push_back({valid_grad_fn, op}); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
m_tape.clear(); |
|
|
|
|
|
m_frozen = true; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
std::vector<ValueRef> GradTransformation::apply_transformation( |
|
|
|
|
|
const Operator& op, Span<ValueRef> inputs) { |
|
|
|
|
|
auto unwrap_inputs = [this](Span<ValueRef> inputs) -> SmallVector<ValueRef> { |
|
|
|
|
|
SmallVector<ValueRef> unwrapped_inputs; |
|
|
|
|
|
for (auto&& input : inputs) { |
|
|
|
|
|
if (auto grad_value = as_grad_value(input)) { |
|
|
|
|
|
unwrapped_inputs.push_back(grad_value->m_value); |
|
|
|
|
|
} else { |
|
|
|
|
|
unwrapped_inputs.push_back(input); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return unwrapped_inputs; |
|
|
|
|
|
}; |
|
|
|
|
|
if (m_suppressed) { |
|
|
|
|
|
return imperative::apply(op, unwrap_inputs(inputs)); |
|
|
|
|
|
} |
|
|
|
|
|
if (auto* op_val = op.as<ApplyOp>()) { |
|
|
|
|
|
size_t nr_require_grad = 0; |
|
|
|
|
|
SmallVector<bool> require_grads; |
|
|
|
|
|
for (auto&& input : inputs) { |
|
|
|
|
|
if (is_grad_value(input)) { |
|
|
|
|
|
nr_require_grad++; |
|
|
|
|
|
require_grads.push_back(true); |
|
|
|
|
|
} else { |
|
|
|
|
|
require_grads.push_back(false); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
if (nr_require_grad == 0) { |
|
|
|
|
|
return imperative::apply(op, inputs); |
|
|
|
|
|
} |
|
|
|
|
|
SmallVector<ValueRef> captured_inputs; |
|
|
|
|
|
SmallVector<bool> inputs_require_grad; |
|
|
|
|
|
// capture value so that trace could assume input as same |
|
|
|
|
|
auto capture_value = [](ValueRef value) { |
|
|
|
|
|
// TODO: fastpath copy shouldn't be an OpDef |
|
|
|
|
|
return imperative::apply(ApplyOp(*FastpathCopy::make()), {&value, 1})[0]; |
|
|
|
|
|
}; |
|
|
|
|
|
for (auto& input : inputs) { |
|
|
|
|
|
if (auto grad_value = as_grad_value(input)) { |
|
|
|
|
|
captured_inputs.push_back(capture_value(grad_value->m_value)); |
|
|
|
|
|
inputs_require_grad.push_back(true); |
|
|
|
|
|
} else { |
|
|
|
|
|
captured_inputs.push_back(capture_value(input)); |
|
|
|
|
|
inputs_require_grad.push_back(false); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
decltype(std::declval<GradFn>().m_backward) backward_storage; |
|
|
|
|
|
auto outputs = [&] { |
|
|
|
|
|
auto backward_rule = |
|
|
|
|
|
CustomBackward::lookup_grad_rule(op_val->op().dyn_typeinfo()); |
|
|
|
|
|
if (backward_rule) { |
|
|
|
|
|
CustomBackward backward; |
|
|
|
|
|
auto optional_outputs = backward_rule( |
|
|
|
|
|
op_val->op(), {captured_inputs.data(), captured_inputs.size()}, |
|
|
|
|
|
{inputs_require_grad.data(), inputs_require_grad.size()}, |
|
|
|
|
|
backward); |
|
|
|
|
|
if (optional_outputs) { |
|
|
|
|
|
backward_storage = backward; |
|
|
|
|
|
// backward by rule |
|
|
|
|
|
return *optional_outputs; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
auto outputs = imperative::apply( |
|
|
|
|
|
op, {captured_inputs.begin(), captured_inputs.end()}); |
|
|
|
|
|
auto backward_graph = make_optimized_backward_graph( |
|
|
|
|
|
op.cast<ApplyOp>().op().shared_from_this(), |
|
|
|
|
|
{captured_inputs.begin(), captured_inputs.end()}, |
|
|
|
|
|
{outputs.data(), outputs.size()}, |
|
|
|
|
|
{inputs_require_grad.data(), inputs_require_grad.size()}); |
|
|
|
|
|
if (backward_graph) { |
|
|
|
|
|
backward_storage = BackwardGraphWithClosure( |
|
|
|
|
|
backward_graph, op.cast<ApplyOp>().op().shared_from_this(), |
|
|
|
|
|
{captured_inputs.begin(), captured_inputs.end()}, |
|
|
|
|
|
{outputs.data(), outputs.size()}); |
|
|
|
|
|
// backward by make_backward_graph |
|
|
|
|
|
return outputs; |
|
|
|
|
|
} else { |
|
|
|
|
|
// no backward |
|
|
|
|
|
return outputs; |
|
|
|
|
|
} |
|
|
|
|
|
}(); |
|
|
|
|
|
if (std::holds_alternative<std::monostate>(backward_storage)) { |
|
|
|
|
|
return outputs; |
|
|
|
|
|
} |
|
|
|
|
|
auto grad_fn = std::make_shared<GradFn>(); |
|
|
|
|
|
grad_fn->m_key = m_key; |
|
|
|
|
|
grad_fn->m_slots.resize(outputs.size()); |
|
|
|
|
|
grad_fn->m_backward = backward_storage; |
|
|
|
|
|
mgb_assert(!outputs.empty()); |
|
|
|
|
|
grad_fn->m_dests.reserve(inputs.size()); |
|
|
|
|
|
// clang-format off |
|
|
|
|
|
std::visit([&](auto& backward) { |
|
|
|
|
|
using T = std::decay_t<decltype(backward)>; |
|
|
|
|
|
if constexpr (std::is_same_v<T, std::monostate>) { |
|
|
|
|
|
mgb_throw(AssertionError, "invalid backward"); |
|
|
|
|
|
} else { |
|
|
|
|
|
for (size_t i = 0; i < inputs.size(); ++i) { |
|
|
|
|
|
if (backward.input_has_grad(i) && require_grads[i]) { |
|
|
|
|
|
auto& input_grad_slot = |
|
|
|
|
|
inputs[i].cast<GradValue>().slot_for(m_key); |
|
|
|
|
|
grad_fn->m_dests.emplace_back(input_grad_slot); |
|
|
|
|
|
grad_fn->m_dests.back().m_producer_record.insert_after( |
|
|
|
|
|
input_grad_slot->m_producer_head); |
|
|
|
|
|
} else { |
|
|
|
|
|
grad_fn->m_dests.emplace_back(); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
for (size_t i = 0; i < outputs.size(); ++i) { |
|
|
|
|
|
if (backward.output_requires_grad(i)) { |
|
|
|
|
|
auto grad_value = GradValue::make(outputs[i], m_key, GradSlotPtr{grad_fn, i}); |
|
|
|
|
|
outputs[i] = record_grad(grad_value); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
}, grad_fn->m_backward); |
|
|
|
|
|
// clang-format on |
|
|
|
|
|
mgb_assert(!grad_fn->m_slots.empty()); |
|
|
|
|
|
m_key->m_tape.push_back({grad_fn, op_val->op().shared_from_this()}); |
|
|
|
|
|
return outputs; |
|
|
|
|
|
} else if (auto* attach_grad = op.as<AttachGrad>()) { |
|
|
|
|
|
if (!has_key(attach_grad->key())) { |
|
|
|
|
|
return imperative::apply(op, unwrap_inputs(inputs)); |
|
|
|
|
|
} |
|
|
|
|
|
auto tensor = inputs[0]; |
|
|
|
|
|
GenericFunction callback = (GenericFunction&)inputs[1].cast<FunctionValue>(); |
|
|
|
|
|
auto output = attach_grad->key()->attach(tensor, [callback](ValueRef grad) { |
|
|
|
|
|
auto ret = callback({&grad, 1}); |
|
|
|
|
|
assert(ret.empty()); |
|
|
|
|
|
}); |
|
|
|
|
|
return {record_grad(output)}; |
|
|
|
|
|
} else if (auto* grad_backward = op.as<GradBackward>()) { |
|
|
|
|
|
if (!has_key(grad_backward->key())) { |
|
|
|
|
|
return imperative::apply(op, unwrap_inputs(inputs)); |
|
|
|
|
|
} |
|
|
|
|
|
size_t nr_grads = inputs.size() / 2; |
|
|
|
|
|
mgb_assert(nr_grads * 2 == inputs.size()); |
|
|
|
|
|
auto values = inputs.sub(0, nr_grads); |
|
|
|
|
|
auto grads = inputs.sub(nr_grads, nr_grads); |
|
|
|
|
|
make_backward_closure(values)(grads); |
|
|
|
|
|
return {}; |
|
|
|
|
|
} else if (auto* is_attached_to = op.as<IsAttachedTo>()) { |
|
|
|
|
|
if (has_key(is_attached_to->key())) { |
|
|
|
|
|
if (auto grad_value = as_grad_value(inputs[0])) { |
|
|
|
|
|
// TODO: assert grad_fn |
|
|
|
|
|
return {BoolValue::make(true)}; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return {BoolValue::make(false)}; |
|
|
|
|
|
} else if (auto* set_grad = op.as<SetGrad>()) { |
|
|
|
|
|
// TODO: merge SetGrad and ApplyOp |
|
|
|
|
|
auto grad_fn = std::make_shared<GradFn>(); |
|
|
|
|
|
auto& backward = |
|
|
|
|
|
std::get<CustomBackward>(grad_fn->m_backward = CustomBackward()); |
|
|
|
|
|
size_t nr_inputs = set_grad->nr_inputs(); |
|
|
|
|
|
mgb_assert(inputs.size() > nr_inputs); |
|
|
|
|
|
size_t nr_outputs = inputs.size() - nr_inputs; |
|
|
|
|
|
Span<ValueRef> inputs_ = {inputs.data(), nr_inputs}; |
|
|
|
|
|
Span<ValueRef> outputs_ = {inputs.data() + nr_inputs, nr_outputs}; |
|
|
|
|
|
backward.m_input_has_grad = SmallVector(nr_inputs, true); |
|
|
|
|
|
backward.m_output_attrs = |
|
|
|
|
|
SmallVector(nr_outputs, CustomBackward::OutputAttr{true, true}); |
|
|
|
|
|
backward.m_backward = set_grad->grad_fn(); |
|
|
|
|
|
std::vector<ValueRef> outputs; |
|
|
|
|
|
grad_fn->m_key = m_key; |
|
|
|
|
|
grad_fn->m_slots.resize(nr_outputs); |
|
|
|
|
|
grad_fn->m_dests.reserve(nr_inputs); |
|
|
|
|
|
for (size_t i = 0; i < nr_inputs; ++i) { |
|
|
|
|
|
if (auto grad_value = as_grad_value(inputs_[i])) { |
|
|
|
|
|
auto& input_grad_slot = grad_value->m_slot; |
|
|
|
|
|
grad_fn->m_dests.emplace_back(grad_value->m_slot); |
|
|
|
|
|
grad_fn->m_dests.back().m_producer_record.insert_after( |
|
|
|
|
|
input_grad_slot->m_producer_head); |
|
|
|
|
|
} else { |
|
|
|
|
|
grad_fn->m_dests.emplace_back(); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
for (size_t i = 0; i < nr_outputs; ++i) { |
|
|
|
|
|
auto& output = outputs_[i]; |
|
|
|
|
|
auto grad_value = as_grad_value(output); |
|
|
|
|
|
if (grad_value) { |
|
|
|
|
|
grad_value = GradValue::make( |
|
|
|
|
|
grad_value->m_value, m_key, GradSlotPtr(grad_fn, i)); |
|
|
|
|
|
} else { |
|
|
|
|
|
grad_value = GradValue::make(output, m_key, GradSlotPtr(grad_fn, i)); |
|
|
|
|
|
} |
|
|
|
|
|
outputs.push_back(record_grad(grad_value)); |
|
|
|
|
|
} |
|
|
|
|
|
m_key->m_tape.push_back({grad_fn, nullptr}); |
|
|
|
|
|
return outputs; |
|
|
|
|
|
} else if (auto* gbc = op.as<GetBackwardColsure>()) { |
|
|
|
|
|
if (gbc->key() != m_key) { |
|
|
|
|
|
return imperative::apply(op, unwrap_inputs(inputs)); |
|
|
|
|
|
} |
|
|
|
|
|
return {FunctionValue::make(make_backward_closure(inputs))}; |
|
|
|
|
|
} else if (op.is<DetachGrad>()) { |
|
|
|
|
|
if (auto grad_value = as_grad_value(inputs[0])) { |
|
|
|
|
|
return {grad_value->m_value}; |
|
|
|
|
|
} else { |
|
|
|
|
|
return {inputs[0]}; |
|
|
|
|
|
} |
|
|
|
|
|
} else if (op.is<GetGradKey>()) { |
|
|
|
|
|
for (auto&& input : inputs) { |
|
|
|
|
|
if (auto grad_value = as_grad_value(input)) { |
|
|
|
|
|
return {GradKeyValue::make(grad_value->m_key)}; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
return imperative::apply(op, inputs); |
|
|
|
|
|
} else if (op.kind() == Operator::IdentityLike) { |
|
|
|
|
|
mgb_assert(inputs.size() == 1); |
|
|
|
|
|
if (auto grad_value = as_grad_value(inputs[0])) { |
|
|
|
|
|
auto output = imperative::apply(op, grad_value->m_value)[0]; |
|
|
|
|
|
auto grad_output = GradValue::make( |
|
|
|
|
|
output, grad_value->key(), grad_value->slot_for(m_key)); |
|
|
|
|
|
return {record_grad(grad_output)}; |
|
|
|
|
|
} else { |
|
|
|
|
|
return imperative::apply(op, inputs); |
|
|
|
|
|
} |
|
|
|
|
|
} else if (op.is<CreateTensor>()) { |
|
|
|
|
|
return imperative::apply(op, inputs); |
|
|
|
|
|
} else { |
|
|
|
|
|
SmallVector<ValueRef> unwrapped_inputs; |
|
|
|
|
|
for (auto&& input : inputs) { |
|
|
|
|
|
if (auto grad_value = as_grad_value(input)) { |
|
|
|
|
|
unwrapped_inputs.push_back(grad_value->m_value); |
|
|
|
|
|
} else { |
|
|
|
|
|
unwrapped_inputs.push_back(input); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
auto outputs = imperative::apply( |
|
|
|
|
|
op, {unwrapped_inputs.data(), unwrapped_inputs.size()}); |
|
|
|
|
|
mgb_assert(op.kind() == Operator::GetAttrLike || outputs.empty()); |
|
|
|
|
|
return outputs; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
GenericFunction GradTransformation::make_backward_closure(Span<ValueRef> ys) { |
|
|
|
|
|
// reset GradKey |
|
|
|
|
|
auto grad_key = m_key; |
|
|
|
|
|
std::vector<GradSlotPtr> y_slots; |
|
|
|
|
|
for (auto&& y : ys) { |
|
|
|
|
|
if (auto grad_value = as_grad_value(y)) { |
|
|
|
|
|
y_slots.push_back(grad_value->slot_for(grad_key)); |
|
|
|
|
|
} else { |
|
|
|
|
|
y_slots.emplace_back(); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
GenericFunction closure = [grad_key, |
|
|
|
|
|
y_slots](Span<ValueRef> dys) -> std::vector<ValueRef> { |
|
|
|
|
|
size_t nr_grads = y_slots.size(); |
|
|
|
|
|
mgb_assert(dys.size() == nr_grads); |
|
|
|
|
|
for (size_t i = 0; i < nr_grads; ++i) { |
|
|
|
|
|
if (y_slots[i]) { |
|
|
|
|
|
y_slots[i]->m_grad = dys[i]; |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
grad_key->backward(); |
|
|
|
|
|
return {}; |
|
|
|
|
|
}; |
|
|
|
|
|
grad_key->freeze(); |
|
|
|
|
|
cleanup(); |
|
|
|
|
|
return closure; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void GradTransformation::on_unregister() noexcept { |
|
|
|
|
|
cleanup(); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void GradTransformation::cleanup() { |
|
|
|
|
|
for (auto&& weak_value : m_weak_values) { |
|
|
|
|
|
auto grad_value = weak_value.lock(); |
|
|
|
|
|
if (grad_value) { |
|
|
|
|
|
mgb_assert(grad_value->m_key == m_key); |
|
|
|
|
|
grad_value.reset(grad_value->m_value); |
|
|
|
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
m_weak_values.clear(); |
|
|
|
|
|
m_key = {}; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void GradTransformation::suppress() { |
|
|
|
|
|
m_suppressed++; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
void GradTransformation::resume() { |
|
|
|
|
|
m_suppressed--; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
} // namespace imperative |
|
|
|
|
|
} // namespace mgb |