You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

grad_manager.py 3.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. from collections import defaultdict
  2. from contextlib import contextmanager
  3. from typing import Callable
  4. from ..core.autodiff.grad import Grad
  5. from ..tensor import tensor
  6. from ..utils.future import Future
  7. backwarding_grad_manager = None
  8. def get_backwarding_grad_manager():
  9. return backwarding_grad_manager
  10. class GradManager:
  11. def __init__(self):
  12. self._call_back_dict = defaultdict(list)
  13. self._param_dict = dict()
  14. self._recording = False
  15. self._grad = None
  16. self._after_backward_callback = []
  17. self._gradients = dict()
  18. def register(self, params, callbacks=None):
  19. if callbacks is None:
  20. callbacks = []
  21. if isinstance(callbacks, Callable):
  22. callbacks = [callbacks]
  23. for p in params:
  24. self._param_dict[id(p)] = p
  25. for cb in callbacks:
  26. self._call_back_dict[id(p)].append(cb)
  27. return self
  28. def _register_after_backward_callback(self, callback):
  29. self._after_backward_callback.append(callback)
  30. return self
  31. def backward(self, ys, dys=None):
  32. global backwarding_grad_manager
  33. cache = backwarding_grad_manager
  34. backwarding_grad_manager = self
  35. if not self._recording:
  36. raise RuntimeError(
  37. "no computation history. "
  38. "did you forget record() or "
  39. "call a method that clears the history?"
  40. )
  41. assert self._grad is not None
  42. if not isinstance(ys, (tuple, list)):
  43. ys = [ys]
  44. if dys is None:
  45. dys = [tensor(1.0).broadcast(y.shape) for y in ys]
  46. if not isinstance(dys, (tuple, list)):
  47. dys = [dys]
  48. try:
  49. self._grad(ys, dys)
  50. for callback in self._after_backward_callback:
  51. callback()
  52. for p, grad in self._gradients.items():
  53. if isinstance(grad, Future):
  54. grad = grad.get()
  55. param = self._param_dict[p]
  56. if getattr(param, "grad", None) is None:
  57. param.grad = grad
  58. else:
  59. param.grad += grad
  60. finally:
  61. self._stop_record()
  62. backwarding_grad_manager = cache
  63. def __enter__(self):
  64. if self._recording:
  65. return self
  66. grad = Grad()
  67. self._recording = True
  68. self._grad = grad
  69. for param_id in self._param_dict.keys():
  70. param_wrapper = self._param_dict[param_id]
  71. callbacks = self._call_back_dict[param_id]
  72. def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self):
  73. ret = grad
  74. for cb in callbacks:
  75. ret = cb(param, ret)
  76. gm._gradients[id(p)] = ret
  77. grad.wrt(param_wrapper, callback=callback)
  78. grad.__enter__()
  79. return self
  80. def __exit__(self, exc_type, exc_val, exc_tb):
  81. self._stop_record()
  82. record = __enter__
  83. def _stop_record(self):
  84. if self._grad is not None:
  85. self._grad.__exit__(None, None, None)
  86. self._recording = False
  87. self._grad = None
  88. self._gradients = dict()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台