Browse Source

fix(mge): fix segfault with Function returning unused grads

GitOrigin-RevId: 0cce845923
release-1.2
Megvii Engine Team 4 years ago
parent
commit
3faba54f28
2 changed files with 9 additions and 1 deletions
  1. +8
    -1
      imperative/python/src/grad.cpp
  2. +1
    -0
      imperative/python/src/grad_info.h

+ 8
- 1
imperative/python/src/grad.cpp View File

@@ -282,6 +282,10 @@ struct GradFn : std::enable_shared_from_this<GradFn> {
}
};

GradSlotPtr::operator bool() const {
return bool(grad_fn);
}

GradSlot* GradSlotPtr::operator->() {
return &grad_fn->slots[idx];
}
@@ -537,7 +541,10 @@ void GradKey::backward(std::vector<TensorWrapper*> tensors, std::vector<TensorWr
if (!grad_fn) continue;

auto grad_receiver = [&](size_t i, auto&& g) {
accum_grad(grad_fn->dsts[i]->grad, std::forward<decltype(g)>(g));
auto& dst = grad_fn->dsts[i];
if (dst) {
accum_grad(dst->grad, std::forward<decltype(g)>(g));
}
};
std::visit([&](auto&& backward) {
using T = std::decay_t<decltype(backward)>;


+ 1
- 0
imperative/python/src/grad_info.h View File

@@ -22,6 +22,7 @@ struct GradSlotPtr {
std::shared_ptr<GradFn> grad_fn;
size_t idx;

operator bool() const;
GradSlot* operator->();
};



Loading…
Cancel
Save