|
|
@@ -315,14 +315,25 @@ public: |
|
|
|
}; |
|
|
|
|
|
|
|
apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { |
|
|
|
auto outputs = apply(ctx); |
|
|
|
// copy inputs first, or trace will make InputNodes for each usage |
|
|
|
SmallVector<std::shared_ptr<Tensor>> inputs_copy; |
|
|
|
SmallVector<Tensor*> inputs_copy_weak; |
|
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
|
inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]); |
|
|
|
inputs_copy_weak.push_back(inputs_copy.back().get()); |
|
|
|
inputs_copy.back()->m_grad_info = ctx.args[i]->m_grad_info; |
|
|
|
} |
|
|
|
ApplyContext ctx_dup = ctx; |
|
|
|
ctx_dup.args = inputs_copy_weak.data(); |
|
|
|
|
|
|
|
auto outputs = apply(ctx_dup); |
|
|
|
|
|
|
|
auto backward_graph = make_backward_graph(ctx, outputs); |
|
|
|
auto backward_graph = make_backward_graph(ctx_dup, outputs); |
|
|
|
if (!backward_graph) { |
|
|
|
return outputs; |
|
|
|
} |
|
|
|
|
|
|
|
ret_grad_fn.emplace<BackwardGraphWithClosure>(std::move(backward_graph), ctx, outputs); |
|
|
|
ret_grad_fn.emplace<BackwardGraphWithClosure>(std::move(backward_graph), ctx_dup, outputs); |
|
|
|
|
|
|
|
return outputs; |
|
|
|
} |
|
|
|