Browse Source

test(mge/function): fix test for new optimizer api

GitOrigin-RevId: 8ae7720fe6
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
69a7c55f55
1 changed files with 16 additions and 12 deletions
  1. +16
    -12
      imperative/python/test/unit/test_function.py

+ 16
- 12
imperative/python/test/unit/test_function.py View File

@@ -163,10 +163,11 @@ def test_skip_invalid_grad():


net = Simple(av, bv) net = Simple(av, bv)
optim = optimizer.SGD(net.parameters(), lr=1.0) optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.zero_grad()
with optim.record():
gm = ad.GradManager().register(net.parameters())
optim.clear_grad()
with gm.record():
loss = net().sum() loss = net().sum()
optim.backward(loss)
gm.backward(loss)
optim.step() optim.step()
np.testing.assert_almost_equal(net.a.numpy(), av - c) np.testing.assert_almost_equal(net.a.numpy(), av - c)
np.testing.assert_almost_equal(net.b.numpy(), bv - c) np.testing.assert_almost_equal(net.b.numpy(), bv - c)
@@ -197,11 +198,12 @@ def test_ste():
av = np.random.random(data_shape).astype(np.float32) av = np.random.random(data_shape).astype(np.float32)
net = Simple(av) net = Simple(av)
optim = optimizer.SGD(net.parameters(), lr=1.0) optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.zero_grad()
gm = ad.GradManager().register(net.parameters())
optim.clear_grad()


with optim.record():
with gm.record():
loss = net() loss = net()
optim.backward(loss.sum())
gm.backward(loss.sum())
optim.step() optim.step()


np.testing.assert_almost_equal( np.testing.assert_almost_equal(
@@ -254,10 +256,11 @@ def test_none_in_out_grad():
b = tensor(np.array([2.0], dtype=np.float32)) b = tensor(np.array([2.0], dtype=np.float32))
net = Simple(a, b) net = Simple(a, b)
optim = optimizer.SGD(net.parameters(), lr=1.0) optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.zero_grad()
with optim.record():
gm = ad.GradManager().register(net.parameters())
optim.clear_grad()
with gm.record():
loss, _ = net() loss, _ = net()
optim.backward(loss)
gm.backward(loss)
optim.step() optim.step()


np.testing.assert_almost_equal( np.testing.assert_almost_equal(
@@ -290,11 +293,12 @@ def test_zero_grad():
a = tensor(np.array([1.0], dtype=np.float32)) a = tensor(np.array([1.0], dtype=np.float32))
net = Simple(a) net = Simple(a)
optim = optimizer.SGD(net.parameters(), lr=1.0) optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.zero_grad()
gm = ad.GradManager().register(net.parameters())
optim.clear_grad()


with optim.record():
with gm.record():
loss = net() loss = net()
optim.backward(loss.sum())
gm.backward(loss.sum())
optim.step() optim.step()
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
net.a.numpy(), np.array([1.0 - 4.0], dtype=np.float32), net.a.numpy(), np.array([1.0 - 4.0], dtype=np.float32),


Loading…
Cancel
Save