GitOrigin-RevId: 086e2871e8
tags/v1.0.0-rc1
@@ -0,0 +1,52 @@ | |||||
from contextlib import contextmanager | |||||
from ..core.autodiff.grad import Grad | |||||
from ..tensor import tensor | |||||
class GradManager: | |||||
def __init__(self): | |||||
self._call_back_pair = [] | |||||
self._recording = False | |||||
self._grad = None | |||||
def register(self, params, callback=None): | |||||
self._call_back_pair.append([params, callback]) | |||||
def backward(self, ys, dys=None): | |||||
if not self._recording: | |||||
raise RuntimeError( | |||||
"no computation history. " | |||||
"did you forget record() or " | |||||
"call a method that clears the history?" | |||||
) | |||||
assert self._grad is not None | |||||
if not isinstance(ys, (tuple, list)): | |||||
ys = [ys] | |||||
if dys is None: | |||||
dys = [tensor(1).broadcast(y.shape) for y in ys] | |||||
if not isinstance(dys, (tuple, list)): | |||||
dys = [dys] | |||||
try: | |||||
self._grad(ys, dys) | |||||
finally: | |||||
self._grad = None | |||||
def record(self): | |||||
@contextmanager | |||||
def recorder(): | |||||
grad = Grad() | |||||
if self._recording: | |||||
raise RuntimeError("already recording!") | |||||
try: | |||||
self._recording = True | |||||
self._grad = grad | |||||
for params, callbacks in self._call_back_pair: | |||||
grad.wrt(*params, callback=callbacks) | |||||
with grad: | |||||
yield | |||||
finally: | |||||
self._recording = False | |||||
self._grad = None | |||||
return recorder() |
@@ -260,9 +260,13 @@ class Grad: | |||||
cache[v] = g | cache[v] = g | ||||
if last_written_to[v] == (seqno, i): | if last_written_to[v] == (seqno, i): | ||||
if v.callback: | if v.callback: | ||||
v.callback( | |||||
grad = v.callback( | |||||
v.owner(), Wrapper(cache[v]) if Wrapper else cache[v] | v.owner(), Wrapper(cache[v]) if Wrapper else cache[v] | ||||
) | ) | ||||
if getattr(v.owner(), "grad", None) is None: | |||||
v.owner().grad = grad | |||||
else: | |||||
v.owner().grad += grad | |||||
if v.opnode is None: | if v.opnode is None: | ||||
# won't read by backward, mark consumed | # won't read by backward, mark consumed | ||||
cache[v] = None | cache[v] = None | ||||
@@ -19,7 +19,7 @@ from .group import ( | |||||
is_distributed, | is_distributed, | ||||
new_group, | new_group, | ||||
) | ) | ||||
from .helper import synchronized | |||||
from .helper import bcast_params_, make_allreduce_cb, synchronized | |||||
from .launcher import launcher | from .launcher import launcher | ||||
from .server import Client, Server | from .server import Client, Server | ||||
from .util import get_free_ports | from .util import get_free_ports |
@@ -12,7 +12,8 @@ from typing import Callable | |||||
from megengine.device import get_device_count | from megengine.device import get_device_count | ||||
from .group import group_barrier, is_distributed | |||||
from .functional import all_reduce_sum, broadcast | |||||
from .group import WORLD, group_barrier, is_distributed | |||||
def synchronized(func: Callable): | def synchronized(func: Callable): | ||||
@@ -42,3 +43,23 @@ def get_device_count_by_fork(device_type: str): | |||||
p.start() | p.start() | ||||
p.join() | p.join() | ||||
return q.get() | return q.get() | ||||
def bcast_params_(params, group): | |||||
for p in params: | |||||
p._reset(broadcast(p, group)) | |||||
class AllreduceCallback: | |||||
def __init__(self, reduce_method, group=WORLD): | |||||
self._reduce_method = reduce_method | |||||
self._group = group | |||||
def __call__(self, param, grad): | |||||
ret = all_reduce_sum(grad, self._group) | |||||
if self._reduce_method == "MEAN": | |||||
ret = ret / self._group.size | |||||
return ret | |||||
make_allreduce_cb = AllreduceCallback |