You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_manager.py 3.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from collections import defaultdict
  2. from contextlib import contextmanager
  3. from ..core.autodiff.grad import Grad
  4. from ..distributed.util import Future
  5. from ..tensor import tensor
  6. backwarding_grad_manager = None
  7. def get_backwarding_grad_manager():
  8. return backwarding_grad_manager
  9. class GradManager:
  10. def __init__(self):
  11. self._call_back_dict = defaultdict(list)
  12. self._param_dict = dict()
  13. self._recording = False
  14. self._grad = None
  15. self._after_backward_callback = []
  16. self._gradients = dict()
  17. def register(self, params, callbacks=[]):
  18. for p in params:
  19. self._param_dict[id(p)] = p
  20. for cb in callbacks:
  21. self._call_back_dict[id(p)].append(cb)
  22. def register_after_backward_callback(self, callback):
  23. self._after_backward_callback.append(callback)
  24. def backward(self, ys, dys=None):
  25. global backwarding_grad_manager
  26. cache = backwarding_grad_manager
  27. backwarding_grad_manager = self
  28. if not self._recording:
  29. raise RuntimeError(
  30. "no computation history. "
  31. "did you forget record() or "
  32. "call a method that clears the history?"
  33. )
  34. assert self._grad is not None
  35. if not isinstance(ys, (tuple, list)):
  36. ys = [ys]
  37. if dys is None:
  38. dys = [tensor(1.0) for y in ys]
  39. if not isinstance(dys, (tuple, list)):
  40. dys = [dys]
  41. try:
  42. self._grad(ys, dys)
  43. for callback in self._after_backward_callback:
  44. callback()
  45. for p, grad in self._gradients.items():
  46. if isinstance(grad, Future):
  47. grad = grad.get()
  48. param = self._param_dict[p]
  49. if getattr(param, "grad", None) is None:
  50. param.grad = grad
  51. else:
  52. param.grad += grad
  53. finally:
  54. self._grad = None
  55. self._gradients = dict()
  56. backwarding_grad_manager = cache
  57. def record(self):
  58. @contextmanager
  59. def recorder():
  60. grad = Grad()
  61. if self._recording:
  62. raise RuntimeError("already recording!")
  63. try:
  64. self._recording = True
  65. self._grad = grad
  66. for param_id in self._param_dict.keys():
  67. param_wrapper = self._param_dict[param_id]
  68. callbacks = self._call_back_dict[param_id]
  69. def callback(
  70. param, grad, callbacks=callbacks, p=param_wrapper, gm=self
  71. ):
  72. ret = grad
  73. for cb in callbacks:
  74. ret = cb(param, ret)
  75. gm._gradients[id(p)] = ret
  76. grad.wrt(param_wrapper, callback=callback)
  77. with grad:
  78. yield
  79. finally:
  80. self._recording = False
  81. self._grad = None
  82. self._gradients = dict()
  83. return recorder()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台