Browse Source

docs(mge/imperative): add docstring for GradManager

GitOrigin-RevId: 4c326206b8
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
e6715910c5
1 changed files with 66 additions and 0 deletions
  1. +66
    -0
      imperative/python/megengine/autodiff/grad_manager.py

+ 66
- 0
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -14,6 +14,51 @@ def get_backwarding_grad_manager():


class GradManager:
r"""GradManager manages auto differentiation and all resources required to perform it.

Our auto differentiation framework requires that the user explicitly indicates when
the forward operations start and when all resources should be released. A typical usage of
GradManager is as follows:

.. codeblock::

gm = GradManager()
gm.attach(model.parameters())
with gm:
# forward operations
...
# backward gradients
gm.backward(loss)

You can also use `record()` and `release()` method instead of `with` context:

.. codeblock::

gm = GradManager()
gm.attach(model.parameters())

gm.record()

# forward operations
...
# backward gradients
gm.backward(loss)

gm.release()

Typically, in data parallel, we would like to average the gradients across
processes. Users will finally get the averaged gradients if an "AllReduce"
callback is registered as follows:

.. codeblock::

import megengine.distributed as dist

gm = GradManager()
gm.attach(model.parameters(), callback=dist.make_allreduce_cb("MEAN"))

"""

def __init__(self):
self._call_back_dict = defaultdict(list)
self._param_dict = dict()
@@ -23,6 +68,18 @@ class GradManager:
self._gradients = dict()

def attach(self, params, callbacks=None):
r"""Registers parameters that gradients should be calculated with respect to.
Callback Functions should have a signature like this:

.. codeblock::

def cb(param: Tensor, grad: Tensor) -> Tensor:
# do something
return grad

:param params: registered parameters
:param callbacks: list of callback functions
"""
if callbacks is None:
callbacks = []
if isinstance(callbacks, Callable):
@@ -38,6 +95,11 @@ class GradManager:
return self

def backward(self, ys, dys=None):
r"""Performs back-propagation and computes gradients.

:param ys: outputs of forward operators, e.g., the loss tensor
:param dys: derivatives of ys
"""
global backwarding_grad_manager
cache = backwarding_grad_manager
backwarding_grad_manager = self
@@ -71,6 +133,8 @@ class GradManager:
backwarding_grad_manager = cache

def record(self):
r"""Starts recording forward operations.
"""
if self._recording:
raise RuntimeError("already recording")
grad = Grad()
@@ -90,6 +154,8 @@ class GradManager:
grad.__enter__()

def release(self):
r"""Stops recording and releases resources for gradients calculation.
"""
if not self._recording:
raise RuntimeError("not recording")
self._stop_record()


Loading…
Cancel
Save