You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_lamb.py 2.8 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import numpy as np
  2. import megengine as mge
  3. import megengine.autodiff as ad
  4. import megengine.functional as F
  5. import megengine.module as M
  6. import megengine.optimizer as optim
  7. from megengine import tensor
  8. from megengine.core._imperative_rt.core2 import apply
  9. from megengine.core.ops.builtin import LAMBUpdate
  10. def lamb_update(
  11. param_group, step, exp_avg, exp_avg_sq, param, grad, bias_correction, always_adapt
  12. ):
  13. lr = param_group["lr"]
  14. weight_decay = param_group["weight_decay"]
  15. eps = param_group["eps"]
  16. beta0, beta1 = param_group["betas"]
  17. # since `conver_inputs` is disabled for param updates,
  18. # scalar should be explicitly tansforred to tensor
  19. _lr, _neg_lr = map(tensor, (lr, -lr))
  20. _weight_decay = tensor(weight_decay)
  21. _eps = tensor(eps)
  22. _beta0, _beta1 = map(tensor, (beta0, beta1))
  23. c1, c05, c0 = map(tensor, (1.0, 0.5, 0.0))
  24. def norm(vec):
  25. return sum(vec * vec) ** c05
  26. p_norm = norm(param.flatten())
  27. # step = step + c1
  28. step += c1
  29. # exp_avg = _beta0 * exp_avg + grad * (c1 - _beta0)
  30. exp_avg *= _beta0
  31. exp_avg += grad * (c1 - _beta0)
  32. # exp_avg_sq = _beta1 * exp_avg_sq + (c1 - _beta1) * (grad * grad)
  33. exp_avg_sq *= _beta1
  34. exp_avg_sq += (c1 - _beta1) * (grad * grad)
  35. bias_correction1 = c1 - _beta0 ** step if bias_correction else c1
  36. bias_correction2 = c1 - _beta1 ** step if bias_correction else c1
  37. delta = (exp_avg / bias_correction1) / (
  38. (exp_avg_sq / bias_correction2) ** c05 + _eps
  39. )
  40. if weight_decay != 0.0:
  41. delta += param * _weight_decay
  42. d_norm = norm(delta.flatten())
  43. trust_ratio = (
  44. p_norm / d_norm
  45. if (always_adapt or weight_decay > 0) and p_norm > c0 and d_norm > c0
  46. else c1
  47. )
  48. new_param = param - _lr * trust_ratio * delta
  49. return exp_avg, exp_avg_sq, new_param
  50. def test_lamb():
  51. op = LAMBUpdate(0.9, 0.999, 1, 1e-3, 0.4, 1e-8, True, False)
  52. m_t_1 = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32)
  53. v_t_1 = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32)
  54. params = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32)
  55. grad = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float16)
  56. (new_m_t, new_v_t, new_param) = apply(op, m_t_1, v_t_1, params, grad)
  57. param_group = {
  58. "betas": (0.9, 0.999),
  59. "step": 1,
  60. "lr": 1e-3,
  61. "weight_decay": 0.4,
  62. "eps": 1e-8,
  63. }
  64. gt_m_t, gt_v_t, gt_new_param = lamb_update(
  65. param_group, 1, m_t_1, v_t_1, params, grad, True, False
  66. )
  67. np.testing.assert_allclose(new_m_t.numpy(), gt_m_t.numpy(), atol=1e-2)
  68. np.testing.assert_allclose(new_v_t.numpy(), gt_v_t.numpy(), atol=1e-2)
  69. np.testing.assert_allclose(new_param.numpy(), gt_new_param.numpy(), atol=1e-2)