|
@@ -21,7 +21,7 @@ from ..functional.param_pack import get_offsets, pack_allreduce_split |
|
|
from ..functional.utils import copy |
|
|
from ..functional.utils import copy |
|
|
from ..utils.future import Future |
|
|
from ..utils.future import Future |
|
|
from .functional import all_reduce_sum, broadcast |
|
|
from .functional import all_reduce_sum, broadcast |
|
|
from .group import WORLD, group_barrier, is_distributed |
|
|
|
|
|
|
|
|
from .group import WORLD, Group, group_barrier, is_distributed |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TensorFuture(Future): |
|
|
class TensorFuture(Future): |
|
@@ -54,28 +54,43 @@ def synchronized(func: Callable): |
|
|
return wrapper |
|
|
return wrapper |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def worker(queue, device_type): |
|
|
|
|
|
|
|
|
def _get_device_count_worker(queue, device_type): |
|
|
num = get_device_count(device_type) |
|
|
num = get_device_count(device_type) |
|
|
queue.put(num) |
|
|
queue.put(num) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_device_count_by_fork(device_type: str): |
|
|
def get_device_count_by_fork(device_type: str): |
|
|
|
|
|
"""Get device count in fork thread. |
|
|
|
|
|
See https://stackoverflow.com/questions/22950047/cuda-initialization-error-after-fork |
|
|
|
|
|
for more information. |
|
|
|
|
|
""" |
|
|
q = mp.Queue() |
|
|
q = mp.Queue() |
|
|
p = mp.Process(target=worker, args=(q, device_type)) |
|
|
|
|
|
|
|
|
p = mp.Process(target=_get_device_count_worker, args=(q, device_type)) |
|
|
p.start() |
|
|
p.start() |
|
|
p.join() |
|
|
p.join() |
|
|
return q.get() |
|
|
return q.get() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def bcast_list_(params, group): |
|
|
|
|
|
for p in params: |
|
|
|
|
|
p._reset(broadcast(p, group)) |
|
|
|
|
|
|
|
|
def bcast_list_(inps: list, group: Group = WORLD): |
|
|
|
|
|
"""Broadcast tensors between given group. |
|
|
|
|
|
|
|
|
|
|
|
:param inps: input tensors. |
|
|
|
|
|
:param group: communication group. |
|
|
|
|
|
""" |
|
|
|
|
|
for inp in inps: |
|
|
|
|
|
inp._reset(broadcast(inp, group)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AllreduceCallback: |
|
|
class AllreduceCallback: |
|
|
def __init__(self, reduce_method, group=WORLD): |
|
|
|
|
|
|
|
|
"""Allreduce Callback with tensor fusion optimization. |
|
|
|
|
|
|
|
|
|
|
|
:param reduce_method: the method to reduce gradiants. |
|
|
|
|
|
:param group: communication group. |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
def __init__(self, reduce_method: str, group: Group = WORLD): |
|
|
reduce_method = reduce_method.lower() |
|
|
reduce_method = reduce_method.lower() |
|
|
assert reduce_method in ["sum", "mean"] |
|
|
|
|
|
|
|
|
assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean" |
|
|
self._reduce_method = reduce_method |
|
|
self._reduce_method = reduce_method |
|
|
self._group = group |
|
|
self._group = group |
|
|
self._marked_gm = WeakSet() |
|
|
self._marked_gm = WeakSet() |
|
|