GitOrigin-RevId: 81e1eb0ebf
release-1.5
@@ -0,0 +1,25 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from ..core._imperative_rt.core2 import ( | |||||
set_allow_higher_order_directive as _set_allow_higher_order_directive, | |||||
) | |||||
__all__ = [ | |||||
"enable_higher_order_directive", | |||||
"disable_higher_order_directive", | |||||
] | |||||
def enable_higher_order_directive(): | |||||
_set_allow_higher_order_directive(True) | |||||
def disable_higher_order_directive(): | |||||
_set_allow_higher_order_directive(False) |
@@ -271,7 +271,7 @@ struct GradFn : std::enable_shared_from_this<GradFn> { | |||||
pool.free(ptr); | pool.free(ptr); | ||||
} | } | ||||
std::shared_ptr<GradFn> make() { | |||||
static std::shared_ptr<GradFn> make() { | |||||
return std::shared_ptr<GradFn>(pool.alloc(), &deleter); | return std::shared_ptr<GradFn>(pool.alloc(), &deleter); | ||||
} | } | ||||
@@ -316,14 +316,18 @@ public: | |||||
apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { | apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) { | ||||
// copy inputs first, or trace will make InputNodes for each usage | // copy inputs first, or trace will make InputNodes for each usage | ||||
ApplyContext ctx_dup = ctx; | |||||
SmallVector<std::shared_ptr<Tensor>> inputs_copy; | SmallVector<std::shared_ptr<Tensor>> inputs_copy; | ||||
SmallVector<Tensor*> inputs_copy_weak; | SmallVector<Tensor*> inputs_copy_weak; | ||||
for (size_t i = 0; i < ctx.nargs; ++i) { | for (size_t i = 0; i < ctx.nargs; ++i) { | ||||
inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]); | |||||
Tensor* input = ctx.args[i]; | |||||
inputs_copy.push_back(python::apply(FastpathCopy::make(), input)[0]); | |||||
inputs_copy_weak.push_back(inputs_copy.back().get()); | inputs_copy_weak.push_back(inputs_copy.back().get()); | ||||
inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict; | inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict; | ||||
if (input->m_flags & Flags::GRAD) { | |||||
inputs_copy.back()->m_flags |= Flags::GRAD; | |||||
} | |||||
} | } | ||||
ApplyContext ctx_dup = ctx; | |||||
ctx_dup.args = inputs_copy_weak.data(); | ctx_dup.args = inputs_copy_weak.data(); | ||||
auto outputs = apply(ctx_dup); | auto outputs = apply(ctx_dup); | ||||
@@ -332,7 +336,6 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra | |||||
if (!backward_graph) { | if (!backward_graph) { | ||||
return outputs; | return outputs; | ||||
} | } | ||||
ret_grad_fn.emplace<BackwardGraphWithClosure>(std::move(backward_graph), ctx_dup, outputs); | ret_grad_fn.emplace<BackwardGraphWithClosure>(std::move(backward_graph), ctx_dup, outputs); | ||||
return outputs; | return outputs; | ||||
@@ -389,6 +392,12 @@ apply_result_t apply_grad(ApplyContext& ctx) { | |||||
if (grad_keys.empty()) { | if (grad_keys.empty()) { | ||||
return apply(ctx); | return apply(ctx); | ||||
} else if (grad_keys.size() > 1 && !GradKey::allow_higher_order_directive) { | |||||
PyErr_SetString( | |||||
PyExc_NotImplementedError, | |||||
"second order directive not enabled, please call " | |||||
"'megengine.experimental.enable_higher_order_directive'"); | |||||
throw pyext17::py_err_set(); | |||||
} | } | ||||
GradFnHelper grad_fn_holder; | GradFnHelper grad_fn_holder; | ||||
@@ -36,6 +36,7 @@ struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj { | |||||
bool is_blocked() const { | bool is_blocked() const { | ||||
return priority < sm_min_priority; | return priority < sm_min_priority; | ||||
} | } | ||||
inline static bool allow_higher_order_directive = false; | |||||
private: | private: | ||||
static int sm_min_priority; | static int sm_min_priority; | ||||
}; | }; | ||||
@@ -990,6 +990,9 @@ void init_tensor(py::module m) { | |||||
m.def("set_tracing", &set_tracing); | m.def("set_tracing", &set_tracing); | ||||
m.def("unset_tracing", &unset_tracing); | m.def("unset_tracing", &unset_tracing); | ||||
m.def("set_allow_higher_order_directive", [](bool value){ | |||||
GradKey::allow_higher_order_directive = value; | |||||
}); | |||||
} | } | ||||
#undef MGE_PY_INTERFACE | #undef MGE_PY_INTERFACE | ||||
@@ -1,3 +1,10 @@ | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
import os | import os | ||||
import platform | import platform | ||||
import sys | import sys | ||||
@@ -9,6 +16,10 @@ import megengine.module | |||||
from megengine import Parameter | from megengine import Parameter | ||||
from megengine.core._imperative_rt.core2 import sync | from megengine.core._imperative_rt.core2 import sync | ||||
from megengine.distributed.helper import get_device_count_by_fork | from megengine.distributed.helper import get_device_count_by_fork | ||||
from megengine.experimental.autograd import ( | |||||
disable_higher_order_directive, | |||||
enable_higher_order_directive, | |||||
) | |||||
from megengine.jit import trace as _trace | from megengine.jit import trace as _trace | ||||
from megengine.module import Linear, Module | from megengine.module import Linear, Module | ||||
@@ -34,3 +45,13 @@ def skip_distributed(request): | |||||
platform.system() | platform.system() | ||||
) | ) | ||||
) | ) | ||||
@pytest.fixture(autouse=True) | |||||
def resolve_require_higher_order_directive(request): | |||||
marker = request.node.get_closest_marker("require_higher_order_directive") | |||||
if marker: | |||||
enable_higher_order_directive() | |||||
yield | |||||
if marker: | |||||
disable_higher_order_directive() |
@@ -281,6 +281,7 @@ def test_broadcast_grad(trace_mode): | |||||
worker() | worker() | ||||
@pytest.mark.require_higher_order_directive() | |||||
def test_2nd_grad_with_manager(): | def test_2nd_grad_with_manager(): | ||||
x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
x = mge.tensor(x_np) | x = mge.tensor(x_np) | ||||
@@ -299,6 +300,7 @@ def test_2nd_grad_with_manager(): | |||||
) | ) | ||||
@pytest.mark.require_higher_order_directive() | |||||
def test_grad_manager_group(): | def test_grad_manager_group(): | ||||
x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
x = mge.tensor(x_np) | x = mge.tensor(x_np) | ||||
@@ -315,6 +317,7 @@ def test_grad_manager_group(): | |||||
x.grad = None | x.grad = None | ||||
@pytest.mark.require_higher_order_directive() | |||||
def test_grad_manager_group_visibility(): | def test_grad_manager_group_visibility(): | ||||
x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
x = mge.tensor(x_np) | x = mge.tensor(x_np) | ||||
@@ -330,6 +333,7 @@ def test_grad_manager_group_visibility(): | |||||
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) | ||||
@pytest.mark.require_higher_order_directive() | |||||
def test_grad_manager_visibility_by_order(): | def test_grad_manager_visibility_by_order(): | ||||
x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
x = mge.tensor(x_np) | x = mge.tensor(x_np) | ||||
@@ -108,6 +108,7 @@ def test_grad_2(): | |||||
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) | ||||
@pytest.mark.require_higher_order_directive() | |||||
def test_2nd_grad(): | def test_2nd_grad(): | ||||
x_np = np.random.rand(10).astype("float32") | x_np = np.random.rand(10).astype("float32") | ||||
x = as_tensor(x_np) | x = as_tensor(x_np) | ||||