You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_manager.py 3.2 kB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. from collections import defaultdict
  2. from contextlib import contextmanager
  3. from ..core.autodiff.grad import Grad
  4. from ..distributed.util import Future
  5. from ..tensor import tensor
  6. backwarding_grad_manager = None
  7. def get_backwarding_grad_manager():
  8. return backwarding_grad_manager
  9. class GradManager:
  10. def __init__(self):
  11. self._call_back_dict = defaultdict(list)
  12. self._param_dict = dict()
  13. self._recording = False
  14. self._grad = None
  15. self._after_backward_callback = []
  16. self._gradients = dict()
  17. def register(self, params, callbacks=[]):
  18. for p in params:
  19. self._param_dict[id(p)] = p
  20. for cb in callbacks:
  21. self._call_back_dict[id(p)].append(cb)
  22. def register_after_backward_callback(self, callback):
  23. self._after_backward_callback.append(callback)
  24. return self
  25. def backward(self, ys, dys=None):
  26. global backwarding_grad_manager
  27. cache = backwarding_grad_manager
  28. backwarding_grad_manager = self
  29. if not self._recording:
  30. raise RuntimeError(
  31. "no computation history. "
  32. "did you forget record() or "
  33. "call a method that clears the history?"
  34. )
  35. assert self._grad is not None
  36. if not isinstance(ys, (tuple, list)):
  37. ys = [ys]
  38. if dys is None:
  39. dys = [tensor(1.0) for y in ys]
  40. if not isinstance(dys, (tuple, list)):
  41. dys = [dys]
  42. try:
  43. self._grad(ys, dys)
  44. for callback in self._after_backward_callback:
  45. callback()
  46. for p, grad in self._gradients.items():
  47. if isinstance(grad, Future):
  48. grad = grad.get()
  49. param = self._param_dict[p]
  50. if getattr(param, "grad", None) is None:
  51. param.grad = grad
  52. else:
  53. param.grad += grad
  54. finally:
  55. self._grad = None
  56. self._gradients = dict()
  57. backwarding_grad_manager = cache
  58. def record(self):
  59. @contextmanager
  60. def recorder():
  61. grad = Grad()
  62. if self._recording:
  63. raise RuntimeError("already recording!")
  64. try:
  65. self._recording = True
  66. self._grad = grad
  67. for param_id in self._param_dict.keys():
  68. param_wrapper = self._param_dict[param_id]
  69. callbacks = self._call_back_dict[param_id]
  70. def callback(
  71. param, grad, callbacks=callbacks, p=param_wrapper, gm=self
  72. ):
  73. ret = grad
  74. for cb in callbacks:
  75. ret = cb(param, ret)
  76. gm._gradients[id(p)] = ret
  77. grad.wrt(param_wrapper, callback=callback)
  78. with grad:
  79. yield
  80. finally:
  81. self._recording = False
  82. self._grad = None
  83. self._gradients = dict()
  84. return recorder()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台