You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

lamb.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # Copyright (c) 2020 Ross Wightman
  2. # This file has been modified by Megvii ("Megvii Modifications").
  3. # All Megvii Modifications are Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. """LAMB optimizer
  5. References: https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lamb.py
  6. """
  7. import os
  8. from typing import Iterable, Tuple, Union
  9. from megengine.core._imperative_rt.core2 import apply
  10. from megengine.core.ops.builtin import LAMBUpdate
  11. from .. import Parameter, tensor
  12. from ..functional import sum
  13. from ..functional.inplace import _inplace_add_
  14. from .optimizer import Optimizer
  15. class LAMB(Optimizer):
  16. r"""Implements LAMB algorithm.
  17. LAMB is proposed in `"Large Batch Optimization for Deep Learning: Training BERT in 76 minutes"
  18. <https://arxiv.org/abs/1904.00962>`_.
  19. Args:
  20. params: iterable of parameters to optimize or dicts defining parameter groups.
  21. lr: learning rate.
  22. betas: coefficients used for computing running averages of gradient and its square.
  23. Default: ``(0.9, 0.999)``
  24. eps: term added to the denominator to improve numerical stability. Default: ``1e-8``
  25. bias_correction: enables bias correction by ``1 - beta ** step``. Default: ``True``
  26. weight_decay: weight decay (L2 penalty). Default: ``0.0``
  27. always_adapt: apply adaptive lr to ``0.0`` weight decay parameter. Default: ``False``
  28. """
  29. def __init__(
  30. self,
  31. params: Union[Iterable[Parameter], dict],
  32. lr: float,
  33. betas: Tuple[float, float] = (0.9, 0.999),
  34. eps: float = 1e-8,
  35. bias_correction: bool = True,
  36. weight_decay: float = 0.0,
  37. always_adapt: bool = False,
  38. ):
  39. if lr < 0.0:
  40. raise ValueError("Invalid learning rate: {}".format(lr))
  41. if weight_decay < 0.0:
  42. raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
  43. if not 0.0 <= betas[0] < 1.0:
  44. raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
  45. if not 0.0 <= betas[1] < 1.0:
  46. raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
  47. defaults = dict(lr=lr, weight_decay=weight_decay, betas=betas, eps=eps)
  48. super().__init__(params, defaults)
  49. self.bias_correction = bias_correction
  50. self.always_adapt = always_adapt
  51. self._disable_type_convert = True
  52. def _create_state(self, param_group):
  53. for param in param_group["params"]:
  54. self._add_state(param, "exp_avg")
  55. self._add_state(param, "exp_avg_sq")
  56. self._add_state(param, "step", initializer=0.0, dtype="float32")
  57. def _updates(self, param_group):
  58. lr = param_group["lr"]
  59. weight_decay = param_group["weight_decay"]
  60. eps = param_group["eps"]
  61. beta0, beta1 = param_group["betas"]
  62. # since `conver_inputs` is disabled for param updates,
  63. # scalar should be explicitly tansforred to tensor
  64. c1 = tensor(1.0)
  65. for param in param_group["params"]:
  66. if param.grad is None:
  67. continue
  68. grad = param.grad
  69. states = self._state[param]
  70. step, exp_avg, exp_avg_sq = (
  71. states["step"],
  72. states["exp_avg"],
  73. states["exp_avg_sq"],
  74. )
  75. step += c1
  76. op = LAMBUpdate(
  77. beta0,
  78. beta1,
  79. int(step),
  80. lr,
  81. weight_decay,
  82. eps,
  83. self.bias_correction,
  84. self.always_adapt,
  85. )
  86. new_exp_avg, new_exp_avg_sq, new_param = apply(
  87. op, exp_avg, exp_avg_sq, param, grad
  88. )
  89. param._reset(new_param)
  90. exp_avg._reset(new_exp_avg)
  91. exp_avg_sq._reset(new_exp_avg_sq)
  92. class LAMBFp16(LAMB):
  93. def _create_state(self, param_group):
  94. for param in param_group["params"]:
  95. self._add_state(param, "exp_avg", dtype="float32")
  96. self._add_state(param, "exp_avg_sq", dtype="float32")
  97. self._add_state(param, "step", initializer=0.0, dtype="float32")
  98. self._state[param]["param_fp32"] = param.astype("float32")
  99. def _updates(self, param_group):
  100. lr = param_group["lr"]
  101. weight_decay = param_group["weight_decay"]
  102. eps = param_group["eps"]
  103. beta0, beta1 = param_group["betas"]
  104. c1 = tensor(1.0)
  105. for param in param_group["params"]:
  106. if param.grad is None:
  107. continue
  108. grad = param.grad
  109. states = self._state[param]
  110. step, exp_avg, exp_avg_sq = (
  111. states["step"],
  112. states["exp_avg"],
  113. states["exp_avg_sq"],
  114. )
  115. step += c1
  116. fp32_param = states["param_fp32"]
  117. op = LAMBUpdate(
  118. beta0,
  119. beta1,
  120. step,
  121. lr,
  122. weight_decay,
  123. eps,
  124. self.bias_correction,
  125. self.always_adapt,
  126. )
  127. new_exp_avg, new_exp_avg_sq, new_param = apply(
  128. op, exp_avg, exp_avg_sq, fp32_param, grad
  129. )
  130. fp32_param._reset(new_param)
  131. param._reset(new_param.astype("float16"))
  132. exp_avg._reset(new_exp_avg)
  133. exp_avg_sq._reset(new_exp_avg_sq)