|
|
@@ -282,6 +282,10 @@ struct GradFn : std::enable_shared_from_this<GradFn> { |
|
|
|
} |
|
|
|
}; |
|
|
|
|
|
|
|
GradSlotPtr::operator bool() const { |
|
|
|
return bool(grad_fn); |
|
|
|
} |
|
|
|
|
|
|
|
GradSlot* GradSlotPtr::operator->() { |
|
|
|
return &grad_fn->slots[idx]; |
|
|
|
} |
|
|
@@ -537,7 +541,10 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr |
|
|
|
if (!grad_fn) continue; |
|
|
|
|
|
|
|
auto grad_receiver = [&](size_t i, auto&& g) { |
|
|
|
accum_grad(grad_fn->dsts[i]->grad, std::forward<decltype(g)>(g)); |
|
|
|
auto& dst = grad_fn->dsts[i]; |
|
|
|
if (dst) { |
|
|
|
accum_grad(dst->grad, std::forward<decltype(g)>(g)); |
|
|
|
} |
|
|
|
}; |
|
|
|
std::visit([&](auto&& backward) { |
|
|
|
using T = std::decay_t<decltype(backward)>; |
|
|
|