Browse Source

feat(optimzer): add AdamW

GitOrigin-RevId: e608b5d5b9
tags/v1.4.0-rc1
Megvii Engine Team 4 years ago
parent
commit
7e22e9f06e
3 changed files with 243 additions and 42 deletions
  1. +1
    -0
      imperative/python/megengine/optimizer/__init__.py
  2. +128
    -0
      imperative/python/megengine/optimizer/adamw.py
  3. +114
    -42
      imperative/python/test/integration/test_optimizer.py

+ 1
- 0
imperative/python/megengine/optimizer/__init__.py View File

@@ -9,6 +9,7 @@
from .adadelta import Adadelta from .adadelta import Adadelta
from .adagrad import Adagrad from .adagrad import Adagrad
from .adam import Adam from .adam import Adam
from .adamw import AdamW
from .lr_scheduler import LRScheduler from .lr_scheduler import LRScheduler
from .multi_step_lr import MultiStepLR from .multi_step_lr import MultiStepLR
from .optimizer import Optimizer from .optimizer import Optimizer


+ 128
- 0
imperative/python/megengine/optimizer/adamw.py View File

@@ -0,0 +1,128 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os
from typing import Iterable, Tuple, Union

from ..functional.inplace import _inplace_add_
from ..tensor import Parameter, tensor
from .optimizer import Optimizer


class AdamW(Optimizer):
r"""
Implements AdamW algorithm proposed in `"Decoupled Weight Decay Regularization" <https://arxiv.org/abs/1711.05101>`_.

:param params: iterable of parameters to optimize or dicts defining
parameter groups.
:param lr: learning rate.
:param betas: coefficients used for computing running averages of gradient
and its square. Default: (0.9, 0.999)
:param eps: term added to the denominator to improve numerical stability
Default: 1e-8
:param weight_decay: weight decay (L2 penalty). Default: 1e-2
"""

def __init__(
self,
params: Union[Iterable[Parameter], dict],
lr: float,
betas: Tuple[float, float] = (0.9, 0.999),
eps: float = 1e-8,
weight_decay: float = 1e-2,
):
if lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))

defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps)
super().__init__(params, defaults)

def _create_state(self, param_group):
for param in param_group["params"]:
self._add_state(param, "exp_avg")
self._add_state(param, "exp_avg_sq")
self._add_state(param, "step", initializer=0.0)

def _updates(self, param_group):
lr = param_group["lr"]
weight_decay = param_group["weight_decay"]
eps = param_group["eps"]
beta0, beta1 = param_group["betas"]

def make_scalar(val):
return tensor(val)

# since `conver_inputs` is disabled for param updates,
# scalar should be explicitly tansforred to tensor

_lr, _neg_lr = map(make_scalar, (lr, -lr))
_weight_decay = make_scalar(weight_decay)
_eps = make_scalar(eps)
_beta0, _beta1 = map(make_scalar, (beta0, beta1))

c1, c05 = map(make_scalar, (1.0, 0.5))

inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0"))
if inplace_mode:
# reduce device sync
c1_sub_beta0, c1_sub_beta1 = map(make_scalar, (1 - beta0, 1 - beta1))

for param in param_group["params"]:

if param.grad is None:
continue

grad = param.grad

states = self._state[param]

step, exp_avg, exp_avg_sq = (
states["step"],
states["exp_avg"],
states["exp_avg_sq"],
)

if inplace_mode:
_inplace_add_(step, c1, alpha=c1, beta=c1)
_inplace_add_(exp_avg, grad, alpha=_beta0, beta=c1_sub_beta0)
_inplace_add_(
exp_avg_sq, grad * grad, alpha=_beta1, beta=c1_sub_beta1,
)

delta = (exp_avg / (c1 - _beta0 ** step)) / (
(exp_avg_sq / (c1 - _beta1 ** step)) ** c05 + _eps
)
if weight_decay != 0.0:
delta += param * _weight_decay
_inplace_add_(param, delta, alpha=c1, beta=_neg_lr)
continue

# step = step + c1
step += c1

# exp_avg = _beta0 * exp_avg + grad * (c1 - _beta0)
exp_avg *= _beta0
exp_avg += grad * (c1 - _beta0)

# exp_avg_sq = _beta1 * exp_avg_sq + (c1 - _beta1) * (grad * grad)
exp_avg_sq *= _beta1
exp_avg_sq += (c1 - _beta1) * (grad * grad)

delta = (exp_avg / (c1 - _beta0 ** step)) / (
(exp_avg_sq / (c1 - _beta1 ** step)) ** c05 + _eps
)
if weight_decay != 0.0:
delta += param * _weight_decay

param -= _lr * delta

+ 114
- 42
imperative/python/test/integration/test_optimizer.py View File

@@ -6,7 +6,10 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import os

import numpy as np import numpy as np
import pytest


import megengine.autodiff as ad import megengine.autodiff as ad
import megengine.functional as F import megengine.functional as F
@@ -110,7 +113,17 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
} }




def test_sgd():
@pytest.mark.parametrize(
"case",
[
{"momentum": 0.9, "lr": 0.01}, # SGD with momentum
{"lr": 0.01}, # simple SGD
{"weight_decay": 0.1, "lr": 0.01}, # with weight_decay
],
)
@pytest.mark.parametrize("update_lr", [False, True])
@pytest.mark.parametrize("inplace_mode", [False, True])
def test_sgd(monkeypatch, case, update_lr, inplace_mode):
class CheckValue: class CheckValue:
def __init__(self, net, **kwarg): def __init__(self, net, **kwarg):
self.slots = {} self.slots = {}
@@ -131,17 +144,26 @@ def test_sgd():
param.numpy(), ori_params[param] + delta, decimal=6 param.numpy(), ori_params[param] + delta, decimal=6
) )


cases = [
{"momentum": 0.9, "lr": 0.01}, # SGD with momentum
{"lr": 0.01}, # simple SGD
{"weight_decay": 0.1, "lr": 0.01}, # with weight_decay
]
for case in cases:
_test_optimizer("SGD", case, CheckValue)
_test_optimizer("SGD", case, CheckValue, update_lr=True)
with monkeypatch.context() as mk:
mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
_test_optimizer("SGD", case, CheckValue, update_lr=update_lr)




def test_adam():
@pytest.mark.parametrize(
"case",
[
{"betas": (0.8, 0.9), "eps": 1e-04, "lr": 0.01},
{
"betas": (0.8, 0.9),
"eps": 1e-04,
"lr": 0.01,
"weight_decay": 0.1,
}, # with weight_decay
],
)
@pytest.mark.parametrize("update_lr", [False, True])
@pytest.mark.parametrize("inplace_mode", [False, True])
def test_adam(monkeypatch, case, update_lr, inplace_mode):
class CheckValue: class CheckValue:
def __init__(self, net, **kwarg): def __init__(self, net, **kwarg):
self.m_slots = {} self.m_slots = {}
@@ -168,21 +190,27 @@ def test_adam():
param.numpy(), ori_params[param] - self.lr * delta, decimal=6 param.numpy(), ori_params[param] - self.lr * delta, decimal=6
) )


cases = [
{"betas": (0.8, 0.9), "eps": 1e-04, "lr": 0.01},
with monkeypatch.context() as mk:
mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
_test_optimizer("Adam", case, CheckValue, update_lr=update_lr)


@pytest.mark.parametrize(
"case",
[
{"lr": 0.01, "eps": 1e-06, "lr_decay": 0.01},
{"lr": 0.01, "eps": 1e-06, "lr_decay": 0.0}, # without lr_decay
{ {
"betas": (0.8, 0.9),
"eps": 1e-04,
"lr": 0.01, "lr": 0.01,
"eps": 1e-06,
"lr_decay": 0.01,
"weight_decay": 0.1, "weight_decay": 0.1,
}, # with weight_decay }, # with weight_decay
]
for case in cases:
_test_optimizer("Adam", case, CheckValue)
_test_optimizer("Adam", case, CheckValue, update_lr=True)


def test_adagrad():
],
)
@pytest.mark.parametrize("update_lr", [False, True])
@pytest.mark.parametrize("inplace_mode", [False, True])
def test_adagrad(monkeypatch, case, update_lr, inplace_mode):
class CheckValue: class CheckValue:
def __init__(self, net, **kwarg): def __init__(self, net, **kwarg):
self.s_slots = {} self.s_slots = {}
@@ -201,22 +229,21 @@ def test_adagrad():
param.numpy(), ori_params[param] + delta, decimal=6 param.numpy(), ori_params[param] + delta, decimal=6
) )


cases = [
{"lr": 0.01, "eps": 1e-06, "lr_decay": 0.01},
{"lr": 0.01, "eps": 1e-06, "lr_decay": 0.0}, # without lr_decay
{
"lr": 0.01,
"eps": 1e-06,
"lr_decay": 0.01,
"weight_decay": 0.1,
}, # with weight_decay
]
for case in cases:
_test_optimizer("Adagrad", case, CheckValue)
_test_optimizer("Adagrad", case, CheckValue, update_lr=True)
with monkeypatch.context() as mk:
mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
_test_optimizer("Adagrad", case, CheckValue, update_lr=update_lr)




def test_adadelta():
@pytest.mark.parametrize(
"case",
[
{"lr": 1.0, "eps": 1e-06, "rho": 0.9},
{"lr": 1.0, "eps": 1e-06, "rho": 0.9, "weight_decay": 0.9}, # with weight_decay
],
)
@pytest.mark.parametrize("update_lr", [False, True])
@pytest.mark.parametrize("inplace_mode", [False, True])
def test_adadelta(monkeypatch, case, update_lr, inplace_mode):
class CheckValue: class CheckValue:
def __init__(self, net, **kwarg): def __init__(self, net, **kwarg):
self.s_slots = {} self.s_slots = {}
@@ -246,10 +273,55 @@ def test_adadelta():
param.numpy(), ori_params[param] + delta, decimal=6 param.numpy(), ori_params[param] + delta, decimal=6
) )


cases = [
{"lr": 1.0, "eps": 1e-06, "rho": 0.9},
{"lr": 1.0, "eps": 1e-06, "rho": 0.9, "weight_decay": 0.9}, # with weight_decay
]
for case in cases:
_test_optimizer("Adadelta", case, CheckValue)
_test_optimizer("Adadelta", case, CheckValue, update_lr=True)
with monkeypatch.context() as mk:
mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
_test_optimizer("Adadelta", case, CheckValue, update_lr=update_lr)


@pytest.mark.parametrize(
"case",
[
{"betas": (0.8, 0.9), "eps": 1e-08, "lr": 0.01},
{
"betas": (0.8, 0.9),
"eps": 1e-08,
"lr": 0.01,
"weight_decay": 0.1,
}, # with weight_decay
],
)
@pytest.mark.parametrize("update_lr", [False, True])
@pytest.mark.parametrize("inplace_mode", [False, True])
def test_adamw(monkeypatch, case, update_lr, inplace_mode):
class CheckValue:
def __init__(self, net, **kwarg):
self.m_slots = {}
self.v_slots = {}
for param in net.parameters():
self.m_slots[param] = np.zeros(param.shape).astype(np.float32)
self.v_slots[param] = np.zeros(param.shape).astype(np.float32)
self.weight_decay = 0.01
for k, v in kwarg.items():
setattr(self, k, v)

def __call__(self, ori_params, new_params, step):
step = np.array(step).astype(np.float32)
for param in new_params:
grad = param.grad.numpy()
m = self.m_slots[param]
v = self.v_slots[param]
m *= self.betas[0]
m += (1 - self.betas[0]) * grad
v *= self.betas[1]
v += (1 - self.betas[1]) * grad * grad
delta = (m / (1 - self.betas[0] ** step)) / (
np.sqrt(v / (1 - self.betas[1] ** step)) + self.eps
)
delta += ori_params[param] * self.weight_decay
np.testing.assert_almost_equal(
param.numpy(), ori_params[param] - self.lr * delta, decimal=6
)

with monkeypatch.context() as mk:
mk.setenv("MEGENGINE_INPLACE_UPDATE", str(int(inplace_mode)))
_test_optimizer("AdamW", case, CheckValue, update_lr=update_lr)

Loading…
Cancel
Save