|
|
@@ -14,7 +14,10 @@ |
|
|
|
#include "megbrain/imperative/ops/autogen.h" |
|
|
|
#include "megbrain/utils/mempool.h" |
|
|
|
|
|
|
|
#include "range/v3/all.hpp" |
|
|
|
|
|
|
|
namespace py = pybind11; |
|
|
|
namespace views = ranges::views; |
|
|
|
|
|
|
|
namespace mgb::imperative::python { |
|
|
|
|
|
|
@@ -25,6 +28,152 @@ struct GradSlotWeakPtr { |
|
|
|
size_t idx; |
|
|
|
}; |
|
|
|
|
|
|
|
struct BackwardGraphCache : std::unordered_map<size_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject { |
|
|
|
std::shared_ptr<void> on_comp_node_finalize() override { |
|
|
|
clear(); |
|
|
|
return {}; |
|
|
|
} |
|
|
|
} backward_graph_cache; |
|
|
|
|
|
|
|
std::shared_ptr<BackwardGraphResult> make_backward_graph( |
|
|
|
ApplyContext& ctx, const apply_result_t& outputs) { |
|
|
|
// hash |
|
|
|
static_assert(alignof(size_t) % alignof(bool) == 0); |
|
|
|
size_t buf_size = (1 + ctx.nargs * 2) * sizeof(size_t) + ctx.nargs * sizeof(bool); |
|
|
|
alignas(alignof(size_t)) std::byte buf[buf_size]; |
|
|
|
size_t* size_t_ptr = reinterpret_cast<size_t*>(buf); |
|
|
|
bool* bool_ptr = reinterpret_cast<bool*>(size_t_ptr + (1 + ctx.nargs * 2)); |
|
|
|
bool* bool_ptr0 = bool_ptr; |
|
|
|
*(size_t_ptr++) = ctx.op->hash(); |
|
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
|
*(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle()); |
|
|
|
*(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node()); |
|
|
|
*(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn); |
|
|
|
} |
|
|
|
mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) && |
|
|
|
bool_ptr == reinterpret_cast<bool*>(buf + buf_size)); |
|
|
|
size_t key = XXHash{}.update(buf, buf_size).digest(); |
|
|
|
|
|
|
|
auto&& iter = backward_graph_cache.find(key); |
|
|
|
if (iter != backward_graph_cache.end()) { |
|
|
|
return iter->second; |
|
|
|
} |
|
|
|
|
|
|
|
// slow path |
|
|
|
SmallVector<LogicalTensorDesc> inputs(ctx.nargs); |
|
|
|
SmallVector<bool> input_requires_grad(ctx.nargs, false); |
|
|
|
SmallVector<bool> output_has_grad(outputs.size(), true); |
|
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
|
inputs[i].comp_node = ctx.args[i]->comp_node(); |
|
|
|
inputs[i].layout.dtype = ctx.args[i]->dtype(); |
|
|
|
input_requires_grad[i] = bool(ctx.args[i]->m_grad_info.grad_fn); |
|
|
|
} |
|
|
|
auto result = std::make_shared<BackwardGraphResult>( |
|
|
|
proxy_graph_detail::make_backward_graph( |
|
|
|
*ctx.op, inputs, input_requires_grad, output_has_grad)); |
|
|
|
if (!result->backward) { |
|
|
|
result.reset(); |
|
|
|
} |
|
|
|
backward_graph_cache.emplace(key, result); |
|
|
|
return result; |
|
|
|
} |
|
|
|
|
|
|
|
struct BackwardGraphWithClosure { |
|
|
|
std::shared_ptr<BackwardGraphResult> backward_graph; |
|
|
|
SmallVector<std::shared_ptr<Tensor>> closure; |
|
|
|
size_t output_mask_offset; |
|
|
|
size_t grad_mask_offset; |
|
|
|
|
|
|
|
BackwardGraphWithClosure(std::shared_ptr<BackwardGraphResult> backward_graph_, |
|
|
|
ApplyContext& ctx, const apply_result_t& outputs) |
|
|
|
: backward_graph(backward_graph_), |
|
|
|
output_mask_offset(ctx.nargs), |
|
|
|
grad_mask_offset(ctx.nargs + outputs.size()) { |
|
|
|
// save_for_backward[0:nargs]: |
|
|
|
// whether input is kept for backward |
|
|
|
// |
|
|
|
// save_for_backward[nargs:nargs+outputs.size()]: |
|
|
|
// whether output is kept for backward |
|
|
|
// |
|
|
|
// save_for_backward[-outputs.size():]: |
|
|
|
// whether gradient of output can propagate to any input |
|
|
|
// |
|
|
|
// Example: |
|
|
|
// perform c = a * b, with a.requires_grad == True and |
|
|
|
// b.requires_grad == False, save_for_backward = [0, 1, 0, 1] |
|
|
|
auto& save_for_backward = backward_graph->save_for_backward; |
|
|
|
mgb_assert(save_for_backward.size() == ctx.nargs + 2 * outputs.size()); |
|
|
|
closure.reserve(std::count_if(save_for_backward.begin(), |
|
|
|
save_for_backward.end(), |
|
|
|
ranges::identity{})); |
|
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
|
if (save_for_backward[i]) { |
|
|
|
closure.push_back(ctx.args[i]->shared_from_this()); |
|
|
|
} |
|
|
|
} |
|
|
|
for (size_t i = 0; i < outputs.size(); ++i) { |
|
|
|
if (save_for_backward[ctx.nargs + i]) { |
|
|
|
closure.push_back(outputs[i]); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
template <typename T, typename R> |
|
|
|
void operator()(T&& grads, R&& receiver) { |
|
|
|
Tensor* args[closure.size() + grads.size()]; |
|
|
|
size_t nargs = 0; |
|
|
|
for (auto&& t : closure) { |
|
|
|
args[nargs++] = t.get(); |
|
|
|
} |
|
|
|
bool null_grad = false; |
|
|
|
for (size_t i = 0; i < grads.size(); ++i) { |
|
|
|
if (backward_graph->save_for_backward[grad_mask_offset + i]) { |
|
|
|
if (grads[i]) { |
|
|
|
if (null_grad) { |
|
|
|
PyErr_SetString(PyExc_NotImplementedError, "report to devs"); |
|
|
|
throw py::error_already_set(); |
|
|
|
} |
|
|
|
args[nargs++] = grads[i]; |
|
|
|
} else { |
|
|
|
null_grad = true; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (null_grad) return; |
|
|
|
|
|
|
|
ApplyContext ctx; |
|
|
|
ctx.op = backward_graph->backward; |
|
|
|
ctx.flags = is_tracing ? Tensor::Flags::TRACE : 0; |
|
|
|
ctx.nargs = nargs; |
|
|
|
ctx.args = args; |
|
|
|
for (size_t i = 0; i < nargs; ++i) { |
|
|
|
ctx.flags |= args[i]->m_flags; |
|
|
|
mgb_assert(args[i]); |
|
|
|
} |
|
|
|
|
|
|
|
auto igrads = apply(ctx); |
|
|
|
auto&& it = igrads.begin(); |
|
|
|
for (auto [i, p] : views::enumerate(backward_graph->input_has_grad)) { |
|
|
|
if (p) { |
|
|
|
receiver(i, std::move(*it)); |
|
|
|
++it; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
bool input_has_grad(size_t i) { |
|
|
|
return backward_graph->input_has_grad[i]; |
|
|
|
} |
|
|
|
|
|
|
|
bool output_requires_grad(size_t i) { |
|
|
|
return backward_graph->save_for_backward[grad_mask_offset + i]; |
|
|
|
} |
|
|
|
|
|
|
|
bool output_captured(size_t i) { |
|
|
|
return backward_graph->save_for_backward[output_mask_offset + i]; |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
} // namespace |
|
|
|
|
|
|
|
struct GradProducerRecord : intrusive_list::Node<GradProducerRecord> { |
|
|
@@ -54,10 +203,15 @@ struct GradFn : std::enable_shared_from_this<GradFn> { |
|
|
|
static MemPool<GradFn> pool; |
|
|
|
|
|
|
|
std::weak_ptr<GradKey> key; |
|
|
|
// slots for receiving and accumulating grads |
|
|
|
// same length as outputs (of forward op) |
|
|
|
SmallVector<GradSlot> slots; |
|
|
|
// where to send and accumulate grads |
|
|
|
// same length as inputs (of forward op) |
|
|
|
SmallVector<GradSlotProducerPtr> dsts; |
|
|
|
SmallVector<std::shared_ptr<Tensor>> closure; |
|
|
|
std::shared_ptr<BackwardGraphResult> backward_graph; |
|
|
|
// encapsules actual function to compute gradient |
|
|
|
std::variant<std::monostate, BackwardGraphWithClosure> backward; |
|
|
|
// a flag used during backward |
|
|
|
bool in_ref_keeper = false; |
|
|
|
|
|
|
|
static void deleter(GradFn* ptr) { |
|
|
@@ -72,8 +226,7 @@ struct GradFn : std::enable_shared_from_this<GradFn> { |
|
|
|
key.reset(); |
|
|
|
slots.clear(); |
|
|
|
dsts.clear(); |
|
|
|
closure.clear(); |
|
|
|
backward_graph.reset(); |
|
|
|
backward.emplace<std::monostate>(); |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
@@ -83,54 +236,36 @@ GradSlot* GradSlotPtr::operator->() { |
|
|
|
|
|
|
|
namespace { |
|
|
|
|
|
|
|
struct BackwardGraphCache : std::unordered_map<size_t, std::shared_ptr<BackwardGraphResult>>, CompNodeDepedentObject { |
|
|
|
std::shared_ptr<void> on_comp_node_finalize() override { |
|
|
|
clear(); |
|
|
|
return {}; |
|
|
|
} |
|
|
|
} backward_graph_cache; |
|
|
|
class GradFnHelper { |
|
|
|
std::shared_ptr<GradFn> grad_fn; |
|
|
|
|
|
|
|
std::shared_ptr<BackwardGraphResult> make_backward_graph( |
|
|
|
ApplyContext& ctx, const apply_result_t& outputs) { |
|
|
|
// hash |
|
|
|
static_assert(alignof(size_t) % alignof(bool) == 0); |
|
|
|
size_t buf_size = (1 + ctx.nargs * 2) * sizeof(size_t) + ctx.nargs * sizeof(bool); |
|
|
|
alignas(alignof(size_t)) std::byte buf[buf_size]; |
|
|
|
size_t* size_t_ptr = reinterpret_cast<size_t*>(buf); |
|
|
|
bool* bool_ptr = reinterpret_cast<bool*>(size_t_ptr + (1 + ctx.nargs * 2)); |
|
|
|
bool* bool_ptr0 = bool_ptr; |
|
|
|
*(size_t_ptr++) = ctx.op->hash(); |
|
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
|
*(size_t_ptr++) = mgb::hash(ctx.args[i]->dtype().handle()); |
|
|
|
*(size_t_ptr++) = mgb::hash(ctx.args[i]->comp_node()); |
|
|
|
*(bool_ptr++) = bool(ctx.args[i]->m_grad_info.grad_fn); |
|
|
|
GradFn* get() { |
|
|
|
if (!grad_fn) { |
|
|
|
grad_fn = std::make_shared<GradFn>(); |
|
|
|
} |
|
|
|
return grad_fn.get(); |
|
|
|
} |
|
|
|
mgb_assert(bool_ptr0 == reinterpret_cast<bool*>(size_t_ptr) && |
|
|
|
bool_ptr == reinterpret_cast<bool*>(buf + buf_size)); |
|
|
|
size_t key = XXHash{}.update(buf, buf_size).digest(); |
|
|
|
|
|
|
|
auto&& iter = backward_graph_cache.find(key); |
|
|
|
if (iter != backward_graph_cache.end()) { |
|
|
|
return iter->second; |
|
|
|
} |
|
|
|
friend apply_result_t imperative::python::apply_grad(ApplyContext&); |
|
|
|
|
|
|
|
// slow path |
|
|
|
SmallVector<LogicalTensorDesc> inputs(ctx.nargs); |
|
|
|
SmallVector<bool> input_requires_grad(ctx.nargs, false); |
|
|
|
SmallVector<bool> output_has_grad(outputs.size(), true); |
|
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
|
inputs[i].comp_node = ctx.args[i]->comp_node(); |
|
|
|
inputs[i].layout.dtype = ctx.args[i]->dtype(); |
|
|
|
input_requires_grad[i] = bool(ctx.args[i]->m_grad_info.grad_fn); |
|
|
|
public: |
|
|
|
template<typename T, typename... Args> |
|
|
|
auto& emplace(Args&&... args) { |
|
|
|
return get()->backward.emplace<T>(std::forward<Args>(args)...); |
|
|
|
} |
|
|
|
auto result = std::make_shared<BackwardGraphResult>( |
|
|
|
proxy_graph_detail::make_backward_graph( |
|
|
|
*ctx.op, inputs, input_requires_grad, output_has_grad)); |
|
|
|
if (!result->backward) { |
|
|
|
result.reset(); |
|
|
|
}; |
|
|
|
|
|
|
|
apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { |
|
|
|
auto outputs = apply(ctx); |
|
|
|
|
|
|
|
auto backward_graph = make_backward_graph(ctx, outputs); |
|
|
|
if (!backward_graph) { |
|
|
|
return outputs; |
|
|
|
} |
|
|
|
backward_graph_cache.emplace(key, result); |
|
|
|
return result; |
|
|
|
|
|
|
|
ret_grad_fn.emplace<BackwardGraphWithClosure>(std::move(backward_graph), ctx, outputs); |
|
|
|
|
|
|
|
return outputs; |
|
|
|
} |
|
|
|
|
|
|
|
} // namespace |
|
|
@@ -164,76 +299,53 @@ apply_result_t apply_grad(ApplyContext& ctx) { |
|
|
|
|
|
|
|
ctx.flags &= ~Tensor::Flags::GRAD; |
|
|
|
|
|
|
|
// perform forward apply_op or trace |
|
|
|
auto outputs = apply(ctx); |
|
|
|
|
|
|
|
if (!grad_key) { |
|
|
|
return outputs; |
|
|
|
return apply(ctx); |
|
|
|
} |
|
|
|
|
|
|
|
auto backward_graph = make_backward_graph(ctx, outputs); |
|
|
|
if (!backward_graph) { |
|
|
|
GradFnHelper grad_fn_holder; |
|
|
|
auto outputs = backward_graph_grad_rule(ctx, grad_fn_holder); |
|
|
|
|
|
|
|
auto& grad_fn = grad_fn_holder.grad_fn; |
|
|
|
if (!grad_fn) { |
|
|
|
return outputs; |
|
|
|
} |
|
|
|
|
|
|
|
auto grad_fn = std::make_shared<GradFn>(); |
|
|
|
grad_fn->key = grad_key; |
|
|
|
grad_fn->slots.resize(outputs.size()); |
|
|
|
grad_fn->backward_graph = std::move(backward_graph); |
|
|
|
|
|
|
|
grad_fn->dsts.reserve(ctx.nargs); |
|
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
|
if (grad_fn->backward_graph->input_has_grad[i]) { |
|
|
|
auto& input_grad_info = ctx.args[i]->m_grad_info; |
|
|
|
grad_fn->dsts.emplace_back(input_grad_info); |
|
|
|
grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head); |
|
|
|
} else { |
|
|
|
grad_fn->dsts.emplace_back(); |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
|
auto& save_for_backward = grad_fn->backward_graph->save_for_backward; |
|
|
|
grad_fn->closure.reserve(std::count_if(save_for_backward.begin(), save_for_backward.end(), [](bool p){return p;})); |
|
|
|
|
|
|
|
// given op, taking gradient of output_tensor_list wrt input_tensor_list: |
|
|
|
// |
|
|
|
// save_for_backward[0:nargs-1]: whether input tensor requires gradient, |
|
|
|
// i.e., whether it is in input_tensor_list |
|
|
|
// |
|
|
|
// save_for_backward[nargs:nargs+outputs.size()-1]: whether output tensor is |
|
|
|
// needed to calculate gradients |
|
|
|
// |
|
|
|
// save_for_backward[-outputs.size():]: whether output tensor is in |
|
|
|
// output_tensor_list |
|
|
|
// |
|
|
|
// Example: perform c = a * b, where a is input data, b is parameter to be |
|
|
|
// optimized, save_for_backward = [1, 1, 0, 1] |
|
|
|
mgb_assert(ctx.nargs + 2 * outputs.size() == save_for_backward.size()); |
|
|
|
|
|
|
|
// record input tensors needed to take grad |
|
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
|
if (save_for_backward[i]) { |
|
|
|
grad_fn->closure.push_back(ctx.args[i]->shared_from_this()); |
|
|
|
} |
|
|
|
} |
|
|
|
// record output tensors needed to take grad |
|
|
|
for (size_t i = 0; i < outputs.size(); ++i) { |
|
|
|
bool requires_grad = save_for_backward[ctx.nargs + outputs.size() + i]; |
|
|
|
if (save_for_backward[ctx.nargs + i]) { |
|
|
|
grad_fn->closure.push_back(outputs[i]); |
|
|
|
if (requires_grad) { |
|
|
|
// avoid reference cycle [Tensor <-> GradFn] |
|
|
|
outputs[i] = outputs[i]->copy(); |
|
|
|
std::visit([&](auto& backward) { |
|
|
|
using T = std::decay_t<decltype(backward)>; |
|
|
|
if constexpr (std::is_same_v<T, std::monostate>) { |
|
|
|
mgb_assert(0); |
|
|
|
} else { |
|
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
|
if (backward.input_has_grad(i)) { |
|
|
|
auto& input_grad_info = ctx.args[i]->m_grad_info; |
|
|
|
grad_fn->dsts.emplace_back(input_grad_info); |
|
|
|
// register as grad producer |
|
|
|
grad_fn->dsts.back().producer_record.insert_after(input_grad_info->producer_head); |
|
|
|
} else { |
|
|
|
grad_fn->dsts.emplace_back(); |
|
|
|
} |
|
|
|
} |
|
|
|
for (size_t i = 0; i < outputs.size(); ++i) { |
|
|
|
if (backward.output_requires_grad(i)) { |
|
|
|
if (backward.output_captured(i)) { |
|
|
|
// avoid reference cycle [Tensor <-> GradFn] |
|
|
|
outputs[i] = outputs[i]->copy(); |
|
|
|
} |
|
|
|
// populate grad info of output tensor |
|
|
|
auto& grad_info = outputs[i]->m_grad_info; |
|
|
|
grad_info.grad_fn = grad_fn; |
|
|
|
grad_info.idx = i; |
|
|
|
grad_info.insert_after(grad_key->free_vars_head); |
|
|
|
outputs[i]->m_flags |= Tensor::Flags::GRAD; |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
if (requires_grad) { |
|
|
|
auto& grad_info = outputs[i]->m_grad_info; |
|
|
|
grad_info.grad_fn = grad_fn; |
|
|
|
grad_info.idx = i; |
|
|
|
grad_info.insert_after(grad_key->free_vars_head); |
|
|
|
outputs[i]->m_flags |= Tensor::Flags::GRAD; |
|
|
|
} |
|
|
|
} |
|
|
|
}, grad_fn->backward); |
|
|
|
|
|
|
|
// record forward history |
|
|
|
grad_key->tape.emplace_back(grad_fn); |
|
|
@@ -334,54 +446,30 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr |
|
|
|
for (std::ptrdiff_t k = tape.size() - 1; k >= 0; --k) { |
|
|
|
auto&& grad_fn = tape[k].lock(); |
|
|
|
if (!grad_fn) continue; |
|
|
|
if (grad_fn->backward_graph) { |
|
|
|
for (size_t i = 0; i < grad_fn->slots.size(); ++i) { |
|
|
|
// grad_fn->dsts correspond to input tensors during forward |
|
|
|
// calculation, grad_fn->slots correspond to output tensors. |
|
|
|
// condition true means the output tensor has gradient for |
|
|
|
// back-propagation |
|
|
|
if (grad_fn->backward_graph->save_for_backward[grad_fn->dsts.size() + grad_fn->slots.size() + i]) { |
|
|
|
grad_fn->closure.push_back(std::move(grad_fn->slots[i].grad)); |
|
|
|
} |
|
|
|
} |
|
|
|
ApplyContext ctx; |
|
|
|
ctx.op = grad_fn->backward_graph->backward; |
|
|
|
ctx.flags = 0; |
|
|
|
ctx.nargs = grad_fn->closure.size(); |
|
|
|
Tensor* args[ctx.nargs]; |
|
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
|
args[i] = grad_fn->closure[i].get(); |
|
|
|
mgb_assert(args[i]); |
|
|
|
ctx.flags |= args[i]->m_flags; |
|
|
|
} |
|
|
|
ctx.args = args; |
|
|
|
|
|
|
|
if (is_tracing) |
|
|
|
ctx.flags |= Tensor::Flags::TRACE; |
|
|
|
|
|
|
|
auto grads = apply(ctx); |
|
|
|
|
|
|
|
size_t j = 0; |
|
|
|
for (size_t i = 0; i < grad_fn->dsts.size(); ++i) { |
|
|
|
if (grad_fn->backward_graph->input_has_grad[i]) { |
|
|
|
auto& dst = grad_fn->dsts[i]; |
|
|
|
// grads[j] is consumed in accum_grad |
|
|
|
accum_grad(dst->grad, std::move(grads[j])); |
|
|
|
++j; |
|
|
|
} |
|
|
|
auto grad_receiver = [&](size_t i, auto&& g) { |
|
|
|
accum_grad(grad_fn->dsts[i]->grad, std::forward<decltype(g)>(g)); |
|
|
|
}; |
|
|
|
std::visit([&](auto&& backward) { |
|
|
|
using T = std::decay_t<decltype(backward)>; |
|
|
|
if constexpr (std::is_same_v<T, std::monostate>) { |
|
|
|
mgb_assert(0); |
|
|
|
} else { |
|
|
|
auto&& grads = views::transform(grad_fn->slots, [](auto&& slot) {return slot.grad.get();}); |
|
|
|
backward(std::forward<decltype(grads)>(grads), grad_receiver); |
|
|
|
} |
|
|
|
mgb_assert(j == grads.size()); |
|
|
|
} |
|
|
|
}, grad_fn->backward); |
|
|
|
|
|
|
|
for (auto&& dst : grad_fn->dsts) { |
|
|
|
if (!dst.grad_fn) continue; |
|
|
|
if (!dst.grad_fn->in_ref_keeper) { |
|
|
|
// after grad_fn is cleared, refcnt of subsequent grad_fn |
|
|
|
// could drop to 0 |
|
|
|
dst.grad_fn->in_ref_keeper = true; |
|
|
|
ref_keeper.push_back(dst.grad_fn); |
|
|
|
} |
|
|
|
// grad_fn->clear will unlink current dst.producer_record |
|
|
|
// such that if dst.producer_record.next is false, dst accumulates |
|
|
|
// all the gradients |
|
|
|
if (!dst.producer_record.next && dst->callback && dst->grad) { |
|
|
|
// I'm the last grad producer, invoke callback |
|
|
|
dst->callback(TensorWrapper::make(pytype, dst->grad)); |
|
|
|
} |
|
|
|
} |
|
|
|