|
- import numpy as np
-
- import megengine as mge
- import megengine.autodiff as ad
- import megengine.functional as F
- import megengine.module as M
- import megengine.optimizer as optim
- from megengine import tensor
- from megengine.core._imperative_rt.core2 import apply
- from megengine.core.ops.builtin import LAMBUpdate
-
-
- def lamb_update(
- param_group, step, exp_avg, exp_avg_sq, param, grad, bias_correction, always_adapt
- ):
- lr = param_group["lr"]
- weight_decay = param_group["weight_decay"]
- eps = param_group["eps"]
- beta0, beta1 = param_group["betas"]
-
- # since `conver_inputs` is disabled for param updates,
- # scalar should be explicitly tansforred to tensor
-
- _lr, _neg_lr = map(tensor, (lr, -lr))
- _weight_decay = tensor(weight_decay)
- _eps = tensor(eps)
- _beta0, _beta1 = map(tensor, (beta0, beta1))
-
- c1, c05, c0 = map(tensor, (1.0, 0.5, 0.0))
-
- def norm(vec):
- return sum(vec * vec) ** c05
-
- p_norm = norm(param.flatten())
-
- # step = step + c1
- step += c1
-
- # exp_avg = _beta0 * exp_avg + grad * (c1 - _beta0)
- exp_avg *= _beta0
- exp_avg += grad * (c1 - _beta0)
-
- # exp_avg_sq = _beta1 * exp_avg_sq + (c1 - _beta1) * (grad * grad)
- exp_avg_sq *= _beta1
- exp_avg_sq += (c1 - _beta1) * (grad * grad)
-
- bias_correction1 = c1 - _beta0 ** step if bias_correction else c1
- bias_correction2 = c1 - _beta1 ** step if bias_correction else c1
- delta = (exp_avg / bias_correction1) / (
- (exp_avg_sq / bias_correction2) ** c05 + _eps
- )
- if weight_decay != 0.0:
- delta += param * _weight_decay
-
- d_norm = norm(delta.flatten())
- trust_ratio = (
- p_norm / d_norm
- if (always_adapt or weight_decay > 0) and p_norm > c0 and d_norm > c0
- else c1
- )
- new_param = param - _lr * trust_ratio * delta
- return exp_avg, exp_avg_sq, new_param
-
-
- def test_lamb():
- op = LAMBUpdate(0.9, 0.999, 1, 1e-3, 0.4, 1e-8, True, False)
- m_t_1 = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32)
- v_t_1 = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32)
- params = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float32)
- grad = mge.tensor(np.random.uniform(size=(256, 256)), dtype=np.float16)
- (new_m_t, new_v_t, new_param) = apply(op, m_t_1, v_t_1, params, grad)
-
- param_group = {
- "betas": (0.9, 0.999),
- "step": 1,
- "lr": 1e-3,
- "weight_decay": 0.4,
- "eps": 1e-8,
- }
- gt_m_t, gt_v_t, gt_new_param = lamb_update(
- param_group, 1, m_t_1, v_t_1, params, grad, True, False
- )
- np.testing.assert_allclose(new_m_t.numpy(), gt_m_t.numpy(), atol=1e-2)
- np.testing.assert_allclose(new_v_t.numpy(), gt_v_t.numpy(), atol=1e-2)
- np.testing.assert_allclose(new_param.numpy(), gt_new_param.numpy(), atol=1e-2)
|