Browse Source

docs(mge/distributed): add distributed.helper docs

GitOrigin-RevId: 37c14aa11f
release-1.1
Megvii Engine Team 4 years ago
parent
commit
6c5cf25f4d
1 changed files with 23 additions and 8 deletions
  1. +23
    -8
      imperative/python/megengine/distributed/helper.py

+ 23
- 8
imperative/python/megengine/distributed/helper.py View File

@@ -21,7 +21,7 @@ from ..functional.param_pack import get_offsets, pack_allreduce_split
from ..functional.utils import copy
from ..utils.future import Future
from .functional import all_reduce_sum, broadcast
from .group import WORLD, group_barrier, is_distributed
from .group import WORLD, Group, group_barrier, is_distributed


class TensorFuture(Future):
@@ -54,28 +54,43 @@ def synchronized(func: Callable):
return wrapper


def worker(queue, device_type):
def _get_device_count_worker(queue, device_type):
num = get_device_count(device_type)
queue.put(num)


def get_device_count_by_fork(device_type: str):
"""Get device count in fork thread.
See https://stackoverflow.com/questions/22950047/cuda-initialization-error-after-fork
for more information.
"""
q = mp.Queue()
p = mp.Process(target=worker, args=(q, device_type))
p = mp.Process(target=_get_device_count_worker, args=(q, device_type))
p.start()
p.join()
return q.get()


def bcast_list_(params, group):
for p in params:
p._reset(broadcast(p, group))
def bcast_list_(inps: list, group: Group = WORLD):
"""Broadcast tensors between given group.

:param inps: input tensors.
:param group: communication group.
"""
for inp in inps:
inp._reset(broadcast(inp, group))


class AllreduceCallback:
def __init__(self, reduce_method, group=WORLD):
"""Allreduce Callback with tensor fusion optimization.

:param reduce_method: the method to reduce gradiants.
:param group: communication group.
"""

def __init__(self, reduce_method: str, group: Group = WORLD):
reduce_method = reduce_method.lower()
assert reduce_method in ["sum", "mean"]
assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean"
self._reduce_method = reduce_method
self._group = group
self._marked_gm = WeakSet()


Loading…
Cancel
Save