You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_scaler.py 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from typing import Iterable, List, Union
  9. import numpy as np
  10. from ..autodiff import GradManager
  11. from ..functional import full_like
  12. from ..functional.math import _check_non_finite
  13. from ..tensor import Tensor
  14. class GradScaler:
  15. r"""A helper class that performs grad scaling to prevent from data overflow in
  16. :class:`~.autocast` mode.
  17. Args:
  18. init_scale: Initial scale factor.
  19. growth_factor: Factor that the scale is multiplied by in actual
  20. :meth:`update` stage. If growth_factor is 0, scale_factor will not update.
  21. backoff_factor: Factor that the scale is multiplied by when encountering
  22. overflow grad.
  23. growth_interval: The interval between two scale update stages.
  24. Example:
  25. .. code-block::
  26. gm = GradManager()
  27. opt = ...
  28. scaler = GradScaler()
  29. gm.attach(model.parameters())
  30. @autocast()
  31. def train_step(image, label):
  32. with gm:
  33. logits = model(image)
  34. loss = F.nn.cross_entropy(logits, label)
  35. scaler.backward(gm, loss)
  36. opt.step().clear_grad()
  37. return loss
  38. If need more flexible usage, could split ``scaler.backward`` into three lines:
  39. .. code-block::
  40. @autocast()
  41. def train_step(image, label):
  42. with gm:
  43. logits = model(image)
  44. loss = F.nn.cross_entropy(logits, label)
  45. gm.backward(loss, dy=megengine.tensor(scaler.scale_factor))
  46. scaler.unscale(gm.attached_tensors())
  47. scaler.update()
  48. opt.step().clear_grad()
  49. return loss
  50. This is useful when need to accumulate grads for multi batches.
  51. """
  52. def __init__(
  53. self,
  54. init_scale: float = 2.0 ** 4,
  55. growth_factor: float = 2.0,
  56. backoff_factor: float = 0.5,
  57. growth_interval: int = 2000,
  58. ):
  59. self.scale_factor = float(init_scale)
  60. self.growth_factor = float(growth_factor)
  61. self.backoff_factor = float(backoff_factor)
  62. self.growth_interval = growth_interval
  63. self._growth_tracker = 0
  64. self._found_non_finite = False
  65. def backward(
  66. self,
  67. gm: GradManager,
  68. y: Union[Tensor, List[Tensor]] = None,
  69. dy: Union[Tensor, List[Tensor]] = None,
  70. *,
  71. unscale_grad: bool = True,
  72. update_scale: bool = "if_unscale_grad"
  73. ):
  74. r"""A wrapper of GradManager's :meth:`~.GradManager.backward`, used to scale
  75. ``y``'s grad and unscale parameters' grads.
  76. Args:
  77. gm: The to be wrapped GradManager.
  78. y: Same as GradManager backward's ``y``.
  79. dy: Same as GradManager backward's ``dy``. Will be multiplied
  80. by ``scale_factor``.
  81. unscale_grad: Whether do :meth:`unscale` at the same time. Could be
  82. ``False`` if needs to accumulate grads.
  83. update_scale: Same as :meth:`unscale`'s ``update``. Will be ignored
  84. if ``unscale_grad`` is ``False``.
  85. """
  86. # These checks should be consistent with GradManager's
  87. if y is None:
  88. ys = []
  89. elif isinstance(y, (tuple, list)):
  90. ys = y
  91. else:
  92. ys = [y]
  93. if dy is None:
  94. dys = [full_like(y, self.scale_factor) for y in ys]
  95. elif isinstance(dy, (tuple, list)):
  96. dys = [dy_ * self.scale_factor for dy_ in dy]
  97. else:
  98. dys = [dy * self.scale_factor]
  99. gm.backward(y=ys, dy=dys)
  100. if unscale_grad:
  101. self.unscale(gm.attached_tensors())
  102. if update_scale:
  103. self.update()
  104. def unscale(self, grad_tensors: Iterable[Tensor]):
  105. r"""Unscale all ``grad_tensors``'s grad.
  106. Args:
  107. grad_tensors: Tensors needed to unscale grads. Should be all tensors
  108. that are affected by ``target`` tensor in GradManager's backward.
  109. """
  110. if self.growth_interval == 0:
  111. # use float64 for better precision
  112. inv_scale = Tensor(1.0 / self.scale_factor)
  113. for tensor in grad_tensors:
  114. if tensor is None or getattr(tensor, "grad", None) is None:
  115. continue
  116. tensor.grad *= inv_scale
  117. return self
  118. # to support tracing, _check_gradients should be applied to every grad.
  119. if self._check_gradients(
  120. [x.grad for x in grad_tensors], 1.0 / self.scale_factor
  121. ):
  122. self._found_non_finite = True
  123. for tensor in grad_tensors:
  124. if tensor is None or getattr(tensor, "grad", None) is None:
  125. continue
  126. tensor.grad = None
  127. return self
  128. def _check_gradients(self, grad, scale):
  129. return _check_non_finite(grad, scale)
  130. def update(self, new_scale: float = None):
  131. r"""Update the scale factor according to whether encountered overflow grad.
  132. If ``new_scale`` is provided, internal update mechanism will be ignored.
  133. """
  134. if self.growth_interval == 0:
  135. return
  136. if new_scale is not None:
  137. self.scale_factor = float(new_scale)
  138. else:
  139. if self._found_non_finite:
  140. self.scale_factor *= self.backoff_factor
  141. self._growth_tracker = 0
  142. else:
  143. self._growth_tracker += 1
  144. if self._growth_tracker >= self.growth_interval:
  145. self.scale_factor *= self.growth_factor
  146. self._growth_tracker = 0
  147. self._found_non_finite = False
  148. def state_dict(self):
  149. return {
  150. "scale_factor": self.scale_factor,
  151. "growth_factor": self.growth_factor,
  152. "backoff_factor": self.backoff_factor,
  153. "growth_interval": self.growth_interval,
  154. "_growth_tracker": self._growth_tracker,
  155. }
  156. def load_state_dict(self, state):
  157. self.scale_factor = state["scale_factor"]
  158. self.growth_factor = state["growth_factor"]
  159. self.backoff_factor = state["backoff_factor"]
  160. self.growth_interval = state["growth_interval"]
  161. self._growth_tracker = state["_growth_tracker"]