You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_scaler.py 6.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. from typing import Iterable, List, Union
  9. import numpy as np
  10. from ..autodiff import GradManager
  11. from ..functional import full_like
  12. from ..functional.math import _check_non_finite
  13. from ..tensor import Tensor
  14. class GradScaler:
  15. r"""A helper class that performs grad scaling to prevent from data overflow in
  16. :class:`~.autocast` mode.
  17. Args:
  18. init_scale: Initial scale factor.
  19. growth_factor: Factor that the scale is multiplied by in actual
  20. :meth:`update` stage. If growth_factor is 0, scale_factor will not update.
  21. backoff_factor: Factor that the scale is multiplied by when encountering
  22. overflow grad.
  23. growth_interval: The interval between two scale update stages.
  24. Example:
  25. .. code-block::
  26. gm = GradManager()
  27. opt = ...
  28. scaler = GradScaler()
  29. gm.attach(model.parameters())
  30. @autocast()
  31. def train_step(image, label):
  32. with gm:
  33. logits = model(image)
  34. loss = F.nn.cross_entropy(logits, label)
  35. scaler.backward(gm, loss)
  36. opt.step().clear_grad()
  37. return loss
  38. If need more flexible usage, could split ``scaler.backward`` into three lines:
  39. .. code-block::
  40. @autocast()
  41. def train_step(image, label):
  42. with gm:
  43. logits = model(image)
  44. loss = F.nn.cross_entropy(logits, label)
  45. gm.backward(loss, dy=megengine.tensor(scaler.scale_factor))
  46. scaler.unscale(gm.attached_tensors())
  47. scaler.update()
  48. opt.step().clear_grad()
  49. return loss
  50. This is useful when need to accumulate grads for multi batches.
  51. """
  52. def __init__(
  53. self,
  54. init_scale: float = 2.0 ** 4,
  55. growth_factor: float = 2.0,
  56. backoff_factor: float = 0.5,
  57. growth_interval: int = 2000,
  58. ):
  59. self.scale_factor = float(init_scale)
  60. self.growth_factor = float(growth_factor)
  61. self.backoff_factor = float(backoff_factor)
  62. self.growth_interval = growth_interval
  63. self._growth_tracker = 0
  64. self._found_non_finite = False
  65. def backward(
  66. self,
  67. gm: GradManager,
  68. y: Union[Tensor, List[Tensor]] = None,
  69. dy: Union[Tensor, List[Tensor]] = None,
  70. *,
  71. unscale_grad: bool = True,
  72. update_scale: bool = "if_unscale_grad"
  73. ):
  74. r"""A wrapper of GradManager's :meth:`~.GradManager.backward`, used to scale
  75. ``y``'s grad and unscale parameters' grads.
  76. Args:
  77. gm: The to be wrapped GradManager.
  78. y: Same as GradManager backward's ``y``.
  79. dy: Same as GradManager backward's ``dy``. Will be multiplied
  80. by ``scale_factor``.
  81. unscale_grad: Whether do :meth:`unscale` at the same time. Could be
  82. ``False`` if needs to accumulate grads.
  83. update_scale: Same as :meth:`unscale`'s ``update``. Will be ignored
  84. if ``unscale_grad`` is ``False``.
  85. """
  86. # These checks should be consistent with GradManager's
  87. if y is None:
  88. ys = []
  89. elif isinstance(y, (tuple, list)):
  90. ys = y
  91. else:
  92. ys = [y]
  93. if dy is None:
  94. dys = [full_like(y, self.scale_factor) for y in ys]
  95. elif isinstance(dy, (tuple, list)):
  96. dys = [dy_ * self.scale_factor for dy_ in dy]
  97. else:
  98. dys = [dy * self.scale_factor]
  99. gm.backward(y=ys, dy=dys)
  100. if unscale_grad:
  101. self.unscale(gm.attached_tensors())
  102. if update_scale:
  103. self.update()
  104. def unscale(self, grad_tensors: Iterable[Tensor]):
  105. r"""Unscale all ``grad_tensors``'s grad.
  106. Args:
  107. grad_tensors: Tensors needed to unscale grads. Should be all tensors
  108. that are affected by ``target`` tensor in GradManager's backward.
  109. """
  110. # to support tracing, _check_gradients should be applied to every grad.
  111. if self._check_gradients([x.grad for x in grad_tensors]):
  112. self._found_non_finite = True
  113. if self._found_non_finite:
  114. for tensor in grad_tensors:
  115. if tensor is None or getattr(tensor, "grad", None) is None:
  116. continue
  117. tensor.grad = None
  118. else:
  119. # use float64 for better precision
  120. inv_scale = Tensor(1.0 / self.scale_factor)
  121. for tensor in grad_tensors:
  122. if tensor is None or getattr(tensor, "grad", None) is None:
  123. continue
  124. tensor.grad *= inv_scale
  125. return self
  126. def _check_gradients(self, grad):
  127. if self.growth_interval == 0:
  128. return False
  129. return _check_non_finite(grad)
  130. def update(self, new_scale: float = None):
  131. r"""Update the scale factor according to whether encountered overflow grad.
  132. If ``new_scale`` is provided, internal update mechanism will be ignored.
  133. """
  134. if self.growth_interval == 0:
  135. return
  136. if new_scale is not None:
  137. self.scale_factor = float(new_scale)
  138. else:
  139. if self._found_non_finite:
  140. self.scale_factor *= self.backoff_factor
  141. self._growth_tracker = 0
  142. else:
  143. self._growth_tracker += 1
  144. if self._growth_tracker >= self.growth_interval:
  145. self.scale_factor *= self.growth_factor
  146. self._growth_tracker = 0
  147. self._found_non_finite = False
  148. def state_dict(self):
  149. return {
  150. "scale_factor": self.scale_factor,
  151. "growth_factor": self.growth_factor,
  152. "backoff_factor": self.backoff_factor,
  153. "growth_interval": self.growth_interval,
  154. "_growth_tracker": self._growth_tracker,
  155. }
  156. def load_state_dict(self, state):
  157. self.scale_factor = state["scale_factor"]
  158. self.growth_factor = state["growth_factor"]
  159. self.backoff_factor = state["backoff_factor"]
  160. self.growth_interval = state["growth_interval"]
  161. self._growth_tracker = state["_growth_tracker"]

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台