Browse Source

perf(autograd): copy inputs before capture in backward_graph_grad_rule

GitOrigin-RevId: 8b9c067b2d
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
b7e596b4a1
1 changed files with 14 additions and 3 deletions
  1. +14
    -3
      imperative/python/src/grad.cpp

+ 14
- 3
imperative/python/src/grad.cpp View File

@@ -315,14 +315,25 @@ public:
};

apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
auto outputs = apply(ctx);
// copy inputs first, or trace will make InputNodes for each usage
SmallVector<std::shared_ptr<Tensor>> inputs_copy;
SmallVector<Tensor*> inputs_copy_weak;
for (size_t i = 0; i < ctx.nargs; ++i) {
inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]);
inputs_copy_weak.push_back(inputs_copy.back().get());
inputs_copy.back()->m_grad_info = ctx.args[i]->m_grad_info;
}
ApplyContext ctx_dup = ctx;
ctx_dup.args = inputs_copy_weak.data();

auto outputs = apply(ctx_dup);

auto backward_graph = make_backward_graph(ctx, outputs);
auto backward_graph = make_backward_graph(ctx_dup, outputs);
if (!backward_graph) {
return outputs;
}

ret_grad_fn.emplace<BackwardGraphWithClosure>(std::move(backward_graph), ctx, outputs);
ret_grad_fn.emplace<BackwardGraphWithClosure>(std::move(backward_graph), ctx_dup, outputs);

return outputs;
}


Loading…
Cancel
Save