|
|
@@ -271,7 +271,7 @@ struct GradFn : std::enable_shared_from_this<GradFn> { |
|
|
|
pool.free(ptr); |
|
|
|
} |
|
|
|
|
|
|
|
std::shared_ptr<GradFn> make() { |
|
|
|
static std::shared_ptr<GradFn> make() { |
|
|
|
return std::shared_ptr<GradFn>(pool.alloc(), &deleter); |
|
|
|
} |
|
|
|
|
|
|
@@ -316,14 +316,18 @@ public: |
|
|
|
|
|
|
|
apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { |
|
|
|
// copy inputs first, or trace will make InputNodes for each usage |
|
|
|
ApplyContext ctx_dup = ctx; |
|
|
|
SmallVector<std::shared_ptr<Tensor>> inputs_copy; |
|
|
|
SmallVector<Tensor*> inputs_copy_weak; |
|
|
|
for (size_t i = 0; i < ctx.nargs; ++i) { |
|
|
|
inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]); |
|
|
|
Tensor* input = ctx.args[i]; |
|
|
|
inputs_copy.push_back(python::apply(FastpathCopy::make(), input)[0]); |
|
|
|
inputs_copy_weak.push_back(inputs_copy.back().get()); |
|
|
|
inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict; |
|
|
|
if (input->m_flags & Flags::GRAD) { |
|
|
|
inputs_copy.back()->m_flags |= Flags::GRAD; |
|
|
|
} |
|
|
|
} |
|
|
|
ApplyContext ctx_dup = ctx; |
|
|
|
ctx_dup.args = inputs_copy_weak.data(); |
|
|
|
|
|
|
|
auto outputs = apply(ctx_dup); |
|
|
@@ -332,7 +336,6 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra |
|
|
|
if (!backward_graph) { |
|
|
|
return outputs; |
|
|
|
} |
|
|
|
|
|
|
|
ret_grad_fn.emplace<BackwardGraphWithClosure>(std::move(backward_graph), ctx_dup, outputs); |
|
|
|
|
|
|
|
return outputs; |
|
|
@@ -389,6 +392,12 @@ apply_result_t apply_grad(ApplyContext& ctx) { |
|
|
|
|
|
|
|
if (grad_keys.empty()) { |
|
|
|
return apply(ctx); |
|
|
|
} else if (grad_keys.size() > 1 && !GradKey::allow_higher_order_directive) { |
|
|
|
PyErr_SetString( |
|
|
|
PyExc_NotImplementedError, |
|
|
|
"second order directive not enabled, please call " |
|
|
|
"'megengine.experimental.enable_higher_order_directive'"); |
|
|
|
throw pyext17::py_err_set(); |
|
|
|
} |
|
|
|
|
|
|
|
GradFnHelper grad_fn_holder; |
|
|
|