You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

autocast.py 2.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. import functools
  2. from ..core import _config
  3. from ..core.tensor import amp
  4. class autocast:
  5. r"""A class to control autocast mode for amp as a context manager or a decorator.
  6. Args:
  7. enabled: Whether autocast mode is enabled.
  8. low_prec_dtype: Set amp autocast mode's lower precision dtype. It will change
  9. the target dtype in tensor casting for better speed and memory. Default: float16.
  10. high_prec_dtype: Set amp autocast mode's higher precision dtype. It will
  11. change the target dtype in tensor casting for better precision. Default: float32.
  12. Examples:
  13. .. code-block::
  14. # used as decorator
  15. @autocast()
  16. def train_step(image, label):
  17. with gm:
  18. logits = model(image)
  19. loss = F.nn.cross_entropy(logits, label)
  20. gm.backward(loss)
  21. opt.step().clear_grad()
  22. return loss
  23. # used as context manager
  24. def train_step(image, label):
  25. with autocast():
  26. with gm:
  27. logits = model(image)
  28. loss = F.nn.cross_entropy(logits, label)
  29. gm.backward(loss)
  30. opt.step().clear_grad()
  31. return loss
  32. """
  33. def __init__(
  34. self,
  35. enabled: bool = True,
  36. low_prec_dtype: str = "float16",
  37. high_prec_dtype: str = "float32",
  38. ):
  39. self.enabled = enabled
  40. self.high_prec_dtype = high_prec_dtype
  41. self.low_prec_dtype = low_prec_dtype
  42. self._origin_enabled = None
  43. self._origin_high = None
  44. self._origin_low = None
  45. self._origin_configs = None
  46. def __enter__(self):
  47. if self.enabled:
  48. self._origin_enabled = amp._enabled
  49. self._origin_high = amp._get_amp_high_prec_dtype()
  50. self._origin_low = amp._get_amp_low_prec_dtype()
  51. amp._enabled = self.enabled
  52. amp._set_amp_dtype_autocast(self.enabled)
  53. amp._set_amp_high_prec_dtype(self.high_prec_dtype)
  54. amp._set_amp_low_prec_dtype(self.low_prec_dtype)
  55. self._origin_configs = _config._reset_execution_config(
  56. compute_mode="float32"
  57. )
  58. def __exit__(self, *args):
  59. if self.enabled:
  60. amp._enabled = self._origin_enabled
  61. amp._set_amp_dtype_autocast(self._origin_enabled)
  62. amp._set_amp_high_prec_dtype(self._origin_high)
  63. amp._set_amp_low_prec_dtype(self._origin_low)
  64. _config._reset_execution_config(*self._origin_configs)
  65. def __call__(self, func):
  66. @functools.wraps(func)
  67. def wrapper(*args, **kwargs):
  68. with self:
  69. return func(*args, **kwargs)
  70. return wrapper