You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. # -*- coding: utf-8 -*-
  2. import os
  3. from typing import Iterable, Union
  4. from ..core import _config
  5. from ..functional.inplace import _inplace_add_
  6. from ..tensor import Parameter, tensor
  7. from .optimizer import Optimizer
  8. class SGD(Optimizer):
  9. r"""Implements stochastic gradient descent.
  10. Nesterov momentum is based on the formula from
  11. `"On the importance of initialization and momentum in deep learning" <http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf>`_ .
  12. Args:
  13. params: iterable of parameters to optimize or dicts defining
  14. parameter groups.
  15. lr: learning rate.
  16. momentum: momentum factor. Default: 0.0
  17. nesterov: enables Nesterov momentum. Default: False
  18. weight_decay: weight decay (L2 penalty). Default: 0.0
  19. """
  20. def __init__(
  21. self,
  22. params: Union[Iterable[Parameter], dict],
  23. lr: float,
  24. momentum: float = 0.0,
  25. nesterov: bool = False,
  26. weight_decay: float = 0.0,
  27. ):
  28. assert lr >= 0.0, "Invalid learning rate: {}".format(lr)
  29. assert momentum >= 0.0, "Invalid momentum value: {}".format(momentum)
  30. assert weight_decay >= 0.0, "Invalid weight_decay value: {}".format(
  31. weight_decay
  32. )
  33. assert not nesterov or momentum > 0.0, "Nesterov momentum requires a momentum"
  34. defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
  35. super().__init__(params, defaults)
  36. self.nesterov = nesterov
  37. self._disable_type_convert = True
  38. def _create_state(self, param_group):
  39. if param_group["momentum"] != 0.0:
  40. for param in param_group["params"]:
  41. self._add_state(param, "momentum_buffer")
  42. def _updates(self, param_group):
  43. lr = param_group["lr"]
  44. weight_decay = param_group["weight_decay"]
  45. momentum = param_group["momentum"]
  46. # since `conver_inputs` is disabled for param updates,
  47. # scalar should be explicitly tansforred to tensor
  48. _lr = tensor(lr, dtype="float32")
  49. _weight_decay = tensor(weight_decay, dtype="float32")
  50. _momentum = tensor(momentum, dtype="float32")
  51. inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0"))
  52. if inplace_mode:
  53. _neg_lr = tensor(-lr, dtype="float32")
  54. c1 = tensor(1.0)
  55. for param in param_group["params"]:
  56. if param.grad is None:
  57. continue
  58. grad = param.grad
  59. if weight_decay != 0.0:
  60. grad = grad + param * _weight_decay
  61. if inplace_mode:
  62. if momentum != 0.0:
  63. v = self._state[param]["momentum_buffer"]
  64. _inplace_add_(v, grad, alpha=_momentum, beta=c1)
  65. if self.nesterov:
  66. grad = grad + v * _momentum
  67. else:
  68. grad = v
  69. _inplace_add_(param, grad, alpha=c1, beta=_neg_lr)
  70. continue
  71. if momentum != 0.0:
  72. v = self._state[param]["momentum_buffer"]
  73. v *= _momentum
  74. v += grad
  75. if self.nesterov:
  76. grad = grad + v * _momentum
  77. else:
  78. grad = v
  79. param -= _lr * grad