You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_scaler.py 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from typing import Iterable, List, Union
  9. import numpy as np
  10. from ..autodiff import GradManager
  11. from ..functional import full_like
  12. from ..functional.math import _has_inf
  13. from ..tensor import Tensor
  14. class GradScaler:
  15. r"""
  16. A helper class that performs grad scaling to prevent from data overflow in
  17. :class:`~.autocast` mode.
  18. :param init_scale: Initial scale factor.
  19. :param growth_factor: Factor that the scale is multiplied by in actual
  20. :meth:`update` stage. If growth_factor is 0, scale_factor will not update.
  21. :param backoff_factor: Factor that the scale is multiplied by when encountering
  22. overflow grad.
  23. :param growth_interval: The interval between two scale update stages.
  24. Example::
  25. gm = GradManager()
  26. opt = ...
  27. scaler = GradScaler()
  28. gm.attach(model.parameters())
  29. @autocast()
  30. def train_step(image, label):
  31. with gm:
  32. logits = model(image)
  33. loss = F.nn.cross_entropy(logits, label)
  34. scaler.backward(gm, loss)
  35. opt.step().clear_grad()
  36. return loss
  37. If need more flexible usage, could split ``scaler.backward`` into three lines:
  38. .. code-block::
  39. @autocast()
  40. def train_step(image, label):
  41. with gm:
  42. logits = model(image)
  43. loss = F.nn.cross_entropy(logits, label)
  44. gm.backward(loss, dy=megengine.tensor(scaler.scale_factor))
  45. scaler.unscale(gm.attached_tensors())
  46. scaler.update()
  47. opt.step().clear_grad()
  48. return loss
  49. This is useful when need to accumulate grads for multi batches.
  50. """
  51. def __init__(
  52. self,
  53. init_scale: float = 2.0 ** 4,
  54. growth_factor: float = 2.0,
  55. backoff_factor: float = 0.5,
  56. growth_interval: int = 2000,
  57. ):
  58. self.scale_factor = float(init_scale)
  59. self.growth_factor = float(growth_factor)
  60. self.backoff_factor = float(backoff_factor)
  61. self.growth_interval = growth_interval
  62. self._growth_tracker = 0
  63. self._found_inf = False
  64. def backward(
  65. self,
  66. gm: GradManager,
  67. y: Union[Tensor, List[Tensor]] = None,
  68. dy: Union[Tensor, List[Tensor]] = None,
  69. *,
  70. unscale_grad: bool = True,
  71. update_scale: bool = "if_unscale_grad"
  72. ):
  73. r"""
  74. A wrapper of GradManager's :meth:`~.GradManager.backward`, used to scale
  75. ``y``'s grad and unscale parameters' grads.
  76. :param gm: The to be wrapped GradManager.
  77. :param y: Same as GradManager backward's ``y``.
  78. :param dy: Same as GradManager backward's ``dy``. Will be multiplied
  79. by ``scale_factor``.
  80. :param unscale_grad: Whether do :meth:`unscale` at the same time. Could be
  81. ``False`` if needs to accumulate grads.
  82. :param update_scale: Same as :meth:`unscale`'s ``update``. Will be ignored
  83. if ``unscale_grad`` is ``False``.
  84. """
  85. # These checks should be consistent with GradManager's
  86. if y is None:
  87. ys = []
  88. elif isinstance(y, (tuple, list)):
  89. ys = y
  90. else:
  91. ys = [y]
  92. if dy is None:
  93. dys = [full_like(y, self.scale_factor) for y in ys]
  94. elif isinstance(dy, (tuple, list)):
  95. dys = [dy_ * self.scale_factor for dy_ in dy]
  96. else:
  97. dys = [dy * self.scale_factor]
  98. gm.backward(y=ys, dy=dys)
  99. if unscale_grad:
  100. self.unscale(gm.attached_tensors())
  101. if update_scale:
  102. self.update()
  103. def unscale(self, grad_tensors: Iterable[Tensor]):
  104. r"""
  105. Unscale all ``grad_tensors``'s grad.
  106. :param grad_tensors: Tensors needed to unscale grads. Should be all tensors
  107. that are affected by ``target`` tensor in GradManager's backward.
  108. """
  109. # use float64 for better precision
  110. inv_scale = Tensor(1.0 / self.scale_factor)
  111. for tensor in grad_tensors:
  112. if tensor is None or getattr(tensor, "grad", None) is None:
  113. continue
  114. # to support tracing, _check_gradients should be applied to every grad.
  115. if self._check_gradients(tensor.grad):
  116. self._found_inf = True
  117. tensor.grad *= inv_scale
  118. if self._found_inf:
  119. for tensor in grad_tensors:
  120. if tensor is None or getattr(tensor, "grad", None) is None:
  121. continue
  122. tensor.grad = None
  123. return self
  124. def _check_gradients(self, grad):
  125. if self.growth_interval == 0:
  126. return False
  127. return _has_inf(grad)
  128. def update(self, new_scale: float = None):
  129. r"""Update the scale factor according to whether encountered overflow grad.
  130. If ``new_scale`` is provided, internal update mechanism will be ignored."""
  131. if self.growth_interval == 0:
  132. return
  133. if new_scale is not None:
  134. self.scale_factor = float(new_scale)
  135. else:
  136. if self._found_inf:
  137. self.scale_factor *= self.backoff_factor
  138. self._growth_tracker = 0
  139. else:
  140. self._growth_tracker += 1
  141. if self._growth_tracker >= self.growth_interval:
  142. self.scale_factor *= self.growth_factor
  143. self._growth_tracker = 0
  144. self._found_inf = False
  145. def state_dict(self):
  146. return {
  147. "scale_factor": self.scale_factor,
  148. "growth_factor": self.growth_factor,
  149. "backoff_factor": self.backoff_factor,
  150. "growth_interval": self.growth_interval,
  151. "_growth_tracker": self._growth_tracker,
  152. }
  153. def load_state_dict(self, state):
  154. self.scale_factor = state["scale_factor"]
  155. self.growth_factor = state["growth_factor"]
  156. self.backoff_factor = state["backoff_factor"]
  157. self.growth_interval = state["growth_interval"]
  158. self._growth_tracker = state["_growth_tracker"]

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台