From 0a3ca253373186368d46676e357d1198e7f3726a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 14 Jan 2021 17:15:37 +0800 Subject: [PATCH] fix(mge): fix backward graph optimization GitOrigin-RevId: 28fd00ac548e7cd663a87ce74070862e9a24b55f --- imperative/python/test/unit/core/test_autodiff.py | 16 +++++++++++++++- imperative/src/impl/backward_graph_opt.cpp | 2 +- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/imperative/python/test/unit/core/test_autodiff.py b/imperative/python/test/unit/core/test_autodiff.py index 889aee7e..701aca37 100644 --- a/imperative/python/test/unit/core/test_autodiff.py +++ b/imperative/python/test/unit/core/test_autodiff.py @@ -19,7 +19,7 @@ import megengine.functional as F from megengine.core._imperative_rt import CompNode, TensorAttr, imperative from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync from megengine.core.autodiff.grad import Grad -from megengine.core.ops.builtin import Elemwise +from megengine.core.ops.builtin import Elemwise, Identity from megengine.distributed.helper import get_device_count_by_fork from megengine.functional.distributed import remote_recv, remote_send @@ -193,6 +193,20 @@ def test_grad_inplace(): np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) +def test_identity(): + x_np = np.random.rand(10).astype("float32") + x = mge.Tensor(x_np) + dy_np = np.random.rand(*x.shape).astype("float32") + dy = mge.Tensor(dy_np) + + grad = Grad().wrt(x, callback=save_to(x)) + + (y,) = apply(Identity(), x) + + grad(y, dy) + np.testing.assert_array_equal(x.grad.numpy(), dy_np) + + def test_elemwise_add(): x_np = np.random.rand(10).astype("float32") y_np = np.random.rand(10, 10).astype("float32") diff --git a/imperative/src/impl/backward_graph_opt.cpp b/imperative/src/impl/backward_graph_opt.cpp index 838fd49e..61b3bc54 100644 --- a/imperative/src/impl/backward_graph_opt.cpp +++ b/imperative/src/impl/backward_graph_opt.cpp @@ -58,7 +58,7 @@ OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphRe // should be marked as always appears in backward for (size_t i = 0, j = 0; i < mask.size(); ++i) { if (!mask[i]) continue; - if (i > input_size + output_size) { + if (i >= input_size + output_size) { vinfo[graph.inputs[j]].appears_in_backward = true; } ++j;