from contextlib import contextmanager from ..core.autodiff.grad import Grad from ..tensor import tensor class GradManager: def __init__(self): self._call_back_pair = [] self._recording = False self._grad = None def register(self, params, callback=None): self._call_back_pair.append([params, callback]) def backward(self, ys, dys=None): if not self._recording: raise RuntimeError( "no computation history. " "did you forget record() or " "call a method that clears the history?" ) assert self._grad is not None if not isinstance(ys, (tuple, list)): ys = [ys] if dys is None: dys = [tensor(1).broadcast(y.shape) for y in ys] if not isinstance(dys, (tuple, list)): dys = [dys] try: self._grad(ys, dys) finally: self._grad = None def record(self): @contextmanager def recorder(): grad = Grad() if self._recording: raise RuntimeError("already recording!") try: self._recording = True self._grad = grad for params, callbacks in self._call_back_pair: grad.wrt(*params, callback=callbacks) with grad: yield finally: self._recording = False self._grad = None return recorder()