Browse Source

fix(mge): fix backward graph optimization

GitOrigin-RevId: 28fd00ac54
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
0a3ca25337
2 changed files with 16 additions and 2 deletions
  1. +15
    -1
      imperative/python/test/unit/core/test_autodiff.py
  2. +1
    -1
      imperative/src/impl/backward_graph_opt.cpp

+ 15
- 1
imperative/python/test/unit/core/test_autodiff.py View File

@@ -19,7 +19,7 @@ import megengine.functional as F
from megengine.core._imperative_rt import CompNode, TensorAttr, imperative from megengine.core._imperative_rt import CompNode, TensorAttr, imperative
from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync from megengine.core._imperative_rt.core2 import TensorWeakRef, apply, sync
from megengine.core.autodiff.grad import Grad from megengine.core.autodiff.grad import Grad
from megengine.core.ops.builtin import Elemwise
from megengine.core.ops.builtin import Elemwise, Identity
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
from megengine.functional.distributed import remote_recv, remote_send from megengine.functional.distributed import remote_recv, remote_send


@@ -193,6 +193,20 @@ def test_grad_inplace():
np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6) np.testing.assert_almost_equal(x.grad.numpy(), 4 * x_np ** 3, decimal=6)




def test_identity():
x_np = np.random.rand(10).astype("float32")
x = mge.Tensor(x_np)
dy_np = np.random.rand(*x.shape).astype("float32")
dy = mge.Tensor(dy_np)

grad = Grad().wrt(x, callback=save_to(x))

(y,) = apply(Identity(), x)

grad(y, dy)
np.testing.assert_array_equal(x.grad.numpy(), dy_np)


def test_elemwise_add(): def test_elemwise_add():
x_np = np.random.rand(10).astype("float32") x_np = np.random.rand(10).astype("float32")
y_np = np.random.rand(10, 10).astype("float32") y_np = np.random.rand(10, 10).astype("float32")


+ 1
- 1
imperative/src/impl/backward_graph_opt.cpp View File

@@ -58,7 +58,7 @@ OptimizedBackwardGraphResult::OptimizedBackwardGraphResult(const BackwardGraphRe
// should be marked as always appears in backward // should be marked as always appears in backward
for (size_t i = 0, j = 0; i < mask.size(); ++i) { for (size_t i = 0, j = 0; i < mask.size(); ++i) {
if (!mask[i]) continue; if (!mask[i]) continue;
if (i > input_size + output_size) {
if (i >= input_size + output_size) {
vinfo[graph.inputs[j]].appears_in_backward = true; vinfo[graph.inputs[j]].appears_in_backward = true;
} }
++j; ++j;


Loading…
Cancel
Save