You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_manager.py 3.2 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. from collections import defaultdict
  2. from contextlib import contextmanager
  3. from ..core.autodiff.grad import Grad
  4. from ..tensor import tensor
  5. from ..utils.future import Future
  6. backwarding_grad_manager = None
  7. def get_backwarding_grad_manager():
  8. return backwarding_grad_manager
  9. class GradManager:
  10. def __init__(self):
  11. self._call_back_dict = defaultdict(list)
  12. self._param_dict = dict()
  13. self._recording = False
  14. self._grad = None
  15. self._after_backward_callback = []
  16. self._gradients = dict()
  17. def register(self, params, callbacks=[]):
  18. for p in params:
  19. self._param_dict[id(p)] = p
  20. for cb in callbacks:
  21. self._call_back_dict[id(p)].append(cb)
  22. return self
  23. def _register_after_backward_callback(self, callback):
  24. self._after_backward_callback.append(callback)
  25. return self
  26. def backward(self, ys, dys=None):
  27. global backwarding_grad_manager
  28. cache = backwarding_grad_manager
  29. backwarding_grad_manager = self
  30. if not self._recording:
  31. raise RuntimeError(
  32. "no computation history. "
  33. "did you forget record() or "
  34. "call a method that clears the history?"
  35. )
  36. assert self._grad is not None
  37. if not isinstance(ys, (tuple, list)):
  38. ys = [ys]
  39. if dys is None:
  40. dys = [tensor(1.0).broadcast(y.shape) for y in ys]
  41. if not isinstance(dys, (tuple, list)):
  42. dys = [dys]
  43. try:
  44. self._grad(ys, dys)
  45. for callback in self._after_backward_callback:
  46. callback()
  47. for p, grad in self._gradients.items():
  48. if isinstance(grad, Future):
  49. grad = grad.get()
  50. param = self._param_dict[p]
  51. if getattr(param, "grad", None) is None:
  52. param.grad = grad
  53. else:
  54. param.grad += grad
  55. finally:
  56. self._grad = None
  57. self._gradients = dict()
  58. backwarding_grad_manager = cache
  59. def record(self):
  60. @contextmanager
  61. def recorder():
  62. grad = Grad()
  63. if self._recording:
  64. raise RuntimeError("already recording!")
  65. try:
  66. self._recording = True
  67. self._grad = grad
  68. for param_id in self._param_dict.keys():
  69. param_wrapper = self._param_dict[param_id]
  70. callbacks = self._call_back_dict[param_id]
  71. def callback(
  72. param, grad, callbacks=callbacks, p=param_wrapper, gm=self
  73. ):
  74. ret = grad
  75. for cb in callbacks:
  76. ret = cb(param, ret)
  77. gm._gradients[id(p)] = ret
  78. grad.wrt(param_wrapper, callback=callback)
  79. with grad:
  80. yield
  81. finally:
  82. self._recording = False
  83. self._grad = None
  84. self._gradients = dict()
  85. return recorder()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台