|
|
@@ -12,7 +12,8 @@ from typing import Callable |
|
|
|
|
|
|
|
from megengine.device import get_device_count |
|
|
|
|
|
|
|
from .group import group_barrier, is_distributed |
|
|
|
from .functional import all_reduce_sum, broadcast |
|
|
|
from .group import WORLD, group_barrier, is_distributed |
|
|
|
|
|
|
|
|
|
|
|
def synchronized(func: Callable): |
|
|
@@ -42,3 +43,23 @@ def get_device_count_by_fork(device_type: str): |
|
|
|
p.start() |
|
|
|
p.join() |
|
|
|
return q.get() |
|
|
|
|
|
|
|
|
|
|
|
def bcast_params_(params, group): |
|
|
|
for p in params: |
|
|
|
p._reset(broadcast(p, group)) |
|
|
|
|
|
|
|
|
|
|
|
class AllreduceCallback: |
|
|
|
def __init__(self, reduce_method, group=WORLD): |
|
|
|
self._reduce_method = reduce_method |
|
|
|
self._group = group |
|
|
|
|
|
|
|
def __call__(self, param, grad): |
|
|
|
ret = all_reduce_sum(grad, self._group) |
|
|
|
if self._reduce_method == "MEAN": |
|
|
|
ret = ret / self._group.size |
|
|
|
return ret |
|
|
|
|
|
|
|
|
|
|
|
make_allreduce_cb = AllreduceCallback |