|
|
@@ -16,7 +16,7 @@ from .optimizer import Optimizer |
|
|
|
|
|
|
|
class SGD(Optimizer): |
|
|
|
r"""Implements stochastic gradient descent. |
|
|
|
|
|
|
|
|
|
|
|
Nesterov momentum is based on the formula from |
|
|
|
`"On the importance of initialization and momentum in deep learning" <http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf>`_ . |
|
|
|
|
|
|
@@ -25,6 +25,7 @@ class SGD(Optimizer): |
|
|
|
parameter groups. |
|
|
|
lr: learning rate. |
|
|
|
momentum: momentum factor. Default: 0.0 |
|
|
|
nesterov: enables Nesterov momentum. Default: False |
|
|
|
weight_decay: weight decay (L2 penalty). Default: 0.0 |
|
|
|
""" |
|
|
|
|
|
|
@@ -33,6 +34,7 @@ class SGD(Optimizer): |
|
|
|
params: Union[Iterable[Parameter], dict], |
|
|
|
lr: float, |
|
|
|
momentum: float = 0.0, |
|
|
|
nesterov: bool = False, |
|
|
|
weight_decay: float = 0.0, |
|
|
|
): |
|
|
|
assert lr >= 0.0, "Invalid learning rate: {}".format(lr) |
|
|
@@ -40,9 +42,11 @@ class SGD(Optimizer): |
|
|
|
assert weight_decay >= 0.0, "Invalid weight_decay value: {}".format( |
|
|
|
weight_decay |
|
|
|
) |
|
|
|
assert not nesterov or momentum > 0.0, "Nesterov momentum requires a momentum" |
|
|
|
|
|
|
|
defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay) |
|
|
|
super().__init__(params, defaults) |
|
|
|
self.nesterov = nesterov |
|
|
|
self._disable_type_convert = True |
|
|
|
|
|
|
|
def _create_state(self, param_group): |
|
|
@@ -76,20 +80,22 @@ class SGD(Optimizer): |
|
|
|
grad = grad + param * _weight_decay |
|
|
|
|
|
|
|
if inplace_mode: |
|
|
|
if momentum: |
|
|
|
if momentum != 0.0: |
|
|
|
v = self._state[param]["momentum_buffer"] |
|
|
|
_inplace_add_(v, grad, alpha=_momentum, beta=c1) |
|
|
|
_inplace_add_(param, v, alpha=c1, beta=_neg_lr) |
|
|
|
else: |
|
|
|
_inplace_add_(param, grad, alpha=c1, beta=_neg_lr) |
|
|
|
if self.nesterov: |
|
|
|
grad = grad + v * _momentum |
|
|
|
else: |
|
|
|
grad = v |
|
|
|
_inplace_add_(param, grad, alpha=c1, beta=_neg_lr) |
|
|
|
continue |
|
|
|
|
|
|
|
if momentum: |
|
|
|
if momentum != 0.0: |
|
|
|
v = self._state[param]["momentum_buffer"] |
|
|
|
# v = v * _momentum + grad |
|
|
|
v *= _momentum |
|
|
|
v += grad |
|
|
|
|
|
|
|
param -= _lr * v |
|
|
|
else: |
|
|
|
param -= _lr * grad |
|
|
|
if self.nesterov: |
|
|
|
grad = grad + v * _momentum |
|
|
|
else: |
|
|
|
grad = v |
|
|
|
param -= _lr * grad |