Browse Source

feat(sgd): sgd supports nesterov momentum

GitOrigin-RevId: 13eda179da
release-1.7
Megvii Engine Team 3 years ago
parent
commit
4e95c13617
2 changed files with 23 additions and 13 deletions
  1. +17
    -11
      imperative/python/megengine/optimizer/sgd.py
  2. +6
    -2
      imperative/python/test/integration/test_optimizer.py

+ 17
- 11
imperative/python/megengine/optimizer/sgd.py View File

@@ -16,7 +16,7 @@ from .optimizer import Optimizer

class SGD(Optimizer):
r"""Implements stochastic gradient descent.
Nesterov momentum is based on the formula from
`"On the importance of initialization and momentum in deep learning" <http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf>`_ .

@@ -25,6 +25,7 @@ class SGD(Optimizer):
parameter groups.
lr: learning rate.
momentum: momentum factor. Default: 0.0
nesterov: enables Nesterov momentum. Default: False
weight_decay: weight decay (L2 penalty). Default: 0.0
"""

@@ -33,6 +34,7 @@ class SGD(Optimizer):
params: Union[Iterable[Parameter], dict],
lr: float,
momentum: float = 0.0,
nesterov: bool = False,
weight_decay: float = 0.0,
):
assert lr >= 0.0, "Invalid learning rate: {}".format(lr)
@@ -40,9 +42,11 @@ class SGD(Optimizer):
assert weight_decay >= 0.0, "Invalid weight_decay value: {}".format(
weight_decay
)
assert not nesterov or momentum > 0.0, "Nesterov momentum requires a momentum"

defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
super().__init__(params, defaults)
self.nesterov = nesterov
self._disable_type_convert = True

def _create_state(self, param_group):
@@ -76,20 +80,22 @@ class SGD(Optimizer):
grad = grad + param * _weight_decay

if inplace_mode:
if momentum:
if momentum != 0.0:
v = self._state[param]["momentum_buffer"]
_inplace_add_(v, grad, alpha=_momentum, beta=c1)
_inplace_add_(param, v, alpha=c1, beta=_neg_lr)
else:
_inplace_add_(param, grad, alpha=c1, beta=_neg_lr)
if self.nesterov:
grad = grad + v * _momentum
else:
grad = v
_inplace_add_(param, grad, alpha=c1, beta=_neg_lr)
continue

if momentum:
if momentum != 0.0:
v = self._state[param]["momentum_buffer"]
# v = v * _momentum + grad
v *= _momentum
v += grad

param -= _lr * v
else:
param -= _lr * grad
if self.nesterov:
grad = grad + v * _momentum
else:
grad = v
param -= _lr * grad

+ 6
- 2
imperative/python/test/integration/test_optimizer.py View File

@@ -124,6 +124,7 @@ def _test_optimizer(opt_str, test_case, check_class, update_lr=False):
"case",
[
{"momentum": 0.9, "lr": 0.01}, # SGD with momentum
{"momentum": 0.9, "lr": 0.01, "nesterov": True}, # with nesterov momentum
{"lr": 0.01}, # simple SGD
{"weight_decay": 0.1, "lr": 0.01}, # with weight_decay
],
@@ -144,9 +145,12 @@ def test_sgd(monkeypatch, case, update_lr, inplace_mode):
grad = param.grad.numpy()
if hasattr(self, "weight_decay") and self.weight_decay != 0.0:
grad = grad + ori_params[param] * self.weight_decay
if hasattr(self, "momentum"):
if hasattr(self, "momentum") and self.momentum != 0.0:
self.slots[param] = grad + self.slots[param] * self.momentum
delta = -self.lr * self.slots[param]
if hasattr(self, "nesterov") and self.nesterov:
delta = -self.lr * (grad + self.slots[param] * self.momentum)
else:
delta = -self.lr * self.slots[param]
else:
delta = -self.lr * grad
np.testing.assert_almost_equal(


Loading…
Cancel
Save