|
- import weakref
- from collections import defaultdict
- from contextlib import contextmanager
- from typing import Callable
-
- from ..core.autodiff.grad import Grad
- from ..logger import get_logger
- from ..tensor import Tensor
- from ..utils.future import Future
-
- logger = get_logger(__name__)
-
- backwarding_grad_manager = None
-
-
- def get_backwarding_grad_manager():
- return backwarding_grad_manager
-
-
- class AttachSpec:
- __slots__ = "tensor", "callbacks"
-
-
- class GradManager:
- r"""
- GradManager manages auto differentiation and all resources required to perform it.
-
- Our auto differentiation framework requires that the user explicitly indicates when
- the forward operations start and when all resources should be released. A typical usage of
- GradManager is as follows:
-
- .. code-block::
-
- gm = GradManager()
- gm.attach(model.parameters())
- with gm:
- # forward operations
- ...
- # backward gradients
- gm.backward(loss)
-
- You can also use ``record()`` and ``release()`` method instead of ``with`` context:
-
- .. code-block::
-
- gm = GradManager()
- gm.attach(model.parameters())
-
- gm.record()
-
- # forward operations
- ...
- # backward gradients
- gm.backward(loss)
-
- gm.release()
-
- Typically, in data parallel, we would like to average the gradients across
- processes. Users will finally get the averaged gradients if an "AllReduce"
- callback is registered as follows:
-
- .. code-block::
-
- import megengine.distributed as dist
-
- gm = GradManager()
- gm.attach(model.parameters(), callback=dist.make_allreduce_cb("MEAN"))
-
- """
-
- def __init__(self):
- self._attach_specs = {} # id(Tensor) -> AttachSpec
- self._recording = False
- self._grad = None
- self._after_backward_callback = []
- self._gradients = {}
-
- def attach(self, tensors: list, callbacks=None):
- r"""
- Registers parameters that gradients should be calculated with respect to.
- Callback Functions should have a signature like this:
-
- .. code-block::
-
- def cb(param: Tensor, grad: Tensor) -> Tensor:
- # do something
- return grad
-
- :param params: to be registered parameters
- :param callbacks: list of callback functions
- """
- if callbacks is None:
- callbacks = []
- if isinstance(callbacks, Callable):
- callbacks = [callbacks]
- if isinstance(tensors, Tensor):
- tensors = [tensors]
-
- def make_spec(tensor):
- selfref = weakref.ref(self)
- key = id(tensor)
-
- def deleter(_):
- self = selfref()
- if self is not None:
- del self._attach_specs[key]
-
- spec = AttachSpec()
- spec.tensor = weakref.ref(tensor, deleter)
- spec.callbacks = []
- return spec
-
- for x in tensors:
- spec = self._attach_specs.get(id(x))
- new_attach = spec is None
- if spec is None:
- spec = make_spec(x)
- self._attach_specs[id(x)] = spec
- spec.callbacks.extend(callbacks)
- if new_attach and self._recording:
- self._do_record(spec)
- return self
-
- def _register_after_backward_callback(self, callback):
- self._after_backward_callback.append(callback)
- return self
-
- def backward(self, y=None, dy=None):
- r"""
- Performs back-propagation and computes gradients.
-
- :param ys: outputs of forward operators, e.g., the loss tensor
- :param dys: derivatives of ys
- """
- from ..functional import ones_like
-
- global backwarding_grad_manager
- cache = backwarding_grad_manager
- backwarding_grad_manager = self
- if not self._recording:
- raise RuntimeError(
- "no computation history. "
- "did you forget record() or "
- "call a method that clears the history?"
- )
- assert self._grad is not None
- if ys is None:
- ys = []
- if not isinstance(ys, (tuple, list)):
- ys = [ys]
- if dys is None:
- dys = [ones_like(y) for y in ys]
- if not isinstance(dys, (tuple, list)):
- dys = [dys]
- try:
- self._grad(ys, dys)
- for callback in self._after_backward_callback:
- callback()
- for id_, grad in self._gradients.items():
- if isinstance(grad, Future):
- grad = grad.get()
- spec = self._attach_specs.get(id_)
- tensor = spec and spec.tensor()
- if tensor is not None:
- if tensor.grad is None:
- tensor.grad = grad
- else:
- tensor.grad += grad
- finally:
- self.release()
- backwarding_grad_manager = cache
-
- def record(self):
- r"""
- Starts recording forward operations.
- """
- if self._recording:
- raise RuntimeError("already recording")
- grad = Grad()
- self._recording = True
- self._grad = grad
- for spec in self._attach_specs.values():
- self._do_record(spec)
- grad.__enter__()
-
- def _do_record(self, spec):
- tensor = spec.tensor()
- if tensor is None:
- return
-
- def callback(_, grad, callbacks=spec.callbacks):
- for cb in callbacks:
- grad = cb(tensor, grad)
- self._gradients[id(tensor)] = grad
-
- # NOTE: override prev callback wrt when called serval times
- self._grad.wrt(tensor, callback=callback)
-
- def release(self):
- r"""
- Stops recording and releases resources for gradients calculation.
- """
- if self._grad is not None:
- self._grad.__exit__(None, None, None)
- self._grad = None
- self._recording = False
- self._gradients = dict()
-
- def __enter__(self):
- self.record()
- return self
-
- def __exit__(self, exc_type, exc_val, exc_tb):
- self.release()
|