|
- from collections import defaultdict
- from contextlib import contextmanager
- from typing import Callable
-
- from ..core.autodiff.grad import Grad
- from ..tensor import tensor
- from ..utils.future import Future
-
- backwarding_grad_manager = None
-
-
- def get_backwarding_grad_manager():
- return backwarding_grad_manager
-
-
- class GradManager:
- def __init__(self):
- self._call_back_dict = defaultdict(list)
- self._param_dict = dict()
- self._recording = False
- self._grad = None
- self._after_backward_callback = []
- self._gradients = dict()
-
- def attach(self, params, callbacks=None):
- if callbacks is None:
- callbacks = []
- if isinstance(callbacks, Callable):
- callbacks = [callbacks]
- for p in params:
- self._param_dict[id(p)] = p
- for cb in callbacks:
- self._call_back_dict[id(p)].append(cb)
- return self
-
- def _register_after_backward_callback(self, callback):
- self._after_backward_callback.append(callback)
- return self
-
- def backward(self, ys, dys=None):
- global backwarding_grad_manager
- cache = backwarding_grad_manager
- backwarding_grad_manager = self
- if not self._recording:
- raise RuntimeError(
- "no computation history. "
- "did you forget record() or "
- "call a method that clears the history?"
- )
- assert self._grad is not None
- if not isinstance(ys, (tuple, list)):
- ys = [ys]
- if dys is None:
- dys = [tensor(1.0).broadcast(y.shape) for y in ys]
- if not isinstance(dys, (tuple, list)):
- dys = [dys]
- try:
- self._grad(ys, dys)
- for callback in self._after_backward_callback:
- callback()
- for p, grad in self._gradients.items():
- if isinstance(grad, Future):
- grad = grad.get()
- param = self._param_dict[p]
- if param.grad is None:
- param.grad = grad
- else:
- param.grad += grad
- finally:
- self._stop_record()
- backwarding_grad_manager = cache
-
- def record(self):
- if self._recording:
- raise RuntimeError("already recording")
- grad = Grad()
- self._recording = True
- self._grad = grad
- for param_id in self._param_dict.keys():
- param_wrapper = self._param_dict[param_id]
- callbacks = self._call_back_dict[param_id]
-
- def callback(param, grad, callbacks=callbacks, p=param_wrapper, gm=self):
- ret = grad
- for cb in callbacks:
- ret = cb(param, ret)
- gm._gradients[id(p)] = ret
-
- grad.wrt(param_wrapper, callback=callback)
- grad.__enter__()
-
- def release(self):
- if not self._recording:
- raise RuntimeError("not recording")
- self._stop_record()
-
- def _stop_record(self):
- if self._grad is not None:
- self._grad.__exit__(None, None, None)
- self._recording = False
- self._grad = None
- self._gradients = dict()
-
- def __enter__(self):
- self.record()
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self._stop_record()
|