|
|
@@ -163,10 +163,11 @@ def test_skip_invalid_grad(): |
|
|
|
|
|
|
|
net = Simple(av, bv) |
|
|
|
optim = optimizer.SGD(net.parameters(), lr=1.0) |
|
|
|
optim.zero_grad() |
|
|
|
with optim.record(): |
|
|
|
gm = ad.GradManager().register(net.parameters()) |
|
|
|
optim.clear_grad() |
|
|
|
with gm.record(): |
|
|
|
loss = net().sum() |
|
|
|
optim.backward(loss) |
|
|
|
gm.backward(loss) |
|
|
|
optim.step() |
|
|
|
np.testing.assert_almost_equal(net.a.numpy(), av - c) |
|
|
|
np.testing.assert_almost_equal(net.b.numpy(), bv - c) |
|
|
@@ -197,11 +198,12 @@ def test_ste(): |
|
|
|
av = np.random.random(data_shape).astype(np.float32) |
|
|
|
net = Simple(av) |
|
|
|
optim = optimizer.SGD(net.parameters(), lr=1.0) |
|
|
|
optim.zero_grad() |
|
|
|
gm = ad.GradManager().register(net.parameters()) |
|
|
|
optim.clear_grad() |
|
|
|
|
|
|
|
with optim.record(): |
|
|
|
with gm.record(): |
|
|
|
loss = net() |
|
|
|
optim.backward(loss.sum()) |
|
|
|
gm.backward(loss.sum()) |
|
|
|
optim.step() |
|
|
|
|
|
|
|
np.testing.assert_almost_equal( |
|
|
@@ -254,10 +256,11 @@ def test_none_in_out_grad(): |
|
|
|
b = tensor(np.array([2.0], dtype=np.float32)) |
|
|
|
net = Simple(a, b) |
|
|
|
optim = optimizer.SGD(net.parameters(), lr=1.0) |
|
|
|
optim.zero_grad() |
|
|
|
with optim.record(): |
|
|
|
gm = ad.GradManager().register(net.parameters()) |
|
|
|
optim.clear_grad() |
|
|
|
with gm.record(): |
|
|
|
loss, _ = net() |
|
|
|
optim.backward(loss) |
|
|
|
gm.backward(loss) |
|
|
|
optim.step() |
|
|
|
|
|
|
|
np.testing.assert_almost_equal( |
|
|
@@ -290,11 +293,12 @@ def test_zero_grad(): |
|
|
|
a = tensor(np.array([1.0], dtype=np.float32)) |
|
|
|
net = Simple(a) |
|
|
|
optim = optimizer.SGD(net.parameters(), lr=1.0) |
|
|
|
optim.zero_grad() |
|
|
|
gm = ad.GradManager().register(net.parameters()) |
|
|
|
optim.clear_grad() |
|
|
|
|
|
|
|
with optim.record(): |
|
|
|
with gm.record(): |
|
|
|
loss = net() |
|
|
|
optim.backward(loss.sum()) |
|
|
|
gm.backward(loss.sum()) |
|
|
|
optim.step() |
|
|
|
np.testing.assert_almost_equal( |
|
|
|
net.a.numpy(), np.array([1.0 - 4.0], dtype=np.float32), |
|
|
|