diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py index be806cab..2a0f6906 100644 --- a/imperative/python/megengine/autodiff/grad_manager.py +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -14,6 +14,51 @@ def get_backwarding_grad_manager(): class GradManager: + r"""GradManager manages auto differentiation and all resources required to perform it. + + Our auto differentiation framework requires that the user explicitly indicates when + the forward operations start and when all resources should be released. A typical usage of + GradManager is as follows: + + .. codeblock:: + + gm = GradManager() + gm.attach(model.parameters()) + with gm: + # forward operations + ... + # backward gradients + gm.backward(loss) + + You can also use `record()` and `release()` method instead of `with` context: + + .. codeblock:: + + gm = GradManager() + gm.attach(model.parameters()) + + gm.record() + + # forward operations + ... + # backward gradients + gm.backward(loss) + + gm.release() + + Typically, in data parallel, we would like to average the gradients across + processes. Users will finally get the averaged gradients if an "AllReduce" + callback is registered as follows: + + .. codeblock:: + + import megengine.distributed as dist + + gm = GradManager() + gm.attach(model.parameters(), callback=dist.make_allreduce_cb("MEAN")) + + """ + def __init__(self): self._call_back_dict = defaultdict(list) self._param_dict = dict() @@ -23,6 +68,18 @@ class GradManager: self._gradients = dict() def attach(self, params, callbacks=None): + r"""Registers parameters that gradients should be calculated with respect to. + Callback Functions should have a signature like this: + + .. codeblock:: + + def cb(param: Tensor, grad: Tensor) -> Tensor: + # do something + return grad + + :param params: registered parameters + :param callbacks: list of callback functions + """ if callbacks is None: callbacks = [] if isinstance(callbacks, Callable): @@ -38,6 +95,11 @@ class GradManager: return self def backward(self, ys, dys=None): + r"""Performs back-propagation and computes gradients. + + :param ys: outputs of forward operators, e.g., the loss tensor + :param dys: derivatives of ys + """ global backwarding_grad_manager cache = backwarding_grad_manager backwarding_grad_manager = self @@ -71,6 +133,8 @@ class GradManager: backwarding_grad_manager = cache def record(self): + r"""Starts recording forward operations. + """ if self._recording: raise RuntimeError("already recording") grad = Grad() @@ -90,6 +154,8 @@ class GradManager: grad.__enter__() def release(self): + r"""Stops recording and releases resources for gradients calculation. + """ if not self._recording: raise RuntimeError("not recording") self._stop_record()