diff --git a/imperative/python/megengine/experimental/__init__.py b/imperative/python/megengine/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/imperative/python/megengine/experimental/autograd.py b/imperative/python/megengine/experimental/autograd.py new file mode 100644 index 00000000..8c8b5d25 --- /dev/null +++ b/imperative/python/megengine/experimental/autograd.py @@ -0,0 +1,25 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +from ..core._imperative_rt.core2 import ( + set_allow_higher_order_directive as _set_allow_higher_order_directive, +) + +__all__ = [ + "enable_higher_order_directive", + "disable_higher_order_directive", +] + + +def enable_higher_order_directive(): + _set_allow_higher_order_directive(True) + + +def disable_higher_order_directive(): + _set_allow_higher_order_directive(False) diff --git a/imperative/python/src/grad.cpp b/imperative/python/src/grad.cpp index e47f733d..898919cc 100644 --- a/imperative/python/src/grad.cpp +++ b/imperative/python/src/grad.cpp @@ -271,7 +271,7 @@ struct GradFn : std::enable_shared_from_this { pool.free(ptr); } - std::shared_ptr make() { + static std::shared_ptr make() { return std::shared_ptr(pool.alloc(), &deleter); } @@ -316,14 +316,18 @@ public: apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { // copy inputs first, or trace will make InputNodes for each usage + ApplyContext ctx_dup = ctx; SmallVector> inputs_copy; SmallVector inputs_copy_weak; for (size_t i = 0; i < ctx.nargs; ++i) { - inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]); + Tensor* input = ctx.args[i]; + inputs_copy.push_back(python::apply(FastpathCopy::make(), input)[0]); inputs_copy_weak.push_back(inputs_copy.back().get()); inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict; + if (input->m_flags & Flags::GRAD) { + inputs_copy.back()->m_flags |= Flags::GRAD; + } } - ApplyContext ctx_dup = ctx; ctx_dup.args = inputs_copy_weak.data(); auto outputs = apply(ctx_dup); @@ -332,7 +336,6 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra if (!backward_graph) { return outputs; } - ret_grad_fn.emplace(std::move(backward_graph), ctx_dup, outputs); return outputs; @@ -389,6 +392,12 @@ apply_result_t apply_grad(ApplyContext& ctx) { if (grad_keys.empty()) { return apply(ctx); + } else if (grad_keys.size() > 1 && !GradKey::allow_higher_order_directive) { + PyErr_SetString( + PyExc_NotImplementedError, + "second order directive not enabled, please call " + "'megengine.experimental.enable_higher_order_directive'"); + throw pyext17::py_err_set(); } GradFnHelper grad_fn_holder; diff --git a/imperative/python/src/grad.h b/imperative/python/src/grad.h index 17f28dbb..c5068e20 100644 --- a/imperative/python/src/grad.h +++ b/imperative/python/src/grad.h @@ -36,6 +36,7 @@ struct GradKey : std::enable_shared_from_this, NonCopyableObj { bool is_blocked() const { return priority < sm_min_priority; } + inline static bool allow_higher_order_directive = false; private: static int sm_min_priority; }; diff --git a/imperative/python/src/tensor.cpp b/imperative/python/src/tensor.cpp index bec4cb77..c4db334e 100644 --- a/imperative/python/src/tensor.cpp +++ b/imperative/python/src/tensor.cpp @@ -990,6 +990,9 @@ void init_tensor(py::module m) { m.def("set_tracing", &set_tracing); m.def("unset_tracing", &unset_tracing); + m.def("set_allow_higher_order_directive", [](bool value){ + GradKey::allow_higher_order_directive = value; + }); } #undef MGE_PY_INTERFACE diff --git a/imperative/python/test/conftest.py b/imperative/python/test/conftest.py index f012c400..ed598ced 100644 --- a/imperative/python/test/conftest.py +++ b/imperative/python/test/conftest.py @@ -1,3 +1,10 @@ +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. import os import platform import sys @@ -9,6 +16,10 @@ import megengine.module from megengine import Parameter from megengine.core._imperative_rt.core2 import sync from megengine.distributed.helper import get_device_count_by_fork +from megengine.experimental.autograd import ( + disable_higher_order_directive, + enable_higher_order_directive, +) from megengine.jit import trace as _trace from megengine.module import Linear, Module @@ -34,3 +45,13 @@ def skip_distributed(request): platform.system() ) ) + + +@pytest.fixture(autouse=True) +def resolve_require_higher_order_directive(request): + marker = request.node.get_closest_marker("require_higher_order_directive") + if marker: + enable_higher_order_directive() + yield + if marker: + disable_higher_order_directive() diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index 2f82f6c0..ddb6ad5b 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -281,6 +281,7 @@ def test_broadcast_grad(trace_mode): worker() +@pytest.mark.require_higher_order_directive() def test_2nd_grad_with_manager(): x_np = np.random.rand(10).astype("float32") x = mge.tensor(x_np) @@ -299,6 +300,7 @@ def test_2nd_grad_with_manager(): ) +@pytest.mark.require_higher_order_directive() def test_grad_manager_group(): x_np = np.random.rand(10).astype("float32") x = mge.tensor(x_np) @@ -315,6 +317,7 @@ def test_grad_manager_group(): x.grad = None +@pytest.mark.require_higher_order_directive() def test_grad_manager_group_visibility(): x_np = np.random.rand(10).astype("float32") x = mge.tensor(x_np) @@ -330,6 +333,7 @@ def test_grad_manager_group_visibility(): np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) +@pytest.mark.require_higher_order_directive() def test_grad_manager_visibility_by_order(): x_np = np.random.rand(10).astype("float32") x = mge.tensor(x_np) diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index 9765ba88..fa90ab40 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -108,6 +108,7 @@ def test_grad_2(): np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) +@pytest.mark.require_higher_order_directive() def test_2nd_grad(): x_np = np.random.rand(10).astype("float32") x = as_tensor(x_np)