|
|
@@ -14,6 +14,51 @@ def get_backwarding_grad_manager(): |
|
|
|
|
|
|
|
|
|
|
|
class GradManager: |
|
|
|
r"""GradManager manages auto differentiation and all resources required to perform it. |
|
|
|
|
|
|
|
Our auto differentiation framework requires that the user explicitly indicates when |
|
|
|
the forward operations start and when all resources should be released. A typical usage of |
|
|
|
GradManager is as follows: |
|
|
|
|
|
|
|
.. codeblock:: |
|
|
|
|
|
|
|
gm = GradManager() |
|
|
|
gm.attach(model.parameters()) |
|
|
|
with gm: |
|
|
|
# forward operations |
|
|
|
... |
|
|
|
# backward gradients |
|
|
|
gm.backward(loss) |
|
|
|
|
|
|
|
You can also use `record()` and `release()` method instead of `with` context: |
|
|
|
|
|
|
|
.. codeblock:: |
|
|
|
|
|
|
|
gm = GradManager() |
|
|
|
gm.attach(model.parameters()) |
|
|
|
|
|
|
|
gm.record() |
|
|
|
|
|
|
|
# forward operations |
|
|
|
... |
|
|
|
# backward gradients |
|
|
|
gm.backward(loss) |
|
|
|
|
|
|
|
gm.release() |
|
|
|
|
|
|
|
Typically, in data parallel, we would like to average the gradients across |
|
|
|
processes. Users will finally get the averaged gradients if an "AllReduce" |
|
|
|
callback is registered as follows: |
|
|
|
|
|
|
|
.. codeblock:: |
|
|
|
|
|
|
|
import megengine.distributed as dist |
|
|
|
|
|
|
|
gm = GradManager() |
|
|
|
gm.attach(model.parameters(), callback=dist.make_allreduce_cb("MEAN")) |
|
|
|
|
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self): |
|
|
|
self._call_back_dict = defaultdict(list) |
|
|
|
self._param_dict = dict() |
|
|
@@ -23,6 +68,18 @@ class GradManager: |
|
|
|
self._gradients = dict() |
|
|
|
|
|
|
|
def attach(self, params, callbacks=None): |
|
|
|
r"""Registers parameters that gradients should be calculated with respect to. |
|
|
|
Callback Functions should have a signature like this: |
|
|
|
|
|
|
|
.. codeblock:: |
|
|
|
|
|
|
|
def cb(param: Tensor, grad: Tensor) -> Tensor: |
|
|
|
# do something |
|
|
|
return grad |
|
|
|
|
|
|
|
:param params: registered parameters |
|
|
|
:param callbacks: list of callback functions |
|
|
|
""" |
|
|
|
if callbacks is None: |
|
|
|
callbacks = [] |
|
|
|
if isinstance(callbacks, Callable): |
|
|
@@ -38,6 +95,11 @@ class GradManager: |
|
|
|
return self |
|
|
|
|
|
|
|
def backward(self, ys, dys=None): |
|
|
|
r"""Performs back-propagation and computes gradients. |
|
|
|
|
|
|
|
:param ys: outputs of forward operators, e.g., the loss tensor |
|
|
|
:param dys: derivatives of ys |
|
|
|
""" |
|
|
|
global backwarding_grad_manager |
|
|
|
cache = backwarding_grad_manager |
|
|
|
backwarding_grad_manager = self |
|
|
@@ -71,6 +133,8 @@ class GradManager: |
|
|
|
backwarding_grad_manager = cache |
|
|
|
|
|
|
|
def record(self): |
|
|
|
r"""Starts recording forward operations. |
|
|
|
""" |
|
|
|
if self._recording: |
|
|
|
raise RuntimeError("already recording") |
|
|
|
grad = Grad() |
|
|
@@ -90,6 +154,8 @@ class GradManager: |
|
|
|
grad.__enter__() |
|
|
|
|
|
|
|
def release(self): |
|
|
|
r"""Stops recording and releases resources for gradients calculation. |
|
|
|
""" |
|
|
|
if not self._recording: |
|
|
|
raise RuntimeError("not recording") |
|
|
|
self._stop_record() |
|
|
|