You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

autocast.py 2.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. import functools
  2. from ..core import _config
  3. from ..core.tensor import amp
  4. class autocast:
  5. r"""A class to control autocast mode for amp as a context manager or a decorator.
  6. Args:
  7. enabled: Whether autocast mode is enabled.
  8. low_prec_dtype: Set amp autocast mode's lower precision dtype. It will change
  9. the target dtype in tensor casting for better speed and memory. Default: float16.
  10. high_prec_dtype: Set amp autocast mode's higher precision dtype. It will
  11. change the target dtype in tensor casting for better precision. Default: float32.
  12. Examples:
  13. .. code-block::
  14. # used as decorator
  15. @autocast()
  16. def train_step(image, label):
  17. with gm:
  18. logits = model(image)
  19. loss = F.nn.cross_entropy(logits, label)
  20. gm.backward(loss)
  21. opt.step().clear_grad()
  22. return loss
  23. # used as context manager
  24. def train_step(image, label):
  25. with autocast():
  26. with gm:
  27. logits = model(image)
  28. loss = F.nn.cross_entropy(logits, label)
  29. gm.backward(loss)
  30. opt.step().clear_grad()
  31. return loss
  32. """
  33. def __init__(
  34. self,
  35. enabled: bool = True,
  36. low_prec_dtype: str = "float16",
  37. high_prec_dtype: str = "float32",
  38. ):
  39. self.enabled = enabled
  40. self.high_prec_dtype = high_prec_dtype
  41. self.low_prec_dtype = low_prec_dtype
  42. self._origin_enabled = None
  43. self._origin_high = None
  44. self._origin_low = None
  45. self._origin_configs = None
  46. def __enter__(self):
  47. self._origin_enabled = amp._enabled
  48. amp._enabled = self.enabled
  49. amp._set_amp_dtype_autocast(self.enabled)
  50. if not self.enabled:
  51. return
  52. self._origin_high = amp._get_amp_high_prec_dtype()
  53. self._origin_low = amp._get_amp_low_prec_dtype()
  54. amp._set_amp_high_prec_dtype(self.high_prec_dtype)
  55. amp._set_amp_low_prec_dtype(self.low_prec_dtype)
  56. self._origin_configs = _config._reset_execution_config(compute_mode="float32")
  57. def __exit__(self, *args):
  58. amp._enabled = self._origin_enabled
  59. amp._set_amp_dtype_autocast(self._origin_enabled)
  60. if not self.enabled:
  61. return
  62. amp._set_amp_high_prec_dtype(self._origin_high)
  63. amp._set_amp_low_prec_dtype(self._origin_low)
  64. def __call__(self, func):
  65. @functools.wraps(func)
  66. def wrapper(*args, **kwargs):
  67. if not self.enabled:
  68. return func(*args, **kwargs)
  69. with self:
  70. return func(*args, **kwargs)
  71. return wrapper