Browse Source

refactor(mge/grad_manager): refactor gradmanager, add allreduce callback

GitOrigin-RevId: 086e2871e8
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
6070266766
4 changed files with 80 additions and 3 deletions
  1. +52
    -0
      imperative/python/megengine/autodiff/grad_manager.py
  2. +5
    -1
      imperative/python/megengine/core/autodiff/grad.py
  3. +1
    -1
      imperative/python/megengine/distributed/__init__.py
  4. +22
    -1
      imperative/python/megengine/distributed/helper.py

+ 52
- 0
imperative/python/megengine/autodiff/grad_manager.py View File

@@ -0,0 +1,52 @@
from contextlib import contextmanager

from ..core.autodiff.grad import Grad
from ..tensor import tensor


class GradManager:
def __init__(self):
self._call_back_pair = []
self._recording = False
self._grad = None

def register(self, params, callback=None):
self._call_back_pair.append([params, callback])

def backward(self, ys, dys=None):
if not self._recording:
raise RuntimeError(
"no computation history. "
"did you forget record() or "
"call a method that clears the history?"
)
assert self._grad is not None
if not isinstance(ys, (tuple, list)):
ys = [ys]
if dys is None:
dys = [tensor(1).broadcast(y.shape) for y in ys]
if not isinstance(dys, (tuple, list)):
dys = [dys]
try:
self._grad(ys, dys)
finally:
self._grad = None

def record(self):
@contextmanager
def recorder():
grad = Grad()
if self._recording:
raise RuntimeError("already recording!")
try:
self._recording = True
self._grad = grad
for params, callbacks in self._call_back_pair:
grad.wrt(*params, callback=callbacks)
with grad:
yield
finally:
self._recording = False
self._grad = None

return recorder()

+ 5
- 1
imperative/python/megengine/core/autodiff/grad.py View File

@@ -260,9 +260,13 @@ class Grad:
cache[v] = g
if last_written_to[v] == (seqno, i):
if v.callback:
v.callback(
grad = v.callback(
v.owner(), Wrapper(cache[v]) if Wrapper else cache[v]
)
if getattr(v.owner(), "grad", None) is None:
v.owner().grad = grad
else:
v.owner().grad += grad
if v.opnode is None:
# won't read by backward, mark consumed
cache[v] = None


+ 1
- 1
imperative/python/megengine/distributed/__init__.py View File

@@ -19,7 +19,7 @@ from .group import (
is_distributed,
new_group,
)
from .helper import synchronized
from .helper import bcast_params_, make_allreduce_cb, synchronized
from .launcher import launcher
from .server import Client, Server
from .util import get_free_ports

+ 22
- 1
imperative/python/megengine/distributed/helper.py View File

@@ -12,7 +12,8 @@ from typing import Callable

from megengine.device import get_device_count

from .group import group_barrier, is_distributed
from .functional import all_reduce_sum, broadcast
from .group import WORLD, group_barrier, is_distributed


def synchronized(func: Callable):
@@ -42,3 +43,23 @@ def get_device_count_by_fork(device_type: str):
p.start()
p.join()
return q.get()


def bcast_params_(params, group):
for p in params:
p._reset(broadcast(p, group))


class AllreduceCallback:
def __init__(self, reduce_method, group=WORLD):
self._reduce_method = reduce_method
self._group = group

def __call__(self, param, grad):
ret = all_reduce_sum(grad, self._group)
if self._reduce_method == "MEAN":
ret = ret / self._group.size
return ret


make_allreduce_cb = AllreduceCallback

Loading…
Cancel
Save