|
@@ -11,6 +11,7 @@ import numpy as np |
|
|
import megengine |
|
|
import megengine |
|
|
import megengine.optimizer as optimizer |
|
|
import megengine.optimizer as optimizer |
|
|
from megengine import Parameter, tensor |
|
|
from megengine import Parameter, tensor |
|
|
|
|
|
from megengine.jit import trace |
|
|
from megengine.module import Module |
|
|
from megengine.module import Module |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -57,3 +58,45 @@ def test_sgd_momentum(): |
|
|
np.testing.assert_almost_equal( |
|
|
np.testing.assert_almost_equal( |
|
|
optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34 |
|
|
optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34 |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_sgd_momentum_trace(): |
|
|
|
|
|
|
|
|
|
|
|
for symbolic in (True, False): |
|
|
|
|
|
|
|
|
|
|
|
@trace(symbolic=symbolic) |
|
|
|
|
|
def train_func(data, *, model=None, optim=None): |
|
|
|
|
|
optim.zero_grad() |
|
|
|
|
|
with optim.record(): |
|
|
|
|
|
loss = net(data) |
|
|
|
|
|
optim.backward(loss) |
|
|
|
|
|
optim.step() |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
@trace(symbolic=symbolic) |
|
|
|
|
|
def eval_func(data, *, model=None, optim=None): |
|
|
|
|
|
loss = net(data) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
|
|
|
net = Simple() |
|
|
|
|
|
optim = optimizer.SGD(net.parameters(), lr=1.0, momentum=0.9) |
|
|
|
|
|
data = tensor([2.34]) |
|
|
|
|
|
train_func(data, model=net, optim=optim) |
|
|
|
|
|
np.testing.assert_almost_equal( |
|
|
|
|
|
optim._state[net.a]["momentum_buffer"].numpy(), 2.34 |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
# do 3 steps of infer |
|
|
|
|
|
for _ in range(3): |
|
|
|
|
|
loss = eval_func(data) |
|
|
|
|
|
np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5) |
|
|
|
|
|
np.testing.assert_almost_equal( |
|
|
|
|
|
optim._state[net.a]["momentum_buffer"].numpy(), 2.34 |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
# do a step of train |
|
|
|
|
|
train_func(data, model=net, optim=optim) |
|
|
|
|
|
np.testing.assert_almost_equal(loss.numpy(), 2.34 * (1.23 - 2.34), 5) |
|
|
|
|
|
np.testing.assert_almost_equal( |
|
|
|
|
|
optim._state[net.a]["momentum_buffer"].numpy(), 0.9 * 2.34 + 2.34 |
|
|
|
|
|
) |