Browse Source

test(mge/function): fix test for new optimizer api

GitOrigin-RevId: 8ae7720fe6
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
69a7c55f55
1 changed files with 16 additions and 12 deletions
  1. +16
    -12
      imperative/python/test/unit/test_function.py

+ 16
- 12
imperative/python/test/unit/test_function.py View File

@@ -163,10 +163,11 @@ def test_skip_invalid_grad():

net = Simple(av, bv)
optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.zero_grad()
with optim.record():
gm = ad.GradManager().register(net.parameters())
optim.clear_grad()
with gm.record():
loss = net().sum()
optim.backward(loss)
gm.backward(loss)
optim.step()
np.testing.assert_almost_equal(net.a.numpy(), av - c)
np.testing.assert_almost_equal(net.b.numpy(), bv - c)
@@ -197,11 +198,12 @@ def test_ste():
av = np.random.random(data_shape).astype(np.float32)
net = Simple(av)
optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.zero_grad()
gm = ad.GradManager().register(net.parameters())
optim.clear_grad()

with optim.record():
with gm.record():
loss = net()
optim.backward(loss.sum())
gm.backward(loss.sum())
optim.step()

np.testing.assert_almost_equal(
@@ -254,10 +256,11 @@ def test_none_in_out_grad():
b = tensor(np.array([2.0], dtype=np.float32))
net = Simple(a, b)
optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.zero_grad()
with optim.record():
gm = ad.GradManager().register(net.parameters())
optim.clear_grad()
with gm.record():
loss, _ = net()
optim.backward(loss)
gm.backward(loss)
optim.step()

np.testing.assert_almost_equal(
@@ -290,11 +293,12 @@ def test_zero_grad():
a = tensor(np.array([1.0], dtype=np.float32))
net = Simple(a)
optim = optimizer.SGD(net.parameters(), lr=1.0)
optim.zero_grad()
gm = ad.GradManager().register(net.parameters())
optim.clear_grad()

with optim.record():
with gm.record():
loss = net()
optim.backward(loss.sum())
gm.backward(loss.sum())
optim.step()
np.testing.assert_almost_equal(
net.a.numpy(), np.array([1.0 - 4.0], dtype=np.float32),


Loading…
Cancel
Save