diff --git a/imperative/python/megengine/optimizer/__init__.py b/imperative/python/megengine/optimizer/__init__.py index 19c5f6e5..f121ff9a 100644 --- a/imperative/python/megengine/optimizer/__init__.py +++ b/imperative/python/megengine/optimizer/__init__.py @@ -9,6 +9,7 @@ from .adadelta import Adadelta from .adagrad import Adagrad from .adam import Adam +from .adamw import AdamW from .lr_scheduler import LRScheduler from .multi_step_lr import MultiStepLR from .optimizer import Optimizer diff --git a/imperative/python/megengine/optimizer/adamw.py b/imperative/python/megengine/optimizer/adamw.py new file mode 100644 index 00000000..aec655e0 --- /dev/null +++ b/imperative/python/megengine/optimizer/adamw.py @@ -0,0 +1,128 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import os +from typing import Iterable, Tuple, Union + +from ..functional.inplace import _inplace_add_ +from ..tensor import Parameter, tensor +from .optimizer import Optimizer + + +class AdamW(Optimizer): + r""" + Implements AdamW algorithm proposed in `"Decoupled Weight Decay Regularization" `_. + + :param params: iterable of parameters to optimize or dicts defining + parameter groups. + :param lr: learning rate. + :param betas: coefficients used for computing running averages of gradient + and its square. Default: (0.9, 0.999) + :param eps: term added to the denominator to improve numerical stability + Default: 1e-8 + :param weight_decay: weight decay (L2 penalty). Default: 1e-2 + """ + + def __init__( + self, + params: Union[Iterable[Parameter], dict], + lr: float, + betas: Tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + ): + if lr < 0.0: + raise ValueError("Invalid learning rate: {}".format(lr)) + if weight_decay < 0.0: + raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + + defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps) + super().__init__(params, defaults) + + def _create_state(self, param_group): + for param in param_group["params"]: + self._add_state(param, "exp_avg") + self._add_state(param, "exp_avg_sq") + self._add_state(param, "step", initializer=0.0) + + def _updates(self, param_group): + lr = param_group["lr"] + weight_decay = param_group["weight_decay"] + eps = param_group["eps"] + beta0, beta1 = param_group["betas"] + + def make_scalar(val): + return tensor(val) + + # since `conver_inputs` is disabled for param updates, + # scalar should be explicitly tansforred to tensor + + _lr, _neg_lr = map(make_scalar, (lr, -lr)) + _weight_decay = make_scalar(weight_decay) + _eps = make_scalar(eps) + _beta0, _beta1 = map(make_scalar, (beta0, beta1)) + + c1, c05 = map(make_scalar, (1.0, 0.5)) + + inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0")) + if inplace_mode: + # reduce device sync + c1_sub_beta0, c1_sub_beta1 = map(make_scalar, (1 - beta0, 1 - beta1)) + + for param in param_group["params"]: + + if param.grad is None: + continue + + grad = param.grad + + states = self._state[param] + + step, exp_avg, exp_avg_sq = ( + states["step"], + states["exp_avg"], + states["exp_avg_sq"], + ) + + if inplace_mode: + _inplace_add_(step, c1, alpha=c1, beta=c1) + _inplace_add_(exp_avg, grad, alpha=_beta0, beta=c1_sub_beta0) + _inplace_add_( + exp_avg_sq, grad * grad, alpha=_beta1, beta=c1_sub_beta1, + ) + + delta = (exp_avg / (c1 - _beta0 ** step)) / ( + (exp_avg_sq / (c1 - _beta1 ** step)) ** c05 + _eps + ) + if weight_decay != 0.0: + delta += param * _weight_decay + _inplace_add_(param, delta, alpha=c1, beta=_neg_lr) + continue + + # step = step + c1 + step += c1 + + # exp_avg = _beta0 * exp_avg + grad * (c1 - _beta0) + exp_avg *= _beta0 + exp_avg += grad * (c1 - _beta0) + + # exp_avg_sq = _beta1 * exp_avg_sq + (c1 - _beta1) * (grad * grad) + exp_avg_sq *= _beta1 + exp_avg_sq += (c1 - _beta1) * (grad * grad) + + delta = (exp_avg / (c1 - _beta0 ** step)) / ( + (exp_avg_sq / (c1 - _beta1 ** step)) ** c05 + _eps + ) + if weight_decay != 0.0: + delta += param * _weight_decay + + param -= _lr * delta diff --git a/imperative/python/test/integration/test_optimizer.py b/imperative/python/test/integration/test_optimizer.py index fd51d567..da71816b 100644 --- a/imperative/python/test/integration/test_optimizer.py +++ b/imperative/python/test/integration/test_optimizer.py @@ -6,7 +6,10 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import os + import numpy as np +import pytest import megengine.autodiff as ad import megengine.functional as F @@ -110,7 +113,17 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False): } -def test_sgd(): +@pytest.mark.parametrize( + "case", + [ + {"momentum": 0.9, "lr": 0.01}, # SGD with momentum + {"lr": 0.01}, # simple SGD + {"weight_decay": 0.1, "lr": 0.01}, # with weight_decay + ], +) +@pytest.mark.parametrize("update_lr", [False, True]) +@pytest.mark.parametrize("inplace_mode", [False, True]) +def test_sgd(monkeypatch, case, update_lr, inplace_mode): class CheckValue: def __init__(self, net, **kwarg): self.slots = {} @@ -131,17 +144,26 @@ def test_sgd(): param.numpy(), ori_params[param] + delta, decimal=6 ) - cases = [ - {"momentum": 0.9, "lr": 0.01}, # SGD with momentum - {"lr": 0.01}, # simple SGD - {"weight_decay": 0.1, "lr": 0.01}, # with weight_decay - ] - for case in cases: - _test_optimizer("SGD", case, CheckValue) - _test_optimizer("SGD", case, CheckValue, update_lr=True) + with monkeypatch.context() as mk: + mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode))) + _test_optimizer("SGD", case, CheckValue, update_lr=update_lr) -def test_adam(): +@pytest.mark.parametrize( + "case", + [ + {"betas": (0.8, 0.9), "eps": 1e-04, "lr": 0.01}, + { + "betas": (0.8, 0.9), + "eps": 1e-04, + "lr": 0.01, + "weight_decay": 0.1, + }, # with weight_decay + ], +) +@pytest.mark.parametrize("update_lr", [False, True]) +@pytest.mark.parametrize("inplace_mode", [False, True]) +def test_adam(monkeypatch, case, update_lr, inplace_mode): class CheckValue: def __init__(self, net, **kwarg): self.m_slots = {} @@ -168,21 +190,27 @@ def test_adam(): param.numpy(), ori_params[param] - self.lr * delta, decimal=6 ) - cases = [ - {"betas": (0.8, 0.9), "eps": 1e-04, "lr": 0.01}, + with monkeypatch.context() as mk: + mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode))) + _test_optimizer("Adam", case, CheckValue, update_lr=update_lr) + + +@pytest.mark.parametrize( + "case", + [ + {"lr": 0.01, "eps": 1e-06, "lr_decay": 0.01}, + {"lr": 0.01, "eps": 1e-06, "lr_decay": 0.0}, # without lr_decay { - "betas": (0.8, 0.9), - "eps": 1e-04, "lr": 0.01, + "eps": 1e-06, + "lr_decay": 0.01, "weight_decay": 0.1, }, # with weight_decay - ] - for case in cases: - _test_optimizer("Adam", case, CheckValue) - _test_optimizer("Adam", case, CheckValue, update_lr=True) - - -def test_adagrad(): + ], +) +@pytest.mark.parametrize("update_lr", [False, True]) +@pytest.mark.parametrize("inplace_mode", [False, True]) +def test_adagrad(monkeypatch, case, update_lr, inplace_mode): class CheckValue: def __init__(self, net, **kwarg): self.s_slots = {} @@ -201,22 +229,21 @@ def test_adagrad(): param.numpy(), ori_params[param] + delta, decimal=6 ) - cases = [ - {"lr": 0.01, "eps": 1e-06, "lr_decay": 0.01}, - {"lr": 0.01, "eps": 1e-06, "lr_decay": 0.0}, # without lr_decay - { - "lr": 0.01, - "eps": 1e-06, - "lr_decay": 0.01, - "weight_decay": 0.1, - }, # with weight_decay - ] - for case in cases: - _test_optimizer("Adagrad", case, CheckValue) - _test_optimizer("Adagrad", case, CheckValue, update_lr=True) + with monkeypatch.context() as mk: + mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode))) + _test_optimizer("Adagrad", case, CheckValue, update_lr=update_lr) -def test_adadelta(): +@pytest.mark.parametrize( + "case", + [ + {"lr": 1.0, "eps": 1e-06, "rho": 0.9}, + {"lr": 1.0, "eps": 1e-06, "rho": 0.9, "weight_decay": 0.9}, # with weight_decay + ], +) +@pytest.mark.parametrize("update_lr", [False, True]) +@pytest.mark.parametrize("inplace_mode", [False, True]) +def test_adadelta(monkeypatch, case, update_lr, inplace_mode): class CheckValue: def __init__(self, net, **kwarg): self.s_slots = {} @@ -246,10 +273,55 @@ def test_adadelta(): param.numpy(), ori_params[param] + delta, decimal=6 ) - cases = [ - {"lr": 1.0, "eps": 1e-06, "rho": 0.9}, - {"lr": 1.0, "eps": 1e-06, "rho": 0.9, "weight_decay": 0.9}, # with weight_decay - ] - for case in cases: - _test_optimizer("Adadelta", case, CheckValue) - _test_optimizer("Adadelta", case, CheckValue, update_lr=True) + with monkeypatch.context() as mk: + mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode))) + _test_optimizer("Adadelta", case, CheckValue, update_lr=update_lr) + + +@pytest.mark.parametrize( + "case", + [ + {"betas": (0.8, 0.9), "eps": 1e-08, "lr": 0.01}, + { + "betas": (0.8, 0.9), + "eps": 1e-08, + "lr": 0.01, + "weight_decay": 0.1, + }, # with weight_decay + ], +) +@pytest.mark.parametrize("update_lr", [False, True]) +@pytest.mark.parametrize("inplace_mode", [False, True]) +def test_adamw(monkeypatch, case, update_lr, inplace_mode): + class CheckValue: + def __init__(self, net, **kwarg): + self.m_slots = {} + self.v_slots = {} + for param in net.parameters(): + self.m_slots[param] = np.zeros(param.shape).astype(np.float32) + self.v_slots[param] = np.zeros(param.shape).astype(np.float32) + self.weight_decay = 0.01 + for k, v in kwarg.items(): + setattr(self, k, v) + + def __call__(self, ori_params, new_params, step): + step = np.array(step).astype(np.float32) + for param in new_params: + grad = param.grad.numpy() + m = self.m_slots[param] + v = self.v_slots[param] + m *= self.betas[0] + m += (1 - self.betas[0]) * grad + v *= self.betas[1] + v += (1 - self.betas[1]) * grad * grad + delta = (m / (1 - self.betas[0] ** step)) / ( + np.sqrt(v / (1 - self.betas[1] ** step)) + self.eps + ) + delta += ori_params[param] * self.weight_decay + np.testing.assert_almost_equal( + param.numpy(), ori_params[param] - self.lr * delta, decimal=6 + ) + + with monkeypatch.context() as mk: + mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode))) + _test_optimizer("AdamW", case, CheckValue, update_lr=update_lr)