You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_manager.py 6.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. from collections import defaultdict
  2. from contextlib import contextmanager
  3. from typing import Callable
  4. from ..core.autodiff.grad import Grad
  5. from ..logger import get_logger
  6. from ..tensor import Tensor
  7. from ..utils.future import Future
  8. logger = get_logger(__name__)
  9. backwarding_grad_manager = None
  10. def get_backwarding_grad_manager():
  11. return backwarding_grad_manager
  12. class GradManager:
  13. r"""
  14. GradManager manages auto differentiation and all resources required to perform it.
  15. Our auto differentiation framework requires that the user explicitly indicates when
  16. the forward operations start and when all resources should be released. A typical usage of
  17. GradManager is as follows:
  18. .. code-block::
  19. gm = GradManager()
  20. gm.attach(model.parameters())
  21. with gm:
  22. # forward operations
  23. ...
  24. # backward gradients
  25. gm.backward(loss)
  26. You can also use ``record()`` and ``release()`` method instead of ``with`` context:
  27. .. code-block::
  28. gm = GradManager()
  29. gm.attach(model.parameters())
  30. gm.record()
  31. # forward operations
  32. ...
  33. # backward gradients
  34. gm.backward(loss)
  35. gm.release()
  36. Typically, in data parallel, we would like to average the gradients across
  37. processes. Users will finally get the averaged gradients if an "AllReduce"
  38. callback is registered as follows:
  39. .. code-block::
  40. import megengine.distributed as dist
  41. gm = GradManager()
  42. gm.attach(model.parameters(), callback=dist.make_allreduce_cb("MEAN"))
  43. """
  44. def __init__(self):
  45. self._call_back_dict = defaultdict(list)
  46. self._param_dict = dict()
  47. self._recording = False
  48. self._grad = None
  49. self._after_backward_callback = []
  50. self._gradients = dict()
  51. def attach(self, params: list, callbacks=None):
  52. r"""
  53. Registers parameters that gradients should be calculated with respect to.
  54. Callback Functions should have a signature like this:
  55. .. code-block::
  56. def cb(param: Tensor, grad: Tensor) -> Tensor:
  57. # do something
  58. return grad
  59. :param params: to be registered parameters
  60. :param callbacks: list of callback functions
  61. """
  62. if callbacks is None:
  63. callbacks = []
  64. if isinstance(callbacks, Callable):
  65. callbacks = [callbacks]
  66. if isinstance(params, Tensor):
  67. params = [params]
  68. for p in params:
  69. self._param_dict[id(p)] = p
  70. for cb in callbacks:
  71. self._call_back_dict[id(p)].append(cb)
  72. if self._grad is not None:
  73. for p in params:
  74. self._record_param(id(p))
  75. return self
  76. def detach(self, params: list):
  77. r"""
  78. Remove specific registered parameters and callback functions.
  79. :param params: registered parameters
  80. """
  81. if isinstance(params, Tensor):
  82. params = [params]
  83. for idx, param in enumerate(params):
  84. if id(param) in self._param_dict:
  85. self._param_dict.pop(id(param))
  86. self._call_back_dict.pop(id(param))
  87. else:
  88. logger.warning("params with index {} is not attached.".format(idx))
  89. def _register_after_backward_callback(self, callback):
  90. self._after_backward_callback.append(callback)
  91. return self
  92. def backward(self, ys=None, dys=None):
  93. r"""
  94. Performs back-propagation and computes gradients.
  95. :param ys: outputs of forward operators, e.g., the loss tensor
  96. :param dys: derivatives of ys
  97. """
  98. from ..functional import ones_like
  99. global backwarding_grad_manager
  100. cache = backwarding_grad_manager
  101. backwarding_grad_manager = self
  102. if not self._recording:
  103. raise RuntimeError(
  104. "no computation history. "
  105. "did you forget record() or "
  106. "call a method that clears the history?"
  107. )
  108. assert self._grad is not None
  109. if ys is None:
  110. ys = []
  111. if not isinstance(ys, (tuple, list)):
  112. ys = [ys]
  113. if dys is None:
  114. dys = [ones_like(y) for y in ys]
  115. if not isinstance(dys, (tuple, list)):
  116. dys = [dys]
  117. try:
  118. self._grad(ys, dys)
  119. for callback in self._after_backward_callback:
  120. callback()
  121. for p, grad in self._gradients.items():
  122. if isinstance(grad, Future):
  123. grad = grad.get()
  124. param = self._param_dict[p]
  125. if param.grad is None:
  126. param.grad = grad
  127. else:
  128. param.grad += grad
  129. finally:
  130. self.release()
  131. backwarding_grad_manager = cache
  132. def record(self):
  133. r"""
  134. Starts recording forward operations.
  135. """
  136. if self._recording:
  137. raise RuntimeError("already recording")
  138. grad = Grad()
  139. self._recording = True
  140. self._grad = grad
  141. for param_id in self._param_dict.keys():
  142. self._record_param(param_id)
  143. grad.__enter__()
  144. def _record_param(self, param_id):
  145. param_wrapper = self._param_dict[param_id]
  146. callbacks = self._call_back_dict[param_id]
  147. def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self):
  148. ret = grad
  149. for cb in callbacks:
  150. ret = cb(param, ret)
  151. gm._gradients[id(p)] = ret
  152. # NOTE: override prev callback wrt when called serval times
  153. self._grad.wrt(param_wrapper, callback=callback)
  154. def release(self):
  155. r"""
  156. Stops recording and releases resources for gradients calculation.
  157. """
  158. if self._grad is not None:
  159. self._grad.__exit__(None, None, None)
  160. self._grad = None
  161. self._recording = False
  162. self._gradients = dict()
  163. def __enter__(self):
  164. self.record()
  165. return self
  166. def __exit__(self, exc_type, exc_val, exc_tb):
  167. self.release()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台