You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

adagrad.py 2.9 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # -*- coding: utf-8 -*-
  2. from typing import Iterable, Union
  3. import numpy as np
  4. from ..tensor import Parameter, tensor
  5. from .optimizer import Optimizer
  6. class Adagrad(Optimizer):
  7. r"""Implements Adagrad algorithm.
  8. It has been proposed in `"Adaptive Subgradient Methods for Online Learning
  9. and Stochastic Optimization" <http://jmlr.org/papers/v12/duchi11a.html>`_.
  10. Args:
  11. params: iterable of parameters to optimize or dicts defining
  12. parameter groups.
  13. lr: coefficient that scales delta before it is applied
  14. to the parameters. Default: 1e-2
  15. lr_decay: learning rate decay. Default: 0
  16. eps: term added to the denominator to improve
  17. numerical stability. Default: 1e-10
  18. weight_decay: weight decay (L2 penalty). Default: 0
  19. """
  20. def __init__(
  21. self,
  22. params: Union[Iterable[Parameter], dict],
  23. lr: float = 1e-2,
  24. lr_decay: float = 0.0,
  25. eps: float = 1e-10,
  26. weight_decay: float = 0.0,
  27. ):
  28. assert lr >= 0.0, "Invalid learning rate: {}".format(lr)
  29. assert lr_decay >= 0, "Invalid learning rate decay: {}".format(lr_decay)
  30. assert eps >= 0.0, "Invalid epsilon value: {}".format(eps)
  31. assert weight_decay >= 0.0, "Invalid weight_decay value: {}".format(
  32. weight_decay
  33. )
  34. defaults = dict(lr=lr, lr_decay=lr_decay, eps=eps, weight_decay=weight_decay)
  35. super().__init__(params, defaults)
  36. self._disable_type_convert = True
  37. def _create_state(self, param_group):
  38. for param in param_group["params"]:
  39. self._add_state(param, "square_avg")
  40. self._add_state(param, "step", initializer=0.0)
  41. def _updates(self, param_group):
  42. lr = param_group["lr"]
  43. lr_decay = param_group["lr_decay"]
  44. weight_decay = param_group["weight_decay"]
  45. eps = param_group["eps"]
  46. def make_scalar(val):
  47. return tensor(val, dtype="float32")
  48. # since `conver_inputs` is disabled for param updates,
  49. # scalar should be explicitly tansforred to tensor
  50. _lr, _lr_decay = map(make_scalar, (lr, lr_decay))
  51. _weight_decay = make_scalar(weight_decay)
  52. _eps = make_scalar(eps)
  53. c1, c2, c05 = map(make_scalar, (1.0, 2.0, 0.5))
  54. for param in param_group["params"]:
  55. if param.grad is None:
  56. continue
  57. states = self._state[param]
  58. step = states["step"]
  59. step += c1
  60. grad = param.grad
  61. if weight_decay != 0.0:
  62. grad = grad + param * _weight_decay
  63. square_avg = states["square_avg"]
  64. square_avg += grad ** c2
  65. delta = grad / (square_avg + _eps) ** c05
  66. clr = _lr / (c1 + (step - c1) * _lr_decay)
  67. param -= clr * delta