You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_manager.py 6.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. from collections import defaultdict
  2. from contextlib import contextmanager
  3. from typing import Callable
  4. from ..core.autodiff.grad import Grad
  5. from ..logger import get_logger
  6. from ..tensor import Tensor
  7. from ..utils.future import Future
  8. logger = get_logger(__name__)
  9. backwarding_grad_manager = None
  10. def get_backwarding_grad_manager():
  11. return backwarding_grad_manager
  12. class GradManager:
  13. r"""
  14. GradManager manages auto differentiation and all resources required to perform it.
  15. Our auto differentiation framework requires that the user explicitly indicates when
  16. the forward operations start and when all resources should be released. A typical usage of
  17. GradManager is as follows:
  18. .. code-block::
  19. gm = GradManager()
  20. gm.attach(model.parameters())
  21. with gm:
  22. # forward operations
  23. ...
  24. # backward gradients
  25. gm.backward(loss)
  26. You can also use ``record()`` and ``release()`` method instead of ``with`` context:
  27. .. code-block::
  28. gm = GradManager()
  29. gm.attach(model.parameters())
  30. gm.record()
  31. # forward operations
  32. ...
  33. # backward gradients
  34. gm.backward(loss)
  35. gm.release()
  36. Typically, in data parallel, we would like to average the gradients across
  37. processes. Users will finally get the averaged gradients if an "AllReduce"
  38. callback is registered as follows:
  39. .. code-block::
  40. import megengine.distributed as dist
  41. gm = GradManager()
  42. gm.attach(model.parameters(), callback=dist.make_allreduce_cb("MEAN"))
  43. """
  44. def __init__(self):
  45. self._call_back_dict = defaultdict(list)
  46. self._param_dict = dict()
  47. self._recording = False
  48. self._grad = None
  49. self._after_backward_callback = []
  50. self._gradients = dict()
  51. def attach(self, params: list, callbacks=None):
  52. r"""
  53. Registers parameters that gradients should be calculated with respect to.
  54. Callback Functions should have a signature like this:
  55. .. code-block::
  56. def cb(param: Tensor, grad: Tensor) -> Tensor:
  57. # do something
  58. return grad
  59. :param params: to be registered parameters
  60. :param callbacks: list of callback functions
  61. """
  62. if callbacks is None:
  63. callbacks = []
  64. if isinstance(callbacks, Callable):
  65. callbacks = [callbacks]
  66. if isinstance(params, Tensor):
  67. params = [params]
  68. for p in params:
  69. self._param_dict[id(p)] = p
  70. for cb in callbacks:
  71. self._call_back_dict[id(p)].append(cb)
  72. if self._grad is not None:
  73. for p in params:
  74. self._record_param(id(p))
  75. return self
  76. def detach(self, params: list):
  77. r"""
  78. Remove specific registered parameters and callback functions.
  79. :param params: registered parameters
  80. """
  81. if isinstance(params, Tensor):
  82. params = [params]
  83. for idx, param in enumerate(params):
  84. if id(param) in self._param_dict:
  85. self._param_dict.pop(id(param))
  86. self._call_back_dict.pop(id(param))
  87. else:
  88. logger.warning("params with index {} is not attached.".format(idx))
  89. def clear_grad(self):
  90. r"""
  91. For advanced usage: set the grad attribute to None for registered parameters.
  92. It could be more convenient when there is more than one Optimizer.
  93. """
  94. for param in self._param_dict.values():
  95. param.grad = None
  96. def _register_after_backward_callback(self, callback):
  97. self._after_backward_callback.append(callback)
  98. return self
  99. def backward(self, ys, dys=None):
  100. r"""
  101. Performs back-propagation and computes gradients.
  102. :param ys: outputs of forward operators, e.g., the loss tensor
  103. :param dys: derivatives of ys
  104. """
  105. from ..functional import ones_like
  106. global backwarding_grad_manager
  107. cache = backwarding_grad_manager
  108. backwarding_grad_manager = self
  109. if not self._recording:
  110. raise RuntimeError(
  111. "no computation history. "
  112. "did you forget record() or "
  113. "call a method that clears the history?"
  114. )
  115. assert self._grad is not None
  116. if not isinstance(ys, (tuple, list)):
  117. ys = [ys]
  118. if dys is None:
  119. dys = [ones_like(y) for y in ys]
  120. if not isinstance(dys, (tuple, list)):
  121. dys = [dys]
  122. try:
  123. self._grad(ys, dys)
  124. for callback in self._after_backward_callback:
  125. callback()
  126. for p, grad in self._gradients.items():
  127. if isinstance(grad, Future):
  128. grad = grad.get()
  129. param = self._param_dict[p]
  130. if param.grad is None:
  131. param.grad = grad
  132. else:
  133. param.grad += grad
  134. finally:
  135. self.release()
  136. backwarding_grad_manager = cache
  137. def record(self):
  138. r"""
  139. Starts recording forward operations.
  140. """
  141. if self._recording:
  142. raise RuntimeError("already recording")
  143. grad = Grad()
  144. self._recording = True
  145. self._grad = grad
  146. for param_id in self._param_dict.keys():
  147. self._record_param(param_id)
  148. grad.__enter__()
  149. def _record_param(self, param_id):
  150. param_wrapper = self._param_dict[param_id]
  151. callbacks = self._call_back_dict[param_id]
  152. def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self):
  153. ret = grad
  154. for cb in callbacks:
  155. ret = cb(param, ret)
  156. gm._gradients[id(p)] = ret
  157. # NOTE: override prev callback wrt when called serval times
  158. self._grad.wrt(param_wrapper, callback=callback)
  159. def release(self):
  160. r"""
  161. Stops recording and releases resources for gradients calculation.
  162. """
  163. if self._grad is not None:
  164. self._grad.__exit__(None, None, None)
  165. self._grad = None
  166. self._recording = False
  167. self._gradients = dict()
  168. def __enter__(self):
  169. self.record()
  170. return self
  171. def __exit__(self, exc_type, exc_val, exc_tb):
  172. self.release()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台