|
|
@@ -67,23 +67,24 @@ def test_sgd_momentum_trace(): |
|
|
|
for symbolic in (True, False): |
|
|
|
|
|
|
|
@trace(symbolic=symbolic) |
|
|
|
def train_func(data, *, model=None, optim=None): |
|
|
|
optim.zero_grad() |
|
|
|
with optim.record(): |
|
|
|
def train_func(data, *, model=None, optim=None, gm=None): |
|
|
|
optim.clear_grad() |
|
|
|
with gm.record(): |
|
|
|
loss = net(data) |
|
|
|
optim.backward(loss) |
|
|
|
gm.backward(loss) |
|
|
|
optim.step() |
|
|
|
return loss |
|
|
|
|
|
|
|
@trace(symbolic=symbolic) |
|
|
|
def eval_func(data, *, model=None, optim=None): |
|
|
|
def eval_func(data, *, model=None, optim=None, gm=None): |
|
|
|
loss = net(data) |
|
|
|
return loss |
|
|
|
|
|
|
|
net = Simple() |
|
|
|
optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9) |
|
|
|
gm = ad.GradManager().register(net.parameters()) |
|
|
|
data = tensor([2.34]) |
|
|
|
train_func(data, model=net, optim=optim) |
|
|
|
train_func(data, model=net, optim=optim, gm=gm) |
|
|
|
np.testing.assert_almost_equal( |
|
|
|
optim._state[net.a]["momentum_buffer"].numpy(), 2.34 |
|
|
|
) |
|
|
@@ -97,7 +98,7 @@ def test_sgd_momentum_trace(): |
|
|
|
) |
|
|
|
|
|
|
|
# do a step of train |
|
|
|
train_func(data, model=net, optim=optim) |
|
|
|
train_func(data, model=net, optim=optim, gm=gm) |
|
|
|
np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5) |
|
|
|
np.testing.assert_almost_equal( |
|
|
|
optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34 |
|
|
|