You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_scaler.py 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. from typing import Iterable, List, Union
  2. import numpy as np
  3. from ..autodiff import GradManager
  4. from ..functional import full_like
  5. from ..functional.math import _check_non_finite
  6. from ..tensor import Tensor
  7. class GradScaler:
  8. r"""A helper class that performs grad scaling to prevent from data overflow in
  9. :class:`~.autocast` mode.
  10. Args:
  11. init_scale: Initial scale factor.
  12. growth_factor: Factor that the scale is multiplied by in actual
  13. :meth:`update` stage. If growth_factor is 0, scale_factor will not update.
  14. backoff_factor: Factor that the scale is multiplied by when encountering
  15. overflow grad.
  16. growth_interval: The interval between two scale update stages.
  17. Example:
  18. .. code-block::
  19. gm = GradManager()
  20. opt = ...
  21. scaler = GradScaler()
  22. gm.attach(model.parameters())
  23. @autocast()
  24. def train_step(image, label):
  25. with gm:
  26. logits = model(image)
  27. loss = F.nn.cross_entropy(logits, label)
  28. scaler.backward(gm, loss)
  29. opt.step().clear_grad()
  30. return loss
  31. If need more flexible usage, could split ``scaler.backward`` into three lines:
  32. .. code-block::
  33. @autocast()
  34. def train_step(image, label):
  35. with gm:
  36. logits = model(image)
  37. loss = F.nn.cross_entropy(logits, label)
  38. gm.backward(loss, dy=megengine.tensor(scaler.scale_factor))
  39. scaler.unscale(gm.attached_tensors())
  40. scaler.update()
  41. opt.step().clear_grad()
  42. return loss
  43. This is useful when need to accumulate grads for multi batches.
  44. """
  45. def __init__(
  46. self,
  47. init_scale: float = 2.0 ** 4,
  48. growth_factor: float = 2.0,
  49. backoff_factor: float = 0.5,
  50. growth_interval: int = 2000,
  51. ):
  52. self.scale_factor = float(init_scale)
  53. self.growth_factor = float(growth_factor)
  54. self.backoff_factor = float(backoff_factor)
  55. self.growth_interval = growth_interval
  56. self._growth_tracker = 0
  57. self._found_non_finite = False
  58. def backward(
  59. self,
  60. gm: GradManager,
  61. y: Union[Tensor, List[Tensor]] = None,
  62. dy: Union[Tensor, List[Tensor]] = None,
  63. *,
  64. unscale_grad: bool = True,
  65. update_scale: bool = "if_unscale_grad"
  66. ):
  67. r"""A wrapper of GradManager's :meth:`~.GradManager.backward`, used to scale
  68. ``y``'s grad and unscale parameters' grads.
  69. Args:
  70. gm: The to be wrapped GradManager.
  71. y: Same as GradManager backward's ``y``.
  72. dy: Same as GradManager backward's ``dy``. Will be multiplied
  73. by ``scale_factor``.
  74. unscale_grad: Whether do :meth:`unscale` at the same time. Could be
  75. ``False`` if needs to accumulate grads.
  76. update_scale: Same as :meth:`unscale`'s ``update``. Will be ignored
  77. if ``unscale_grad`` is ``False``.
  78. """
  79. # These checks should be consistent with GradManager's
  80. if y is None:
  81. ys = []
  82. elif isinstance(y, (tuple, list)):
  83. ys = y
  84. else:
  85. ys = [y]
  86. if dy is None:
  87. dys = [full_like(y, self.scale_factor) for y in ys]
  88. elif isinstance(dy, (tuple, list)):
  89. dys = [dy_ * self.scale_factor for dy_ in dy]
  90. else:
  91. dys = [dy * self.scale_factor]
  92. gm.backward(y=ys, dy=dys)
  93. if unscale_grad:
  94. self.unscale(gm.attached_tensors())
  95. if update_scale:
  96. self.update()
  97. def unscale(self, grad_tensors: Iterable[Tensor]):
  98. r"""Unscale all ``grad_tensors``'s grad.
  99. Args:
  100. grad_tensors: Tensors needed to unscale grads. Should be all tensors
  101. that are affected by ``target`` tensor in GradManager's backward.
  102. """
  103. if self.growth_interval == 0:
  104. # use float64 for better precision
  105. inv_scale = Tensor(1.0 / self.scale_factor)
  106. for tensor in grad_tensors:
  107. if tensor is None or getattr(tensor, "grad", None) is None:
  108. continue
  109. tensor.grad *= inv_scale
  110. return self
  111. # to support tracing, _check_gradients should be applied to every grad.
  112. if self._check_gradients(
  113. [x.grad for x in grad_tensors], 1.0 / self.scale_factor
  114. ):
  115. self._found_non_finite = True
  116. for tensor in grad_tensors:
  117. if tensor is None or getattr(tensor, "grad", None) is None:
  118. continue
  119. tensor.grad = None
  120. return self
  121. def _check_gradients(self, grads, scale):
  122. if len(grads) == 0:
  123. return False
  124. rst = _check_non_finite(grads, scale)
  125. rst = rst.numpy()
  126. return rst
  127. def update(self, new_scale: float = None):
  128. r"""Update the scale factor according to whether encountered overflow grad.
  129. If ``new_scale`` is provided, internal update mechanism will be ignored.
  130. """
  131. if self.growth_interval == 0:
  132. return
  133. if new_scale is not None:
  134. self.scale_factor = float(new_scale)
  135. else:
  136. if self._found_non_finite:
  137. self.scale_factor *= self.backoff_factor
  138. self._growth_tracker = 0
  139. else:
  140. self._growth_tracker += 1
  141. if self._growth_tracker >= self.growth_interval:
  142. self.scale_factor *= self.growth_factor
  143. self._growth_tracker = 0
  144. self._found_non_finite = False
  145. def state_dict(self):
  146. return {
  147. "scale_factor": self.scale_factor,
  148. "growth_factor": self.growth_factor,
  149. "backoff_factor": self.backoff_factor,
  150. "growth_interval": self.growth_interval,
  151. "_growth_tracker": self._growth_tracker,
  152. }
  153. def load_state_dict(self, state):
  154. self.scale_factor = state["scale_factor"]
  155. self.growth_factor = state["growth_factor"]
  156. self.backoff_factor = state["backoff_factor"]
  157. self.growth_interval = state["growth_interval"]
  158. self._growth_tracker = state["_growth_tracker"]