From 69a7c55f5515ff34b02f2d3b51475c648b8bb78a Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Thu, 10 Sep 2020 12:25:09 +0800 Subject: [PATCH] test(mge/function): fix test for new optimizer api GitOrigin-RevId: 8ae7720fe6340d6ff60ce86981111173a8c1e447 --- imperative/python/test/unit/test_function.py | 28 ++++++++++++++++------------ 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/imperative/python/test/unit/test_function.py b/imperative/python/test/unit/test_function.py index 990ced26..cef30bd9 100644 --- a/imperative/python/test/unit/test_function.py +++ b/imperative/python/test/unit/test_function.py @@ -163,10 +163,11 @@ def test_skip_invalid_grad(): net = Simple(av, bv) optim = optimizer.SGD(net.parameters(), lr=1.0) - optim.zero_grad() - with optim.record(): + gm = ad.GradManager().register(net.parameters()) + optim.clear_grad() + with gm.record(): loss = net().sum() - optim.backward(loss) + gm.backward(loss) optim.step() np.testing.assert_almost_equal(net.a.numpy(), av - c) np.testing.assert_almost_equal(net.b.numpy(), bv - c) @@ -197,11 +198,12 @@ def test_ste(): av = np.random.random(data_shape).astype(np.float32) net = Simple(av) optim = optimizer.SGD(net.parameters(), lr=1.0) - optim.zero_grad() + gm = ad.GradManager().register(net.parameters()) + optim.clear_grad() - with optim.record(): + with gm.record(): loss = net() - optim.backward(loss.sum()) + gm.backward(loss.sum()) optim.step() np.testing.assert_almost_equal( @@ -254,10 +256,11 @@ def test_none_in_out_grad(): b = tensor(np.array([2.0], dtype=np.float32)) net = Simple(a, b) optim = optimizer.SGD(net.parameters(), lr=1.0) - optim.zero_grad() - with optim.record(): + gm = ad.GradManager().register(net.parameters()) + optim.clear_grad() + with gm.record(): loss, _ = net() - optim.backward(loss) + gm.backward(loss) optim.step() np.testing.assert_almost_equal( @@ -290,11 +293,12 @@ def test_zero_grad(): a = tensor(np.array([1.0], dtype=np.float32)) net = Simple(a) optim = optimizer.SGD(net.parameters(), lr=1.0) - optim.zero_grad() + gm = ad.GradManager().register(net.parameters()) + optim.clear_grad() - with optim.record(): + with gm.record(): loss = net() - optim.backward(loss.sum()) + gm.backward(loss.sum()) optim.step() np.testing.assert_almost_equal( net.a.numpy(), np.array([1.0 - 4.0], dtype=np.float32),