You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_manager.py 5.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. from collections import defaultdict
  2. from contextlib import contextmanager
  3. from typing import Callable
  4. from ..core.autodiff.grad import Grad
  5. from ..tensor import tensor
  6. from ..utils.future import Future
  7. backwarding_grad_manager = None
  8. def get_backwarding_grad_manager():
  9. return backwarding_grad_manager
  10. class GradManager:
  11. r"""GradManager manages auto differentiation and all resources required to perform it.
  12. Our auto differentiation framework requires that the user explicitly indicates when
  13. the forward operations start and when all resources should be released. A typical usage of
  14. GradManager is as follows:
  15. .. code-block::
  16. gm = GradManager()
  17. gm.attach(model.parameters())
  18. with gm:
  19. # forward operations
  20. ...
  21. # backward gradients
  22. gm.backward(loss)
  23. You can also use `record()` and `release()` method instead of `with` context:
  24. .. code-block::
  25. gm = GradManager()
  26. gm.attach(model.parameters())
  27. gm.record()
  28. # forward operations
  29. ...
  30. # backward gradients
  31. gm.backward(loss)
  32. gm.release()
  33. Typically, in data parallel, we would like to average the gradients across
  34. processes. Users will finally get the averaged gradients if an "AllReduce"
  35. callback is registered as follows:
  36. .. code-block::
  37. import megengine.distributed as dist
  38. gm = GradManager()
  39. gm.attach(model.parameters(), callback=dist.make_allreduce_cb("MEAN"))
  40. """
  41. def __init__(self):
  42. self._call_back_dict = defaultdict(list)
  43. self._param_dict = dict()
  44. self._recording = False
  45. self._grad = None
  46. self._after_backward_callback = []
  47. self._gradients = dict()
  48. def attach(self, params, callbacks=None):
  49. r"""Registers parameters that gradients should be calculated with respect to.
  50. Callback Functions should have a signature like this:
  51. .. code-block::
  52. def cb(param: Tensor, grad: Tensor) -> Tensor:
  53. # do something
  54. return grad
  55. :param params: registered parameters
  56. :param callbacks: list of callback functions
  57. """
  58. if callbacks is None:
  59. callbacks = []
  60. if isinstance(callbacks, Callable):
  61. callbacks = [callbacks]
  62. for p in params:
  63. self._param_dict[id(p)] = p
  64. for cb in callbacks:
  65. self._call_back_dict[id(p)].append(cb)
  66. return self
  67. def _register_after_backward_callback(self, callback):
  68. self._after_backward_callback.append(callback)
  69. return self
  70. def backward(self, ys, dys=None):
  71. r"""Performs back-propagation and computes gradients.
  72. :param ys: outputs of forward operators, e.g., the loss tensor
  73. :param dys: derivatives of ys
  74. """
  75. from ..functional import ones_like
  76. global backwarding_grad_manager
  77. cache = backwarding_grad_manager
  78. backwarding_grad_manager = self
  79. if not self._recording:
  80. raise RuntimeError(
  81. "no computation history. "
  82. "did you forget record() or "
  83. "call a method that clears the history?"
  84. )
  85. assert self._grad is not None
  86. if not isinstance(ys, (tuple, list)):
  87. ys = [ys]
  88. if dys is None:
  89. dys = [ones_like(y) for y in ys]
  90. if not isinstance(dys, (tuple, list)):
  91. dys = [dys]
  92. try:
  93. self._grad(ys, dys)
  94. for callback in self._after_backward_callback:
  95. callback()
  96. for p, grad in self._gradients.items():
  97. if isinstance(grad, Future):
  98. grad = grad.get()
  99. param = self._param_dict[p]
  100. if param.grad is None:
  101. param.grad = grad
  102. else:
  103. param.grad += grad
  104. finally:
  105. self._stop_record()
  106. backwarding_grad_manager = cache
  107. def record(self):
  108. r"""Starts recording forward operations.
  109. """
  110. if self._recording:
  111. raise RuntimeError("already recording")
  112. grad = Grad()
  113. self._recording = True
  114. self._grad = grad
  115. for param_id in self._param_dict.keys():
  116. param_wrapper = self._param_dict[param_id]
  117. callbacks = self._call_back_dict[param_id]
  118. def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self):
  119. ret = grad
  120. for cb in callbacks:
  121. ret = cb(param, ret)
  122. gm._gradients[id(p)] = ret
  123. grad.wrt(param_wrapper, callback=callback)
  124. grad.__enter__()
  125. def release(self):
  126. r"""Stops recording and releases resources for gradients calculation.
  127. """
  128. if not self._recording:
  129. raise RuntimeError("not recording")
  130. self._stop_record()
  131. def _stop_record(self):
  132. if self._grad is not None:
  133. self._grad.__exit__(None, None, None)
  134. self._recording = False
  135. self._grad = None
  136. self._gradients = dict()
  137. def __enter__(self):
  138. self.record()
  139. return self
  140. def __exit__(self, exc_type, exc_val, exc_tb):
  141. self._stop_record()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台