Browse Source

fix(mge): fix dumping backward graph

GitOrigin-RevId: 430f110053
release-1.4
Megvii Engine Team 4 years ago
parent
commit
cb84efa74f
3 changed files with 33 additions and 2 deletions
  1. +1
    -1
      imperative/python/megengine/core/tensor/megbrain_graph.py
  2. +4
    -1
      imperative/python/megengine/jit/tracing.py
  3. +28
    -0
      imperative/python/test/unit/jit/test_tracing.py

+ 1
- 1
imperative/python/megengine/core/tensor/megbrain_graph.py View File

@@ -489,7 +489,7 @@ def apply_backward_varnode(op: BackwardGraph, *args: VarNode):
graph._make_const_for_backward,
args,
)
return _unwrap(outputs)
return outputs


set_cpp_apply_backward_varnode(apply_backward_varnode)


+ 4
- 1
imperative/python/megengine/jit/tracing.py View File

@@ -830,7 +830,10 @@ class trace:
name=info.name,
)
ivars.append(h2v[h])
ovars = G.apply_normal_varnode(op, *ivars)
if isinstance(op, BackwardGraph):
ovars = G.apply_backward_varnode(op, *ivars)
else:
ovars = G.apply_normal_varnode(op, *ivars)

AutoNaming.record_opnode(ovars[0].op)



+ 28
- 0
imperative/python/test/unit/jit/test_tracing.py View File

@@ -247,6 +247,34 @@ def test_dump_volatile():
)


def test_dump_backward_graph():
x0 = tensor(np.random.randn(3, 4))
x1 = tensor(np.random.randn(3, 4))

gm = GradManager().attach(x0)

@trace(symbolic=True, capture_as_const=True)
def f(x0, x1):
with gm:
y = x0 * x1
gm.backward(y, F.ones_like(y))
dx0 = x0.grad
return y, dx0

y, dx0 = f(x0, x1)
np.testing.assert_equal(dx0.numpy(), x1)

file = io.BytesIO()
f.dump(file, optimize_for_inference=False)
file.seek(0)

infer_cg = cgtools.GraphInference(file)
results = list((infer_cg.run(x0, x1)).values())

np.testing.assert_equal(results[0], y)
np.testing.assert_equal(results[1], dx0)


@pytest.mark.parametrize("trace_mode", [False, True])
def test_trace_profiler(trace_mode):
@trace(symbolic=trace_mode, profiling=True)


Loading…
Cancel
Save