From 02455941566e66c224bdded38cdffc3aa69de715 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 7 Jul 2021 14:16:09 +0800 Subject: [PATCH] test(autograd): test jvp emulated by 2nd grad GitOrigin-RevId: 47114fcd99f53b335c1615896370dd6e45dd3693 --- imperative/python/src/grad_override.cpp | 4 +++- .../python/test/unit/autodiff/test_grad_manger.py | 26 ++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/imperative/python/src/grad_override.cpp b/imperative/python/src/grad_override.cpp index 5a54623e..b49b58cb 100644 --- a/imperative/python/src/grad_override.cpp +++ b/imperative/python/src/grad_override.cpp @@ -163,7 +163,9 @@ apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward: apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) { auto& op = ctx.op->cast_final_safe(); if (op.mode == Reduce::Mode::SUM) { - mgb_assert(ctx.nargs == 1); + if (ctx.nargs != 1) { + throw GradRuleFallback(); + } std::array, 1> input_shapes; if (input_requires_grad(ctx, 0)) { input_shapes[0] = get_shape(ctx.args[0]); diff --git a/imperative/python/test/unit/autodiff/test_grad_manger.py b/imperative/python/test/unit/autodiff/test_grad_manger.py index ddb6ad5b..112c37ba 100644 --- a/imperative/python/test/unit/autodiff/test_grad_manger.py +++ b/imperative/python/test/unit/autodiff/test_grad_manger.py @@ -349,3 +349,29 @@ def test_grad_manager_visibility_by_order(): gm.backward(x.grad) np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) + + +@pytest.mark.require_higher_order_directive() +@pytest.mark.parametrize("target", [F.cos, F.sin, lambda x: x * 2 + 1]) +def test_emulate_forward_mode_with_reverse_mode(target): + def jvp(inp, expr): + with GradManager() as gm: + with GradManager().attach([inp]) as gm2: + oup = expr(inp) + oup_grad = F.zeros_like(oup) + gm.attach(oup_grad) + gm2.backward(oup, oup_grad) + gm.backward(inp.grad) + return oup, oup_grad.grad + + def fake_jvp(inp, expr): + delta = 0.001 + return expr(inp), (expr(inp + delta) - expr(inp - delta)) / (2 * delta) + + x_np = np.random.rand(10).astype("float32") + x = mge.tensor(x_np) + y, dy = jvp(x, target) + y1, dy1 = fake_jvp(x, target) + + np.testing.assert_almost_equal(y.numpy(), y1.numpy(), decimal=5) + np.testing.assert_almost_equal(dy.numpy(), dy1.numpy(), decimal=3)