You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

autocast.py 2.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import functools
  9. from ..core.tensor import amp
  10. class autocast:
  11. r"""A class to control autocast mode for amp as a context manager or a decorator.
  12. Args:
  13. enabled: Whether autocast mode is enabled.
  14. low_prec_dtype: Set amp autocast mode's lower precision dtype. It will change
  15. the target dtype in tensor casting for better speed and memory. Default: float16.
  16. high_prec_dtype: Set amp autocast mode's higher precision dtype. It will
  17. change the target dtype in tensor casting for better precision. Default: float32.
  18. Examples:
  19. .. code-block::
  20. # used as decorator
  21. @autocast()
  22. def train_step(image, label):
  23. with gm:
  24. logits = model(image)
  25. loss = F.nn.cross_entropy(logits, label)
  26. gm.backward(loss)
  27. opt.step().clear_grad()
  28. return loss
  29. # used as context manager
  30. def train_step(image, label):
  31. with autocast():
  32. with gm:
  33. logits = model(image)
  34. loss = F.nn.cross_entropy(logits, label)
  35. gm.backward(loss)
  36. opt.step().clear_grad()
  37. return loss
  38. """
  39. def __init__(
  40. self,
  41. enabled: bool = True,
  42. low_prec_dtype: str = "float16",
  43. high_prec_dtype: str = "float32",
  44. ):
  45. self.enabled = enabled
  46. self.high_prec_dtype = high_prec_dtype
  47. self.low_prec_dtype = low_prec_dtype
  48. self._origin_enabled = None
  49. self._origin_high = None
  50. self._origin_low = None
  51. def __enter__(self):
  52. self._origin_enabled = amp._enabled
  53. self._origin_high = amp._get_amp_high_prec_dtype()
  54. self._origin_low = amp._get_amp_low_prec_dtype()
  55. amp._enabled = self.enabled
  56. amp._set_amp_dtype_autocast(self.enabled)
  57. amp._set_amp_high_prec_dtype(self.high_prec_dtype)
  58. amp._set_amp_low_prec_dtype(self.low_prec_dtype)
  59. def __exit__(self, *args):
  60. amp._enabled = self._origin_enabled
  61. amp._set_amp_dtype_autocast(self._origin_enabled)
  62. amp._set_amp_high_prec_dtype(self._origin_high)
  63. amp._set_amp_low_prec_dtype(self._origin_low)
  64. def __call__(self, func):
  65. @functools.wraps(func)
  66. def wrapper(*args, **kwargs):
  67. with self:
  68. return func(*args, **kwargs)
  69. return wrapper