|
- from collections import defaultdict
- from contextlib import contextmanager
-
- from ..core.autodiff.grad import Grad
- from ..distributed.util import Future
- from ..tensor import tensor
-
- backwarding_grad_manager = None
-
-
- def get_backwarding_grad_manager():
- return backwarding_grad_manager
-
-
- class GradManager:
- def __init__(self):
- self._call_back_dict = defaultdict(list)
- self._param_dict = dict()
- self._recording = False
- self._grad = None
- self._after_backward_callback = []
- self._gradients = dict()
-
- def register(self, params, callbacks=[]):
- for p in params:
- self._param_dict[id(p)] = p
- for cb in callbacks:
- self._call_back_dict[id(p)].append(cb)
-
- def register_after_backward_callback(self, callback):
- self._after_backward_callback.append(callback)
- return self
-
- def backward(self, ys, dys=None):
- global backwarding_grad_manager
- cache = backwarding_grad_manager
- backwarding_grad_manager = self
- if not self._recording:
- raise RuntimeError(
- "no computation history. "
- "did you forget record() or "
- "call a method that clears the history?"
- )
- assert self._grad is not None
- if not isinstance(ys, (tuple, list)):
- ys = [ys]
- if dys is None:
- dys = [tensor(1.0) for y in ys]
- if not isinstance(dys, (tuple, list)):
- dys = [dys]
- try:
- self._grad(ys, dys)
- for callback in self._after_backward_callback:
- callback()
- for p, grad in self._gradients.items():
- if isinstance(grad, Future):
- grad = grad.get()
- param = self._param_dict[p]
- if getattr(param, "grad", None) is None:
- param.grad = grad
- else:
- param.grad += grad
- finally:
- self._grad = None
- self._gradients = dict()
- backwarding_grad_manager = cache
-
- def record(self):
- @contextmanager
- def recorder():
- grad = Grad()
- if self._recording:
- raise RuntimeError("already recording!")
- try:
- self._recording = True
- self._grad = grad
- for param_id in self._param_dict.keys():
- param_wrapper = self._param_dict[param_id]
- callbacks = self._call_back_dict[param_id]
-
- def callback(
- param, grad, callbacks=callbacks, p=param_wrapper, gm=self
- ):
- ret = grad
- for cb in callbacks:
- ret = cb(param, ret)
- gm._gradients[id(p)] = ret
-
- grad.wrt(param_wrapper, callback=callback)
- with grad:
- yield
- finally:
- self._recording = False
- self._grad = None
- self._gradients = dict()
-
- return recorder()
|