Browse Source

fix(autograd): make higher order grad experimental

GitOrigin-RevId: 81e1eb0ebf
release-1.5
Megvii Engine Team huangxinda 4 years ago
parent
commit
8480302da8
8 changed files with 68 additions and 4 deletions
  1. +0
    -0
      imperative/python/megengine/experimental/__init__.py
  2. +25
    -0
      imperative/python/megengine/experimental/autograd.py
  3. +13
    -4
      imperative/python/src/grad.cpp
  4. +1
    -0
      imperative/python/src/grad.h
  5. +3
    -0
      imperative/python/src/tensor.cpp
  6. +21
    -0
      imperative/python/test/conftest.py
  7. +4
    -0
      imperative/python/test/unit/autodiff/test_grad_manger.py
  8. +1
    -0
      imperative/python/test/unit/core/test_autodiff.py

+ 0
- 0
imperative/python/megengine/experimental/__init__.py View File


+ 25
- 0
imperative/python/megengine/experimental/autograd.py View File

@@ -0,0 +1,25 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

from ..core._imperative_rt.core2 import (
set_allow_higher_order_directive as _set_allow_higher_order_directive,
)

__all__ = [
"enable_higher_order_directive",
"disable_higher_order_directive",
]


def enable_higher_order_directive():
_set_allow_higher_order_directive(True)


def disable_higher_order_directive():
_set_allow_higher_order_directive(False)

+ 13
- 4
imperative/python/src/grad.cpp View File

@@ -271,7 +271,7 @@ struct GradFn : std::enable_shared_from_this<GradFn> {
pool.free(ptr);
}

std::shared_ptr<GradFn> make() {
static std::shared_ptr<GradFn> make() {
return std::shared_ptr<GradFn>(pool.alloc(), &deleter);
}

@@ -316,14 +316,18 @@ public:

apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_grad_fn) {
// copy inputs first, or trace will make InputNodes for each usage
ApplyContext ctx_dup = ctx;
SmallVector<std::shared_ptr<Tensor>> inputs_copy;
SmallVector<Tensor*> inputs_copy_weak;
for (size_t i = 0; i < ctx.nargs; ++i) {
inputs_copy.push_back(python::apply(FastpathCopy::make(), ctx.args[i]->shared_from_this())[0]);
Tensor* input = ctx.args[i];
inputs_copy.push_back(python::apply(FastpathCopy::make(), input)[0]);
inputs_copy_weak.push_back(inputs_copy.back().get());
inputs_copy.back()->m_grad_info_dict = ctx.args[i]->m_grad_info_dict;
if (input->m_flags & Flags::GRAD) {
inputs_copy.back()->m_flags |= Flags::GRAD;
}
}
ApplyContext ctx_dup = ctx;
ctx_dup.args = inputs_copy_weak.data();

auto outputs = apply(ctx_dup);
@@ -332,7 +336,6 @@ apply_result_t backward_graph_grad_rule(ApplyContext& ctx, GradFnHelper& ret_gra
if (!backward_graph) {
return outputs;
}

ret_grad_fn.emplace<BackwardGraphWithClosure>(std::move(backward_graph), ctx_dup, outputs);

return outputs;
@@ -389,6 +392,12 @@ apply_result_t apply_grad(ApplyContext& ctx) {

if (grad_keys.empty()) {
return apply(ctx);
} else if (grad_keys.size() > 1 && !GradKey::allow_higher_order_directive) {
PyErr_SetString(
PyExc_NotImplementedError,
"second order directive not enabled, please call "
"'megengine.experimental.enable_higher_order_directive'");
throw pyext17::py_err_set();
}

GradFnHelper grad_fn_holder;


+ 1
- 0
imperative/python/src/grad.h View File

@@ -36,6 +36,7 @@ struct GradKey : std::enable_shared_from_this<GradKey>, NonCopyableObj {
bool is_blocked() const {
return priority < sm_min_priority;
}
inline static bool allow_higher_order_directive = false;
private:
static int sm_min_priority;
};


+ 3
- 0
imperative/python/src/tensor.cpp View File

@@ -990,6 +990,9 @@ void init_tensor(py::module m) {

m.def("set_tracing", &set_tracing);
m.def("unset_tracing", &unset_tracing);
m.def("set_allow_higher_order_directive", [](bool value){
GradKey::allow_higher_order_directive = value;
});
}

#undef MGE_PY_INTERFACE


+ 21
- 0
imperative/python/test/conftest.py View File

@@ -1,3 +1,10 @@
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os
import platform
import sys
@@ -9,6 +16,10 @@ import megengine.module
from megengine import Parameter
from megengine.core._imperative_rt.core2 import sync
from megengine.distributed.helper import get_device_count_by_fork
from megengine.experimental.autograd import (
disable_higher_order_directive,
enable_higher_order_directive,
)
from megengine.jit import trace as _trace
from megengine.module import Linear, Module

@@ -34,3 +45,13 @@ def skip_distributed(request):
platform.system()
)
)


@pytest.fixture(autouse=True)
def resolve_require_higher_order_directive(request):
marker = request.node.get_closest_marker("require_higher_order_directive")
if marker:
enable_higher_order_directive()
yield
if marker:
disable_higher_order_directive()

+ 4
- 0
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -281,6 +281,7 @@ def test_broadcast_grad(trace_mode):
worker()


@pytest.mark.require_higher_order_directive()
def test_2nd_grad_with_manager():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
@@ -299,6 +300,7 @@ def test_2nd_grad_with_manager():
)


@pytest.mark.require_higher_order_directive()
def test_grad_manager_group():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
@@ -315,6 +317,7 @@ def test_grad_manager_group():
x.grad = None


@pytest.mark.require_higher_order_directive()
def test_grad_manager_group_visibility():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
@@ -330,6 +333,7 @@ def test_grad_manager_group_visibility():
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)


@pytest.mark.require_higher_order_directive()
def test_grad_manager_visibility_by_order():
x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)


+ 1
- 0
imperative/python/test/unit/core/test_autodiff.py View File

@@ -108,6 +108,7 @@ def test_grad_2():
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6)


@pytest.mark.require_higher_order_directive()
def test_2nd_grad():
x_np = np.random.rand(10).astype("float32")
x = as_tensor(x_np)


Loading…
Cancel
Save