|
|
@@ -32,7 +32,7 @@ class MLP(Module): |
|
|
|
class Simple(Module): |
|
|
|
def __init__(self): |
|
|
|
super().__init__() |
|
|
|
self.a = Parameter([1.23], dtype=np.float32) |
|
|
|
self.a = Parameter(1.23, dtype=np.float32) |
|
|
|
|
|
|
|
def forward(self, x): |
|
|
|
x = x * self.a |
|
|
@@ -64,6 +64,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): |
|
|
|
|
|
|
|
ori_params = {} |
|
|
|
for param in net.parameters(): |
|
|
|
assert param._tuple_shape is () |
|
|
|
ori_params[param] = np.copy(param.numpy()) |
|
|
|
opt.step() |
|
|
|
step += 1 |
|
|
@@ -95,6 +96,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): |
|
|
|
|
|
|
|
ori_params = {} |
|
|
|
for param in net.parameters(): |
|
|
|
assert param._tuple_shape is () |
|
|
|
ori_params[param] = np.copy(param.numpy()) |
|
|
|
|
|
|
|
train_func( |
|
|
@@ -121,7 +123,9 @@ def test_sgd(): |
|
|
|
delta = -self.lr * self.slots[param] |
|
|
|
else: |
|
|
|
delta = -self.lr * grad |
|
|
|
np.testing.assert_almost_equal(param.numpy(), ori_params[param] + delta) |
|
|
|
np.testing.assert_almost_equal( |
|
|
|
param.numpy(), ori_params[param] + delta, decimal=6 |
|
|
|
) |
|
|
|
|
|
|
|
cases = [ |
|
|
|
{"momentum": 0.9, "lr": 0.01}, # SGD with momentum |
|
|
@@ -157,7 +161,7 @@ def test_adam(): |
|
|
|
np.sqrt(v / (1 - self.betas[1] ** step)) + self.eps |
|
|
|
) |
|
|
|
np.testing.assert_almost_equal( |
|
|
|
param.numpy(), ori_params[param] - self.lr * delta |
|
|
|
param.numpy(), ori_params[param] - self.lr * delta, decimal=6 |
|
|
|
) |
|
|
|
|
|
|
|
cases = [ |
|
|
@@ -189,7 +193,9 @@ def test_adagrad(): |
|
|
|
self.s_slots[param] += grad ** 2 |
|
|
|
delta = grad / (self.s_slots[param] + self.eps) ** 0.5 |
|
|
|
delta *= -(self.lr / (1 + (step - 1) * self.lr_decay)) |
|
|
|
np.testing.assert_almost_equal(param.numpy(), ori_params[param] + delta) |
|
|
|
np.testing.assert_almost_equal( |
|
|
|
param.numpy(), ori_params[param] + delta, decimal=6 |
|
|
|
) |
|
|
|
|
|
|
|
cases = [ |
|
|
|
{"lr": 0.01, "eps": 1e-06, "lr_decay": 0.01}, |
|
|
@@ -232,7 +238,9 @@ def test_adadelta(): |
|
|
|
1 - self.rho |
|
|
|
) |
|
|
|
delta *= -self.lr |
|
|
|
np.testing.assert_almost_equal(param.numpy(), ori_params[param] + delta) |
|
|
|
np.testing.assert_almost_equal( |
|
|
|
param.numpy(), ori_params[param] + delta, decimal=6 |
|
|
|
) |
|
|
|
|
|
|
|
cases = [ |
|
|
|
{"lr": 1.0, "eps": 1e-06, "rho": 0.9}, |
|
|
|