diff --git a/imperative/python/megengine/autodiff/grad_manager.py b/imperative/python/megengine/autodiff/grad_manager.py new file mode 100644 index 00000000..84a8f55f --- /dev/null +++ b/imperative/python/megengine/autodiff/grad_manager.py @@ -0,0 +1,52 @@ +from contextlib import contextmanager + +from ..core.autodiff.grad import Grad +from ..tensor import tensor + + +class GradManager: + def __init__(self): + self._call_back_pair = [] + self._recording = False + self._grad = None + + def register(self, params, callback=None): + self._call_back_pair.append([params, callback]) + + def backward(self, ys, dys=None): + if not self._recording: + raise RuntimeError( + "no computation history. " + "did you forget record() or " + "call a method that clears the history?" + ) + assert self._grad is not None + if not isinstance(ys, (tuple, list)): + ys = [ys] + if dys is None: + dys = [tensor(1).broadcast(y.shape) for y in ys] + if not isinstance(dys, (tuple, list)): + dys = [dys] + try: + self._grad(ys, dys) + finally: + self._grad = None + + def record(self): + @contextmanager + def recorder(): + grad = Grad() + if self._recording: + raise RuntimeError("already recording!") + try: + self._recording = True + self._grad = grad + for params, callbacks in self._call_back_pair: + grad.wrt(*params, callback=callbacks) + with grad: + yield + finally: + self._recording = False + self._grad = None + + return recorder() diff --git a/imperative/python/megengine/core/autodiff/grad.py b/imperative/python/megengine/core/autodiff/grad.py index c30e4113..42d8ae9c 100644 --- a/imperative/python/megengine/core/autodiff/grad.py +++ b/imperative/python/megengine/core/autodiff/grad.py @@ -260,9 +260,13 @@ class Grad: cache[v] = g if last_written_to[v] == (seqno, i): if v.callback: - v.callback( + grad = v.callback( v.owner(), Wrapper(cache[v]) if Wrapper else cache[v] ) + if getattr(v.owner(), "grad", None) is None: + v.owner().grad = grad + else: + v.owner().grad += grad if v.opnode is None: # won't read by backward, mark consumed cache[v] = None diff --git a/imperative/python/megengine/distributed/__init__.py b/imperative/python/megengine/distributed/__init__.py index 30e0766f..daf866af 100644 --- a/imperative/python/megengine/distributed/__init__.py +++ b/imperative/python/megengine/distributed/__init__.py @@ -19,7 +19,7 @@ from .group import ( is_distributed, new_group, ) -from .helper import synchronized +from .helper import bcast_params_, make_allreduce_cb, synchronized from .launcher import launcher from .server import Client, Server from .util import get_free_ports diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index 2be3011f..8b787b7a 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -12,7 +12,8 @@ from typing import Callable from megengine.device import get_device_count -from .group import group_barrier, is_distributed +from .functional import all_reduce_sum, broadcast +from .group import WORLD, group_barrier, is_distributed def synchronized(func: Callable): @@ -42,3 +43,23 @@ def get_device_count_by_fork(device_type: str): p.start() p.join() return q.get() + + +def bcast_params_(params, group): + for p in params: + p._reset(broadcast(p, group)) + + +class AllreduceCallback: + def __init__(self, reduce_method, group=WORLD): + self._reduce_method = reduce_method + self._group = group + + def __call__(self, param, grad): + ret = all_reduce_sum(grad, self._group) + if self._reduce_method == "MEAN": + ret = ret / self._group.size + return ret + + +make_allreduce_cb = AllreduceCallback