Browse Source

test(mge/optimizer): update optimizer test to make sure grad not change

GitOrigin-RevId: e207672116
release-1.4
Megvii Engine Team 4 years ago
parent
commit
5dbf3612b2
1 changed files with 15 additions and 0 deletions
  1. +15
    -0
      imperative/python/test/integration/test_optimizer.py

+ 15
- 0
imperative/python/test/integration/test_optimizer.py View File

@@ -66,10 +66,17 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
gm.backward(loss)

ori_params = {}
ori_grads = {}
for param in net.parameters():
assert param._tuple_shape is ()
ori_params[param] = np.copy(param.numpy())
ori_grads[param] = np.copy(param.grad.numpy())
opt.step()
# check grad not change
for param in net.parameters():
assert np.equal(
ori_grads[param], param.grad.numpy()
), "step should not change param.grad"
step += 1
check_func(ori_params, net.parameters(), step)

@@ -135,6 +142,8 @@ def test_sgd(monkeypatch, case, update_lr, inplace_mode):
def __call__(self, ori_params, new_params, step):
for param in new_params:
grad = param.grad.numpy()
if hasattr(self, "weight_decay") and self.weight_decay != 0.0:
grad = grad + ori_params[param] * self.weight_decay
if hasattr(self, "momentum"):
self.slots[param] = grad + self.slots[param] * self.momentum
delta = -self.lr * self.slots[param]
@@ -177,6 +186,8 @@ def test_adam(monkeypatch, case, update_lr, inplace_mode):
def __call__(self, ori_params, new_params, step):
for param in new_params:
grad = param.grad.numpy()
if hasattr(self, "weight_decay") and self.weight_decay != 0.0:
grad = grad + ori_params[param] * self.weight_decay
m = self.m_slots[param]
v = self.v_slots[param]
m *= self.betas[0]
@@ -222,6 +233,8 @@ def test_adagrad(monkeypatch, case, update_lr, inplace_mode):
def __call__(self, ori_params, new_params, step):
for param in new_params:
grad = param.grad.numpy()
if hasattr(self, "weight_decay") and self.weight_decay != 0.0:
grad = grad + ori_params[param] * self.weight_decay
self.s_slots[param] += grad ** 2
delta = grad / (self.s_slots[param] + self.eps) ** 0.5
delta *= -(self.lr / (1 + (step - 1) * self.lr_decay))
@@ -257,6 +270,8 @@ def test_adadelta(monkeypatch, case, update_lr, inplace_mode):
def __call__(self, ori_params, new_params, step):
for param in new_params:
grad = param.grad.numpy()
if hasattr(self, "weight_decay") and self.weight_decay != 0.0:
grad = grad + ori_params[param] * self.weight_decay
self.s_slots[param] = self.s_slots[param] * self.rho + grad ** 2 * (
1 - self.rho
)


Loading…
Cancel
Save