You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

autocast.py 2.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import functools
  9. from ..core.tensor import amp
  10. class autocast:
  11. r"""A class to control autocast mode for amp as a context manager or a decorator.
  12. Args:
  13. enabled: Whether autocast mode is enabled.
  14. low_prec_dtype: Set amp autocast mode's lower precision dtype. It will change
  15. the target dtype in tensor casting for better speed and memory. Default: float16.
  16. high_prec_dtype: Set amp autocast mode's higher precision dtype. It will
  17. change the target dtype in tensor casting for better precision. Default: float32.
  18. Examples:
  19. .. code-block::
  20. # used as decorator
  21. @autocast()
  22. def train_step(image, label):
  23. with gm:
  24. logits = model(image)
  25. loss = F.nn.cross_entropy(logits, label)
  26. gm.backward(loss)
  27. opt.step().clear_grad()
  28. return loss
  29. # used as context manager
  30. def train_step(image, label):
  31. with autocast():
  32. with gm:
  33. logits = model(image)
  34. loss = F.nn.cross_entropy(logits, label)
  35. gm.backward(loss)
  36. opt.step().clear_grad()
  37. return loss
  38. """
  39. def __init__(
  40. self,
  41. enabled: bool = True,
  42. low_prec_dtype: str = "float16",
  43. high_prec_dtype: str = "float32",
  44. ):
  45. self.enabled = enabled
  46. self.high_prec_dtype = high_prec_dtype
  47. self.low_prec_dtype = low_prec_dtype
  48. self._origin_enabled = None
  49. self._origin_high = None
  50. self._origin_low = None
  51. def __enter__(self):
  52. self._origin_enabled, amp._enabled = amp._enabled, self.enabled
  53. self._origin_high = amp._high_prec_dtype
  54. amp._high_prec_dtype = self.high_prec_dtype
  55. self._origin_low = amp._low_prec_dtype
  56. amp._low_prec_dtype = self.low_prec_dtype
  57. def __exit__(self, *args):
  58. amp._enabled = self._origin_enabled
  59. amp._high_prec_dtype = self._origin_high
  60. amp._low_prec_dtype = self._origin_low
  61. def __call__(self, func):
  62. @functools.wraps(func)
  63. def wrapper(*args, **kwargs):
  64. with self:
  65. return func(*args, **kwargs)
  66. return wrapper

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台