You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

sgd.py 3.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import os
  10. from typing import Iterable, Union
  11. from ..functional.inplace import _inplace_add_
  12. from ..tensor import Parameter, tensor
  13. from .optimizer import Optimizer
  14. class SGD(Optimizer):
  15. r"""
  16. Implements stochastic gradient descent.
  17. Nesterov momentum is based on the formula from
  18. `"On the importance of initialization and momentum in deep learning" <http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf>`_ .
  19. :param params: iterable of parameters to optimize or dicts defining
  20. parameter groups.
  21. :param lr: learning rate.
  22. :param momentum: momentum factor. Default: 0.0
  23. :param weight_decay: weight decay (L2 penalty). Default: 0.0
  24. """
  25. def __init__(
  26. self,
  27. params: Union[Iterable[Parameter], dict],
  28. lr: float,
  29. momentum: float = 0.0,
  30. weight_decay: float = 0.0,
  31. ):
  32. assert lr >= 0.0, "Invalid learning rate: {}".format(lr)
  33. assert momentum >= 0.0, "Invalid momentum value: {}".format(momentum)
  34. assert weight_decay >= 0.0, "Invalid weight_decay value: {}".format(
  35. weight_decay
  36. )
  37. defaults = dict(lr=lr, momentum=momentum, weight_decay=weight_decay)
  38. super().__init__(params, defaults)
  39. def _create_state(self, param_group):
  40. if param_group["momentum"] != 0.0:
  41. for param in param_group["params"]:
  42. self._add_state(param, "momentum_buffer")
  43. def _updates(self, param_group):
  44. lr = param_group["lr"]
  45. weight_decay = param_group["weight_decay"]
  46. momentum = param_group["momentum"]
  47. # since `conver_inputs` is disabled for param updates,
  48. # scalar should be explicitly tansforred to tensor
  49. _lr = tensor(lr)
  50. _weight_decay = tensor(weight_decay)
  51. _momentum = tensor(momentum)
  52. inplace_mode = int(os.getenv("MEGENGINE_INPLACE_UPDATE", "0"))
  53. if inplace_mode:
  54. _neg_lr = tensor(-lr)
  55. c1 = tensor([1.0])
  56. for param in param_group["params"]:
  57. if param.grad is None:
  58. continue
  59. grad = param.grad
  60. if weight_decay != 0.0:
  61. grad = grad + param * _weight_decay
  62. if inplace_mode:
  63. if momentum:
  64. v = self._state[param]["momentum_buffer"]
  65. _inplace_add_(v, grad, alpha=_momentum, beta=c1)
  66. _inplace_add_(param, v, alpha=c1, beta=_neg_lr)
  67. else:
  68. _inplace_add_(param, grad, alpha=c1, beta=_neg_lr)
  69. continue
  70. if momentum:
  71. v = self._state[param]["momentum_buffer"]
  72. # v = v * _momentum + grad
  73. v *= _momentum
  74. v += grad
  75. param -= _lr * v
  76. else:
  77. param -= _lr * grad

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台