|
|
@@ -349,3 +349,29 @@ def test_grad_manager_visibility_by_order(): |
|
|
|
gm.backward(x.grad) |
|
|
|
|
|
|
|
np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5) |
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.require_higher_order_directive() |
|
|
|
@pytest.mark.parametrize("target", [F.cos, F.sin, lambda x: x * 2 + 1]) |
|
|
|
def test_emulate_forward_mode_with_reverse_mode(target): |
|
|
|
def jvp(inp, expr): |
|
|
|
with GradManager() as gm: |
|
|
|
with GradManager().attach([inp]) as gm2: |
|
|
|
oup = expr(inp) |
|
|
|
oup_grad = F.zeros_like(oup) |
|
|
|
gm.attach(oup_grad) |
|
|
|
gm2.backward(oup, oup_grad) |
|
|
|
gm.backward(inp.grad) |
|
|
|
return oup, oup_grad.grad |
|
|
|
|
|
|
|
def fake_jvp(inp, expr): |
|
|
|
delta = 0.001 |
|
|
|
return expr(inp), (expr(inp + delta) - expr(inp - delta)) / (2 * delta) |
|
|
|
|
|
|
|
x_np = np.random.rand(10).astype("float32") |
|
|
|
x = mge.tensor(x_np) |
|
|
|
y, dy = jvp(x, target) |
|
|
|
y1, dy1 = fake_jvp(x, target) |
|
|
|
|
|
|
|
np.testing.assert_almost_equal(y.numpy(), y1.numpy(), decimal=5) |
|
|
|
np.testing.assert_almost_equal(dy.numpy(), dy1.numpy(), decimal=3) |