You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_manager.py 6.4 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206
  1. from collections import defaultdict
  2. from contextlib import contextmanager
  3. from typing import Callable
  4. from ..core.autodiff.grad import Grad
  5. from ..logger import get_logger
  6. from ..tensor import Tensor
  7. from ..utils.future import Future
  8. logger = get_logger(__name__)
  9. backwarding_grad_manager = None
  10. def get_backwarding_grad_manager():
  11. return backwarding_grad_manager
  12. class GradManager:
  13. r"""GradManager manages auto differentiation and all resources required to perform it.
  14. Our auto differentiation framework requires that the user explicitly indicates when
  15. the forward operations start and when all resources should be released. A typical usage of
  16. GradManager is as follows:
  17. .. code-block::
  18. gm = GradManager()
  19. gm.attach(model.parameters())
  20. with gm:
  21. # forward operations
  22. ...
  23. # backward gradients
  24. gm.backward(loss)
  25. You can also use ``record()`` and ``release()`` method instead of ``with`` context:
  26. .. code-block::
  27. gm = GradManager()
  28. gm.attach(model.parameters())
  29. gm.record()
  30. # forward operations
  31. ...
  32. # backward gradients
  33. gm.backward(loss)
  34. gm.release()
  35. Typically, in data parallel, we would like to average the gradients across
  36. processes. Users will finally get the averaged gradients if an "AllReduce"
  37. callback is registered as follows:
  38. .. code-block::
  39. import megengine.distributed as dist
  40. gm = GradManager()
  41. gm.attach(model.parameters(), callback=dist.make_allreduce_cb("MEAN"))
  42. """
  43. def __init__(self):
  44. self._call_back_dict = defaultdict(list)
  45. self._param_dict = dict()
  46. self._recording = False
  47. self._grad = None
  48. self._after_backward_callback = []
  49. self._gradients = dict()
  50. def attach(self, params: list, callbacks=None):
  51. r"""Registers parameters that gradients should be calculated with respect to.
  52. Callback Functions should have a signature like this:
  53. .. code-block::
  54. def cb(param: Tensor, grad: Tensor) -> Tensor:
  55. # do something
  56. return grad
  57. :param params: to be registered parameters
  58. :param callbacks: list of callback functions
  59. """
  60. if callbacks is None:
  61. callbacks = []
  62. if isinstance(callbacks, Callable):
  63. callbacks = [callbacks]
  64. if isinstance(params, Tensor):
  65. params = [params]
  66. for p in params:
  67. self._param_dict[id(p)] = p
  68. for cb in callbacks:
  69. self._call_back_dict[id(p)].append(cb)
  70. if self._grad is not None:
  71. for p in params:
  72. self._record_param(id(p))
  73. return self
  74. def detach(self, params: list):
  75. r"""Remove specific registered parameters and callback functions.
  76. :param params: registered parameters
  77. """
  78. if isinstance(params, Tensor):
  79. params = [params]
  80. for idx, param in enumerate(params):
  81. if id(param) in self._param_dict:
  82. self._param_dict.pop(id(param))
  83. self._call_back_dict.pop(id(param))
  84. else:
  85. logger.warning("params with index {} is not attached.".format(idx))
  86. def clear_grad(self):
  87. r"""
  88. For advanced usage: set the grad attribute to None for registered parameters.
  89. It could be more convenient when there is more than one Optimizer.
  90. """
  91. for param in self._param_dict.values():
  92. param.grad = None
  93. def _register_after_backward_callback(self, callback):
  94. self._after_backward_callback.append(callback)
  95. return self
  96. def backward(self, ys, dys=None):
  97. r"""Performs back-propagation and computes gradients.
  98. :param ys: outputs of forward operators, e.g., the loss tensor
  99. :param dys: derivatives of ys
  100. """
  101. from ..functional import ones_like
  102. global backwarding_grad_manager
  103. cache = backwarding_grad_manager
  104. backwarding_grad_manager = self
  105. if not self._recording:
  106. raise RuntimeError(
  107. "no computation history. "
  108. "did you forget record() or "
  109. "call a method that clears the history?"
  110. )
  111. assert self._grad is not None
  112. if not isinstance(ys, (tuple, list)):
  113. ys = [ys]
  114. if dys is None:
  115. dys = [ones_like(y) for y in ys]
  116. if not isinstance(dys, (tuple, list)):
  117. dys = [dys]
  118. try:
  119. self._grad(ys, dys)
  120. for callback in self._after_backward_callback:
  121. callback()
  122. for p, grad in self._gradients.items():
  123. if isinstance(grad, Future):
  124. grad = grad.get()
  125. param = self._param_dict[p]
  126. if param.grad is None:
  127. param.grad = grad
  128. else:
  129. param.grad += grad
  130. finally:
  131. self.release()
  132. backwarding_grad_manager = cache
  133. def record(self):
  134. r"""Starts recording forward operations.
  135. """
  136. if self._recording:
  137. raise RuntimeError("already recording")
  138. grad = Grad()
  139. self._recording = True
  140. self._grad = grad
  141. for param_id in self._param_dict.keys():
  142. self._record_param(param_id)
  143. grad.__enter__()
  144. def _record_param(self, param_id):
  145. param_wrapper = self._param_dict[param_id]
  146. callbacks = self._call_back_dict[param_id]
  147. def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self):
  148. ret = grad
  149. for cb in callbacks:
  150. ret = cb(param, ret)
  151. gm._gradients[id(p)] = ret
  152. # NOTE: override prev callback wrt when called serval times
  153. self._grad.wrt(param_wrapper, callback=callback)
  154. def release(self):
  155. r"""Stops recording and releases resources for gradients calculation.
  156. """
  157. if self._grad is not None:
  158. self._grad.__exit__(None, None, None)
  159. self._grad = None
  160. self._recording = False
  161. self._gradients = dict()
  162. def __enter__(self):
  163. self.record()
  164. return self
  165. def __exit__(self, exc_type, exc_val, exc_tb):
  166. self.release()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台