Browse Source

docs(mge/distributed): add distributed.helper docs

GitOrigin-RevId: 37c14aa11f
release-1.1
Megvii Engine Team 4 years ago
parent
commit
6c5cf25f4d
1 changed files with 23 additions and 8 deletions
  1. +23
    -8
      imperative/python/megengine/distributed/helper.py

+ 23
- 8
imperative/python/megengine/distributed/helper.py View File

@@ -21,7 +21,7 @@ from ..functional.param_pack import get_offsets, pack_allreduce_split
from ..functional.utils import copy from ..functional.utils import copy
from ..utils.future import Future from ..utils.future import Future
from .functional import all_reduce_sum, broadcast from .functional import all_reduce_sum, broadcast
from .group import WORLD, group_barrier, is_distributed
from .group import WORLD, Group, group_barrier, is_distributed




class TensorFuture(Future): class TensorFuture(Future):
@@ -54,28 +54,43 @@ def synchronized(func: Callable):
return wrapper return wrapper




def worker(queue, device_type):
def _get_device_count_worker(queue, device_type):
num = get_device_count(device_type) num = get_device_count(device_type)
queue.put(num) queue.put(num)




def get_device_count_by_fork(device_type: str): def get_device_count_by_fork(device_type: str):
"""Get device count in fork thread.
See https://stackoverflow.com/questions/22950047/cuda-initialization-error-after-fork
for more information.
"""
q = mp.Queue() q = mp.Queue()
p = mp.Process(target=worker, args=(q, device_type))
p = mp.Process(target=_get_device_count_worker, args=(q, device_type))
p.start() p.start()
p.join() p.join()
return q.get() return q.get()




def bcast_list_(params, group):
for p in params:
p._reset(broadcast(p, group))
def bcast_list_(inps: list, group: Group = WORLD):
"""Broadcast tensors between given group.

:param inps: input tensors.
:param group: communication group.
"""
for inp in inps:
inp._reset(broadcast(inp, group))




class AllreduceCallback: class AllreduceCallback:
def __init__(self, reduce_method, group=WORLD):
"""Allreduce Callback with tensor fusion optimization.

:param reduce_method: the method to reduce gradiants.
:param group: communication group.
"""

def __init__(self, reduce_method: str, group: Group = WORLD):
reduce_method = reduce_method.lower() reduce_method = reduce_method.lower()
assert reduce_method in ["sum", "mean"]
assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean"
self._reduce_method = reduce_method self._reduce_method = reduce_method
self._group = group self._group = group
self._marked_gm = WeakSet() self._marked_gm = WeakSet()


Loading…
Cancel
Save