|
|
@@ -66,10 +66,17 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): |
|
|
|
gm.backward(loss) |
|
|
|
|
|
|
|
ori_params = {} |
|
|
|
ori_grads = {} |
|
|
|
for param in net.parameters(): |
|
|
|
assert param._tuple_shape is () |
|
|
|
ori_params[param] = np.copy(param.numpy()) |
|
|
|
ori_grads[param] = np.copy(param.grad.numpy()) |
|
|
|
opt.step() |
|
|
|
# check grad not change |
|
|
|
for param in net.parameters(): |
|
|
|
assert np.equal( |
|
|
|
ori_grads[param], param.grad.numpy() |
|
|
|
), "step should not change param.grad" |
|
|
|
step += 1 |
|
|
|
check_func(ori_params, net.parameters(), step) |
|
|
|
|
|
|
@@ -135,6 +142,8 @@ def test_sgd(monkeypatch, case, update_lr, inplace_mode): |
|
|
|
def __call__(self, ori_params, new_params, step): |
|
|
|
for param in new_params: |
|
|
|
grad = param.grad.numpy() |
|
|
|
if hasattr(self, "weight_decay") and self.weight_decay != 0.0: |
|
|
|
grad = grad + ori_params[param] * self.weight_decay |
|
|
|
if hasattr(self, "momentum"): |
|
|
|
self.slots[param] = grad + self.slots[param] * self.momentum |
|
|
|
delta = -self.lr * self.slots[param] |
|
|
@@ -177,6 +186,8 @@ def test_adam(monkeypatch, case, update_lr, inplace_mode): |
|
|
|
def __call__(self, ori_params, new_params, step): |
|
|
|
for param in new_params: |
|
|
|
grad = param.grad.numpy() |
|
|
|
if hasattr(self, "weight_decay") and self.weight_decay != 0.0: |
|
|
|
grad = grad + ori_params[param] * self.weight_decay |
|
|
|
m = self.m_slots[param] |
|
|
|
v = self.v_slots[param] |
|
|
|
m *= self.betas[0] |
|
|
@@ -222,6 +233,8 @@ def test_adagrad(monkeypatch, case, update_lr, inplace_mode): |
|
|
|
def __call__(self, ori_params, new_params, step): |
|
|
|
for param in new_params: |
|
|
|
grad = param.grad.numpy() |
|
|
|
if hasattr(self, "weight_decay") and self.weight_decay != 0.0: |
|
|
|
grad = grad + ori_params[param] * self.weight_decay |
|
|
|
self.s_slots[param] += grad ** 2 |
|
|
|
delta = grad / (self.s_slots[param] + self.eps) ** 0.5 |
|
|
|
delta *= -(self.lr / (1 + (step - 1) * self.lr_decay)) |
|
|
@@ -257,6 +270,8 @@ def test_adadelta(monkeypatch, case, update_lr, inplace_mode): |
|
|
|
def __call__(self, ori_params, new_params, step): |
|
|
|
for param in new_params: |
|
|
|
grad = param.grad.numpy() |
|
|
|
if hasattr(self, "weight_decay") and self.weight_decay != 0.0: |
|
|
|
grad = grad + ori_params[param] * self.weight_decay |
|
|
|
self.s_slots[param] = self.s_slots[param] * self.rho + grad ** 2 * ( |
|
|
|
1 - self.rho |
|
|
|
) |
|
|
|