You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import os
  10. from typing import Iterable, Union
  11. from ..functional.inplace import _inplace_add_
  12. from ..tensor import Parameter, tensor
  13. from .optimizer import Optimizer
  14. class SGD(Optimizer):
  15. r"""Implements stochastic gradient descent.
  16. Nesterov momentum is based on the formula from
  17. `"On the importance of initialization and momentum in deep learning" <http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf>`_ .
  18. Args:
  19. params: iterable of parameters to optimize or dicts defining
  20. parameter groups.
  21. lr: learning rate.
  22. momentum: momentum factor. Default: 0.0
  23. nesterov: enables Nesterov momentum. Default: False
  24. weight_decay: weight decay (L2 penalty). Default: 0.0
  25. """
  26. def __init__(
  27. self,
  28. params: Union[Iterable[Parameter], dict],
  29. lr: float,
  30. momentum: float = 0.0,
  31. nesterov: bool = False,
  32. weight_decay: float = 0.0,
  33. ):
  34. assert lr >= 0.0, "Invalid learning rate: {}".format(lr)
  35. assert momentum >= 0.0, "Invalid momentum value: {}".format(momentum)
  36. assert weight_decay >= 0.0, "Invalid weight_decay value: {}".format(
  37. weight_decay
  38. )
  39. assert not nesterov or momentum > 0.0, "Nesterov momentum requires a momentum"
  40. defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
  41. super().__init__(params, defaults)
  42. self.nesterov = nesterov
  43. self._disable_type_convert = True
  44. def _create_state(self, param_group):
  45. if param_group["momentum"] != 0.0:
  46. for param in param_group["params"]:
  47. self._add_state(param, "momentum_buffer")
  48. def _updates(self, param_group):
  49. lr = param_group["lr"]
  50. weight_decay = param_group["weight_decay"]
  51. momentum = param_group["momentum"]
  52. # since `conver_inputs` is disabled for param updates,
  53. # scalar should be explicitly tansforred to tensor
  54. _lr = tensor(lr, dtype="float32")
  55. _weight_decay = tensor(weight_decay, dtype="float32")
  56. _momentum = tensor(momentum, dtype="float32")
  57. inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0"))
  58. if inplace_mode:
  59. _neg_lr = tensor(-lr, dtype="float32")
  60. c1 = tensor(1.0)
  61. for param in param_group["params"]:
  62. if param.grad is None:
  63. continue
  64. grad = param.grad
  65. if weight_decay != 0.0:
  66. grad = grad + param * _weight_decay
  67. if inplace_mode:
  68. if momentum != 0.0:
  69. v = self._state[param]["momentum_buffer"]
  70. _inplace_add_(v, grad, alpha=_momentum, beta=c1)
  71. if self.nesterov:
  72. grad = grad + v * _momentum
  73. else:
  74. grad = v
  75. _inplace_add_(param, grad, alpha=c1, beta=_neg_lr)
  76. continue
  77. if momentum != 0.0:
  78. v = self._state[param]["momentum_buffer"]
  79. v *= _momentum
  80. v += grad
  81. if self.nesterov:
  82. grad = grad + v * _momentum
  83. else:
  84. grad = v
  85. param -= _lr * grad