You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_manager.py 5.7 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191
  1. from collections import defaultdict
  2. from contextlib import contextmanager
  3. from typing import Callable
  4. from ..core.autodiff.grad import Grad
  5. from ..logger import get_logger
  6. from ..tensor import Tensor
  7. from ..utils.future import Future
  8. logger = get_logger(__name__)
  9. backwarding_grad_manager = None
  10. def get_backwarding_grad_manager():
  11. return backwarding_grad_manager
  12. class GradManager:
  13. r"""
  14. GradManager manages auto differentiation and all resources required to perform it.
  15. Our auto differentiation framework requires that the user explicitly indicates when
  16. the forward operations start and when all resources should be released. A typical usage of
  17. GradManager is as follows:
  18. .. code-block::
  19. gm = GradManager()
  20. gm.attach(model.parameters())
  21. with gm:
  22. # forward operations
  23. ...
  24. # backward gradients
  25. gm.backward(loss)
  26. You can also use ``record()`` and ``release()`` method instead of ``with`` context:
  27. .. code-block::
  28. gm = GradManager()
  29. gm.attach(model.parameters())
  30. gm.record()
  31. # forward operations
  32. ...
  33. # backward gradients
  34. gm.backward(loss)
  35. gm.release()
  36. Typically, in data parallel, we would like to average the gradients across
  37. processes. Users will finally get the averaged gradients if an "AllReduce"
  38. callback is registered as follows:
  39. .. code-block::
  40. import megengine.distributed as dist
  41. gm = GradManager()
  42. gm.attach(model.parameters(), callback=dist.make_allreduce_cb("MEAN"))
  43. """
  44. def __init__(self):
  45. self._call_back_dict = defaultdict(list)
  46. self._param_dict = dict()
  47. self._recording = False
  48. self._grad = None
  49. self._after_backward_callback = []
  50. self._gradients = dict()
  51. def attach(self, params: list, callbacks=None):
  52. r"""
  53. Registers parameters that gradients should be calculated with respect to.
  54. Callback Functions should have a signature like this:
  55. .. code-block::
  56. def cb(param: Tensor, grad: Tensor) -> Tensor:
  57. # do something
  58. return grad
  59. :param params: to be registered parameters
  60. :param callbacks: list of callback functions
  61. """
  62. if callbacks is None:
  63. callbacks = []
  64. if isinstance(callbacks, Callable):
  65. callbacks = [callbacks]
  66. if isinstance(params, Tensor):
  67. params = [params]
  68. for p in params:
  69. self._param_dict[id(p)] = p
  70. for cb in callbacks:
  71. self._call_back_dict[id(p)].append(cb)
  72. if self._grad is not None:
  73. for p in params:
  74. self._record_param(id(p))
  75. return self
  76. def _register_after_backward_callback(self, callback):
  77. self._after_backward_callback.append(callback)
  78. return self
  79. def backward(self, ys=None, dys=None):
  80. r"""
  81. Performs back-propagation and computes gradients.
  82. :param ys: outputs of forward operators, e.g., the loss tensor
  83. :param dys: derivatives of ys
  84. """
  85. from ..functional import ones_like
  86. global backwarding_grad_manager
  87. cache = backwarding_grad_manager
  88. backwarding_grad_manager = self
  89. if not self._recording:
  90. raise RuntimeError(
  91. "no computation history. "
  92. "did you forget record() or "
  93. "call a method that clears the history?"
  94. )
  95. assert self._grad is not None
  96. if ys is None:
  97. ys = []
  98. if not isinstance(ys, (tuple, list)):
  99. ys = [ys]
  100. if dys is None:
  101. dys = [ones_like(y) for y in ys]
  102. if not isinstance(dys, (tuple, list)):
  103. dys = [dys]
  104. try:
  105. self._grad(ys, dys)
  106. for callback in self._after_backward_callback:
  107. callback()
  108. for p, grad in self._gradients.items():
  109. if isinstance(grad, Future):
  110. grad = grad.get()
  111. param = self._param_dict[p]
  112. if param.grad is None:
  113. param.grad = grad
  114. else:
  115. param.grad += grad
  116. finally:
  117. self.release()
  118. backwarding_grad_manager = cache
  119. def record(self):
  120. r"""
  121. Starts recording forward operations.
  122. """
  123. if self._recording:
  124. raise RuntimeError("already recording")
  125. grad = Grad()
  126. self._recording = True
  127. self._grad = grad
  128. for param_id in self._param_dict.keys():
  129. self._record_param(param_id)
  130. grad.__enter__()
  131. def _record_param(self, param_id):
  132. param_wrapper = self._param_dict[param_id]
  133. callbacks = self._call_back_dict[param_id]
  134. def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self):
  135. ret = grad
  136. for cb in callbacks:
  137. ret = cb(param, ret)
  138. gm._gradients[id(p)] = ret
  139. # NOTE: override prev callback wrt when called serval times
  140. self._grad.wrt(param_wrapper, callback=callback)
  141. def release(self):
  142. r"""
  143. Stops recording and releases resources for gradients calculation.
  144. """
  145. if self._grad is not None:
  146. self._grad.__exit__(None, None, None)
  147. self._grad = None
  148. self._recording = False
  149. self._gradients = dict()
  150. def __enter__(self):
  151. self.record()
  152. return self
  153. def __exit__(self, exc_type, exc_val, exc_tb):
  154. self.release()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台