Browse Source

test(autograd): test jvp emulated by 2nd grad

GitOrigin-RevId: 47114fcd99
release-1.5
Megvii Engine Team huangxinda 3 years ago
parent
commit
0245594156
2 changed files with 29 additions and 1 deletions
  1. +3
    -1
      imperative/python/src/grad_override.cpp
  2. +26
    -0
      imperative/python/test/unit/autodiff/test_grad_manger.py

+ 3
- 1
imperative/python/src/grad_override.cpp View File

@@ -163,7 +163,9 @@ apply_result_t indexingMultiAxisVec_grad_rule(ApplyContext& ctx, CustomBackward:
apply_result_t reduce_grad_rule(ApplyContext& ctx, CustomBackward::Maker& maker) {
auto& op = ctx.op->cast_final_safe<Reduce>();
if (op.mode == Reduce::Mode::SUM) {
mgb_assert(ctx.nargs == 1);
if (ctx.nargs != 1) {
throw GradRuleFallback();
}
std::array<std::shared_ptr<Tensor>, 1> input_shapes;
if (input_requires_grad(ctx, 0)) {
input_shapes[0] = get_shape(ctx.args[0]);


+ 26
- 0
imperative/python/test/unit/autodiff/test_grad_manger.py View File

@@ -349,3 +349,29 @@ def test_grad_manager_visibility_by_order():
gm.backward(x.grad)

np.testing.assert_almost_equal(x.grad.numpy(), -np.sin(x_np), decimal=5)


@pytest.mark.require_higher_order_directive()
@pytest.mark.parametrize("target", [F.cos, F.sin, lambda x: x * 2 + 1])
def test_emulate_forward_mode_with_reverse_mode(target):
def jvp(inp, expr):
with GradManager() as gm:
with GradManager().attach([inp]) as gm2:
oup = expr(inp)
oup_grad = F.zeros_like(oup)
gm.attach(oup_grad)
gm2.backward(oup, oup_grad)
gm.backward(inp.grad)
return oup, oup_grad.grad

def fake_jvp(inp, expr):
delta = 0.001
return expr(inp), (expr(inp + delta) - expr(inp - delta)) / (2 * delta)

x_np = np.random.rand(10).astype("float32")
x = mge.tensor(x_np)
y, dy = jvp(x, target)
y1, dy1 = fake_jvp(x, target)

np.testing.assert_almost_equal(y.numpy(), y1.numpy(), decimal=5)
np.testing.assert_almost_equal(dy.numpy(), dy1.numpy(), decimal=3)

Loading…
Cancel
Save