GitOrigin-RevId: 1dd5a02a51
release-1.5
@@ -1018,6 +1018,7 @@ endif() | |||||
if(MGE_WITH_DISTRIBUTED) | if(MGE_WITH_DISTRIBUTED) | ||||
set(MEGRAY_WITH_NCCL ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE) | set(MEGRAY_WITH_NCCL ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE) | ||||
set(MEGRAY_WITH_SHM ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE) | |||||
set(MEGRAY_WITH_RCCL ${MGE_WITH_ROCM} CACHE BOOL "Override MegRay option" FORCE) | set(MEGRAY_WITH_RCCL ${MGE_WITH_ROCM} CACHE BOOL "Override MegRay option" FORCE) | ||||
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay) | add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay) | ||||
endif() | endif() | ||||
@@ -6,6 +6,9 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from mprop import mproperty | |||||
from . import group | |||||
from .group import ( | from .group import ( | ||||
WORLD, | WORLD, | ||||
Group, | Group, | ||||
@@ -19,7 +22,20 @@ from .group import ( | |||||
init_process_group, | init_process_group, | ||||
is_distributed, | is_distributed, | ||||
new_group, | new_group, | ||||
override_backend, | |||||
) | ) | ||||
from .helper import bcast_list_, make_allreduce_cb, synchronized | from .helper import bcast_list_, make_allreduce_cb, synchronized | ||||
from .launcher import launcher | from .launcher import launcher | ||||
from .server import Client, Server | from .server import Client, Server | ||||
@mproperty | |||||
def backend(mod): | |||||
assert group._sd, "please call init_process_group first" | |||||
return group._sd.backend | |||||
@backend.setter | |||||
def backend(mod, val): | |||||
assert group._sd, "please call init_process_group first" | |||||
group._sd.backend = val |
@@ -14,9 +14,10 @@ from ..core._imperative_rt.core2 import apply | |||||
from ..core.autodiff.grad import Function, _grad_manager_dict | from ..core.autodiff.grad import Function, _grad_manager_dict | ||||
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | ||||
from ..core.tensor.utils import isscalar, setscalar | from ..core.tensor.utils import isscalar, setscalar | ||||
from ..device import get_default_device | |||||
from ..device import get_default_device, what_is_xpu | |||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank | |||||
from . import group | |||||
from .group import WORLD, Group, get_client, get_mm_server_addr, get_rank | |||||
__all__ = [ | __all__ = [ | ||||
"reduce_sum", | "reduce_sum", | ||||
@@ -34,14 +35,30 @@ __all__ = [ | |||||
] | ] | ||||
_device2backend = { | |||||
"gpu": "nccl", | |||||
"cuda": "nccl", | |||||
"rocm": "rccl", | |||||
} | |||||
def _backend(): | |||||
if group._sd.backend == "auto": | |||||
return _device2backend[what_is_xpu()] | |||||
else: | |||||
return group._sd.backend | |||||
def collective_comm(inp, mode, group, device): | def collective_comm(inp, mode, group, device): | ||||
"""Helper function for applying collective communication functions.""" | """Helper function for applying collective communication functions.""" | ||||
assert isinstance(group, Group) | assert isinstance(group, Group) | ||||
if group is None: | if group is None: | ||||
return inp | return inp | ||||
if device is None: | |||||
device = "" | |||||
addr, port = get_mm_server_addr() | addr, port = get_mm_server_addr() | ||||
op = CollectiveComm( | op = CollectiveComm( | ||||
key=group.key, | |||||
key=group.key + _backend(), | |||||
nr_devices=group.size, | nr_devices=group.size, | ||||
rank=group.rank, | rank=group.rank, | ||||
is_root=(group.rank == 0), | is_root=(group.rank == 0), | ||||
@@ -50,7 +67,7 @@ def collective_comm(inp, mode, group, device): | |||||
port=port, | port=port, | ||||
mode=mode, | mode=mode, | ||||
dtype=inp.dtype, | dtype=inp.dtype, | ||||
backend=get_backend(), | |||||
backend=_backend(), | |||||
comp_node=device, | comp_node=device, | ||||
) | ) | ||||
(result,) = apply(op, inp) | (result,) = apply(op, inp) | ||||
@@ -112,8 +129,8 @@ def _bcast_tracer_state(group, inp): | |||||
g._refkeeper.append(inp) | g._refkeeper.append(inp) | ||||
def _dummy_input(shape, dtype, device=""): | |||||
if device == "": | |||||
def _dummy_input(shape, dtype, device=None): | |||||
if device is None: | |||||
device = get_default_device() | device = get_default_device() | ||||
inp = Tensor(0, dtype=dtype, device=device) | inp = Tensor(0, dtype=dtype, device=device) | ||||
if len(shape) > 0: | if len(shape) > 0: | ||||
@@ -122,14 +139,14 @@ def _dummy_input(shape, dtype, device=""): | |||||
class _ReduceSum(Function): | class _ReduceSum(Function): | ||||
def __init__(self, group=WORLD, device=""): | |||||
def __init__(self, group=WORLD, device=None): | |||||
self.group = group | self.group = group | ||||
self.out_device = device | self.out_device = device | ||||
def forward(self, data): | def forward(self, data): | ||||
self.in_device = str(data.device) | self.in_device = str(data.device) | ||||
return collective_comm( | return collective_comm( | ||||
data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device | |||||
data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device, | |||||
) | ) | ||||
def backward(self, grad): | def backward(self, grad): | ||||
@@ -139,7 +156,7 @@ class _ReduceSum(Function): | |||||
def reduce_sum( | def reduce_sum( | ||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
) -> Tensor: | ) -> Tensor: | ||||
""" | """ | ||||
Create reduce_sum operator for collective communication. | Create reduce_sum operator for collective communication. | ||||
@@ -158,14 +175,14 @@ def reduce_sum( | |||||
class _Broadcast(Function): | class _Broadcast(Function): | ||||
def __init__(self, group=WORLD, device=""): | |||||
def __init__(self, group=WORLD, device=None): | |||||
self.group = group | self.group = group | ||||
self.out_device = device | self.out_device = device | ||||
def forward(self, data): | def forward(self, data): | ||||
self.in_device = str(data.device) | self.in_device = str(data.device) | ||||
return collective_comm( | return collective_comm( | ||||
data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device | |||||
data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device, | |||||
) | ) | ||||
def backward(self, grad): | def backward(self, grad): | ||||
@@ -175,7 +192,7 @@ class _Broadcast(Function): | |||||
def broadcast( | def broadcast( | ||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
) -> Tensor: | ) -> Tensor: | ||||
""" | """ | ||||
Create broadcast operator for collective communication. | Create broadcast operator for collective communication. | ||||
@@ -197,14 +214,14 @@ def broadcast( | |||||
def _bcast_param( | def _bcast_param( | ||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None | |||||
) -> Tensor: | ) -> Tensor: | ||||
mode = CollectiveComm.Mode.BROADCAST | mode = CollectiveComm.Mode.BROADCAST | ||||
return collective_comm(inp, mode, group, device) | return collective_comm(inp, mode, group, device) | ||||
def all_gather( | def all_gather( | ||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
) -> Tensor: | ) -> Tensor: | ||||
""" | """ | ||||
Create all_gather operator for collective communication. | Create all_gather operator for collective communication. | ||||
@@ -218,7 +235,7 @@ def all_gather( | |||||
def reduce_scatter_sum( | def reduce_scatter_sum( | ||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
) -> Tensor: | ) -> Tensor: | ||||
""" | """ | ||||
Create reduce_scatter_sum operator for collective communication. | Create reduce_scatter_sum operator for collective communication. | ||||
@@ -232,7 +249,7 @@ def reduce_scatter_sum( | |||||
def all_reduce_sum( | def all_reduce_sum( | ||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
) -> Tensor: | ) -> Tensor: | ||||
""" | """ | ||||
Create all_reduce_sum operator for collective communication. | Create all_reduce_sum operator for collective communication. | ||||
@@ -246,7 +263,7 @@ def all_reduce_sum( | |||||
def all_reduce_max( | def all_reduce_max( | ||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
) -> Tensor: | ) -> Tensor: | ||||
""" | """ | ||||
Create all_reduce_max operator for collective communication. | Create all_reduce_max operator for collective communication. | ||||
@@ -260,7 +277,7 @@ def all_reduce_max( | |||||
def all_reduce_min( | def all_reduce_min( | ||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
) -> Tensor: | ) -> Tensor: | ||||
""" | """ | ||||
Create all_reduce_min operator for collective communication. | Create all_reduce_min operator for collective communication. | ||||
@@ -274,7 +291,7 @@ def all_reduce_min( | |||||
class _Gather(Function): | class _Gather(Function): | ||||
def __init__(self, group=WORLD, device=""): | |||||
def __init__(self, group=WORLD, device=None): | |||||
self.group = group | self.group = group | ||||
self.out_device = device | self.out_device = device | ||||
@@ -291,7 +308,7 @@ class _Gather(Function): | |||||
def gather( | def gather( | ||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
) -> Tensor: | ) -> Tensor: | ||||
""" | """ | ||||
Create gather operator for collective communication. | Create gather operator for collective communication. | ||||
@@ -311,7 +328,7 @@ def gather( | |||||
class _Scatter(Function): | class _Scatter(Function): | ||||
def __init__(self, group=WORLD, device=""): | |||||
def __init__(self, group=WORLD, device=None): | |||||
self.group = group | self.group = group | ||||
self.out_device = device | self.out_device = device | ||||
@@ -328,7 +345,7 @@ class _Scatter(Function): | |||||
def scatter( | def scatter( | ||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
) -> Tensor: | ) -> Tensor: | ||||
""" | """ | ||||
Create scatter operator for collective communication. | Create scatter operator for collective communication. | ||||
@@ -350,7 +367,7 @@ def scatter( | |||||
def all_to_all( | def all_to_all( | ||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||||
) -> Tensor: | ) -> Tensor: | ||||
""" | """ | ||||
Create all_to_all operator for collective communication. | Create all_to_all operator for collective communication. | ||||
@@ -407,7 +424,7 @@ class _RemoteRecv(Function): | |||||
remote_send(grad, self.op.rank_from) | remote_send(grad, self.op.rank_from) | ||||
def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||||
def remote_send(inp: Tensor, dest_rank: int): | |||||
""" | """ | ||||
Send a Tensor to a remote process. | Send a Tensor to a remote process. | ||||
@@ -423,13 +440,13 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||||
op.key = group.key | op.key = group.key | ||||
op.addr, op.port = get_mm_server_addr() | op.addr, op.port = get_mm_server_addr() | ||||
op.rank_to = dest_rank | op.rank_to = dest_rank | ||||
op.backend = get_backend() | |||||
op.backend = _backend() | |||||
(out,) = apply(_RemoteSend(op), inp) | (out,) = apply(_RemoteSend(op), inp) | ||||
_save_output_for_autodiff(inp, out) | _save_output_for_autodiff(inp, out) | ||||
def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tensor: | |||||
def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor: | |||||
""" | """ | ||||
Receive a Tensor from a remote process. | Receive a Tensor from a remote process. | ||||
@@ -459,7 +476,7 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tenso | |||||
op.dtype = dtype | op.dtype = dtype | ||||
op.addr, op.port = get_mm_server_addr() | op.addr, op.port = get_mm_server_addr() | ||||
op.rank_from = src_rank | op.rank_from = src_rank | ||||
op.backend = get_backend() | |||||
op.backend = _backend() | |||||
(ret,) = apply(_RemoteRecv(op), inp) | (ret,) = apply(_RemoteRecv(op), inp) | ||||
if _isscalar: | if _isscalar: | ||||
@@ -7,8 +7,11 @@ | |||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
import time | import time | ||||
from contextlib import contextmanager | |||||
from typing import List, Optional, Tuple | from typing import List, Optional, Tuple | ||||
from mprop import mproperty | |||||
from ..device import set_default_device, what_is_xpu | from ..device import set_default_device, what_is_xpu | ||||
from ..random import seed | from ..random import seed | ||||
from .server import Client, Server | from .server import Client, Server | ||||
@@ -26,6 +29,7 @@ class StaticData: | |||||
backend = None | backend = None | ||||
next_stream = None | next_stream = None | ||||
device_type = None | device_type = None | ||||
machine_ranks = None | |||||
_sd = None | _sd = None | ||||
@@ -55,6 +59,7 @@ class Group: | |||||
self.proc_ranks = proc_ranks | self.proc_ranks = proc_ranks | ||||
self.stream = _sd.next_stream | self.stream = _sd.next_stream | ||||
_sd.next_stream += 1 | _sd.next_stream += 1 | ||||
self.is_single_machine_cache = None | |||||
def check(self, proc_ranks): | def check(self, proc_ranks): | ||||
assert _sd is not None, "please call init_process_group first" | assert _sd is not None, "please call init_process_group first" | ||||
@@ -83,17 +88,23 @@ class Group: | |||||
assert len(self.proc_ranks) > 0, "invalid group" | assert len(self.proc_ranks) > 0, "invalid group" | ||||
return "{}{}:{}".format(_sd.device_type, _sd.device, self.stream) | return "{}{}:{}".format(_sd.device_type, _sd.device, self.stream) | ||||
WORLD = Group([]) | |||||
@property | |||||
def is_single_machine(self): | |||||
if self.is_single_machine_cache is not None: | |||||
return self.is_single_machine_cache | |||||
assert _sd is not None, "please call init_process_group first" | |||||
for rank in self.proc_ranks: | |||||
if rank not in _sd.machine_ranks: | |||||
self.is_single_machine_cache = False | |||||
return False | |||||
self.is_single_machine_cache = True | |||||
return True | |||||
_device2backend = { | |||||
"gpu": "nccl", | |||||
"cuda": "nccl", | |||||
"rocm": "rccl", | |||||
} | |||||
WORLD = Group([]) | |||||
_backends = {"nccl", "rccl", "ucx"} | |||||
_devices = {"gpu", "cuda", "rocm"} | |||||
_backends = {"nccl", "rccl", "ucx", "auto"} | |||||
def init_process_group( | def init_process_group( | ||||
@@ -102,7 +113,7 @@ def init_process_group( | |||||
world_size: int, | world_size: int, | ||||
rank: int, | rank: int, | ||||
device: int, | device: int, | ||||
backend: Optional[str] = None, | |||||
backend: Optional[str] = "auto", | |||||
device_type: str = "xpu", | device_type: str = "xpu", | ||||
) -> None: | ) -> None: | ||||
""" | """ | ||||
@@ -113,10 +124,9 @@ def init_process_group( | |||||
:param world_size: total number of processes participating in the job. | :param world_size: total number of processes participating in the job. | ||||
:param rank: rank of the current process. | :param rank: rank of the current process. | ||||
:param device: the GPU device id to bind this process to. | :param device: the GPU device id to bind this process to. | ||||
:param backend: communicator backend, currently support 'nccl' and 'ucx'. | |||||
:param backend: communicator backend, currently support 'nccl' and 'shm'. | |||||
""" | """ | ||||
physical_device_type = what_is_xpu() if device_type == "xpu" else device_type | physical_device_type = what_is_xpu() if device_type == "xpu" else device_type | ||||
backend = _device2backend[physical_device_type] if backend is None else backend | |||||
if not isinstance(master_ip, str): | if not isinstance(master_ip, str): | ||||
raise TypeError("Expect type str but got {}".format(type(master_ip))) | raise TypeError("Expect type str but got {}".format(type(master_ip))) | ||||
if not isinstance(port, int): | if not isinstance(port, int): | ||||
@@ -131,7 +141,7 @@ def init_process_group( | |||||
raise ValueError( | raise ValueError( | ||||
"backend should be one of {} but got {}".format(_backends, backend) | "backend should be one of {} but got {}".format(_backends, backend) | ||||
) | ) | ||||
if physical_device_type not in _device2backend: | |||||
if physical_device_type not in _devices: | |||||
raise ValueError( | raise ValueError( | ||||
"{} is not a valid distributed device type".format(device_type) | "{} is not a valid distributed device type".format(device_type) | ||||
) | ) | ||||
@@ -161,6 +171,30 @@ def init_process_group( | |||||
seed(int(time.time()) + rank) | seed(int(time.time()) + rank) | ||||
def _set_machine_ranks(ranks) -> None: | |||||
global _sd | |||||
assert _sd is not None | |||||
_sd.machine_ranks = ranks | |||||
@contextmanager | |||||
def override_backend(new_backend: str): | |||||
""" | |||||
Override distributed backend | |||||
:param new_backend: communicator backend set in this context. | |||||
""" | |||||
global _sd | |||||
assert _sd, "please call init_process_group first" | |||||
old_backend = _sd.backend | |||||
_sd.backend = new_backend | |||||
try: | |||||
yield | |||||
finally: | |||||
_sd.backend = old_backend | |||||
def is_distributed() -> bool: | def is_distributed() -> bool: | ||||
"""Return True if the distributed process group has been initialized.""" | """Return True if the distributed process group has been initialized.""" | ||||
return _sd is not None | return _sd is not None | ||||
@@ -22,8 +22,9 @@ from ..core.ops.builtin import ParamPackConcat, ParamPackSplit | |||||
from ..functional.tensor import copy | from ..functional.tensor import copy | ||||
from ..tensor import Tensor | from ..tensor import Tensor | ||||
from ..utils.future import Future | from ..utils.future import Future | ||||
from . import group as _group | |||||
from .functional import _bcast_param, all_reduce_sum, broadcast | from .functional import _bcast_param, all_reduce_sum, broadcast | ||||
from .group import WORLD, Group, group_barrier, is_distributed | |||||
from .group import WORLD, Group, group_barrier, is_distributed, override_backend | |||||
def param_pack_split(inp: Tensor, offsets: list, shapes: list): | def param_pack_split(inp: Tensor, offsets: list, shapes: list): | ||||
@@ -118,10 +119,30 @@ def get_offsets(shapes): | |||||
return offsets | return offsets | ||||
_enable_p2p_cache = None | |||||
def _check_enable_p2p(): | |||||
global _enable_p2p_cache | |||||
if _enable_p2p_cache is not None: | |||||
return _enable_p2p_cache | |||||
cmd = ["nvidia-smi", "topo", "-p2p", "w"] | |||||
import subprocess | |||||
output = subprocess.run(cmd, stdout=subprocess.PIPE).stdout | |||||
if output.count(b"OK") > 1: | |||||
_enable_p2p_cache = True | |||||
return True | |||||
else: | |||||
_enable_p2p_cache = False | |||||
return False | |||||
def pack_allreduce_split(pack_list, shapes, group, reduce_method): | def pack_allreduce_split(pack_list, shapes, group, reduce_method): | ||||
offsets_val = get_offsets(shapes) | offsets_val = get_offsets(shapes) | ||||
offsets = Tensor(offsets_val) | offsets = Tensor(offsets_val) | ||||
packed_grads = param_pack_concat(pack_list, offsets, offsets_val) | packed_grads = param_pack_concat(pack_list, offsets, offsets_val) | ||||
packed_grads = all_reduce_sum(packed_grads, group, group.comp_node) | packed_grads = all_reduce_sum(packed_grads, group, group.comp_node) | ||||
if reduce_method == "mean": | if reduce_method == "mean": | ||||
packed_grads /= group.size | packed_grads /= group.size | ||||
@@ -207,9 +228,10 @@ class AllreduceCallback: | |||||
:param reduce_method: the method to reduce gradiants. | :param reduce_method: the method to reduce gradiants. | ||||
:param group: communication group. | :param group: communication group. | ||||
:param backend: override distributed backend in allreduce | |||||
""" | """ | ||||
def __init__(self, reduce_method: str, group: Group = WORLD): | |||||
def __init__(self, reduce_method: str, group: Group = WORLD, backend: str = None): | |||||
reduce_method = reduce_method.lower() | reduce_method = reduce_method.lower() | ||||
assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean" | assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean" | ||||
self._reduce_method = reduce_method | self._reduce_method = reduce_method | ||||
@@ -217,6 +239,15 @@ class AllreduceCallback: | |||||
self._marked_gm = WeakSet() | self._marked_gm = WeakSet() | ||||
self._param_pack_thd = 10 * 1024 * 1024 | self._param_pack_thd = 10 * 1024 * 1024 | ||||
self._reset() | self._reset() | ||||
if backend is None: | |||||
assert _group._sd, "please call init_process_group first" | |||||
backend = _group._sd.backend | |||||
if backend == "auto": | |||||
if group.is_single_machine and not _check_enable_p2p(): | |||||
backend = "shm" | |||||
else: | |||||
backend = "nccl" | |||||
self._backend = backend | |||||
def _reset(self): | def _reset(self): | ||||
self._params = [] | self._params = [] | ||||
@@ -231,9 +262,10 @@ class AllreduceCallback: | |||||
return | return | ||||
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] | grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] | ||||
shapes = [p._tuple_shape for p in self._packing_list[dtype]] | shapes = [p._tuple_shape for p in self._packing_list[dtype]] | ||||
reduced_grads = pack_allreduce_split( | |||||
grad_list, shapes, self._group, self._reduce_method | |||||
) | |||||
with override_backend(self._backend): | |||||
reduced_grads = pack_allreduce_split( | |||||
grad_list, shapes, self._group, self._reduce_method | |||||
) | |||||
for param, grad in zip(self._packing_list[dtype], reduced_grads): | for param, grad in zip(self._packing_list[dtype], reduced_grads): | ||||
self._gradients_dict[param] = grad | self._gradients_dict[param] = grad | ||||
self._packing_list[dtype] = [] | self._packing_list[dtype] = [] | ||||
@@ -14,7 +14,7 @@ import queue | |||||
from .. import _exit | from .. import _exit | ||||
from ..core._imperative_rt.core2 import full_sync | from ..core._imperative_rt.core2 import full_sync | ||||
from ..logger import get_logger | from ..logger import get_logger | ||||
from .group import group_barrier, init_process_group | |||||
from .group import _set_machine_ranks, group_barrier, init_process_group | |||||
from .helper import _check_device_initialized, get_device_count_by_fork | from .helper import _check_device_initialized, get_device_count_by_fork | ||||
from .server import Client, Server | from .server import Client, Server | ||||
@@ -34,7 +34,9 @@ def _run_wrapped( | |||||
device_type, | device_type, | ||||
args, | args, | ||||
kwargs, | kwargs, | ||||
backend, | |||||
queue: mp.Queue, | queue: mp.Queue, | ||||
machine_ranks: list, | |||||
): | ): | ||||
"""Init distributed process group and run wrapped function.""" | """Init distributed process group and run wrapped function.""" | ||||
_check_device_initialized(device_type) | _check_device_initialized(device_type) | ||||
@@ -44,10 +46,12 @@ def _run_wrapped( | |||||
world_size=world_size, | world_size=world_size, | ||||
rank=rank, | rank=rank, | ||||
device=dev, | device=dev, | ||||
backend=backend, | |||||
device_type=device_type, | device_type=device_type, | ||||
) | ) | ||||
# set NCCL_LAUNCH_MODE to avoid deadlock | # set NCCL_LAUNCH_MODE to avoid deadlock | ||||
os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL" | os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL" | ||||
_set_machine_ranks(machine_ranks) | |||||
if is_multimachine: | if is_multimachine: | ||||
group_barrier() | group_barrier() | ||||
ret = func(*args, **kwargs) | ret = func(*args, **kwargs) | ||||
@@ -67,6 +71,7 @@ class launcher: | |||||
:param rank_start: start number for rank. | :param rank_start: start number for rank. | ||||
:param master_ip: ip address for master node (where the rank 0 is). | :param master_ip: ip address for master node (where the rank 0 is). | ||||
:param port: server port for distributed server. | :param port: server port for distributed server. | ||||
:param backend: set default collective communication backend. | |||||
""" | """ | ||||
def __new__(cls, *args, **kwargs): | def __new__(cls, *args, **kwargs): | ||||
@@ -83,6 +88,7 @@ class launcher: | |||||
master_ip="localhost", | master_ip="localhost", | ||||
port=0, | port=0, | ||||
device_type="xpu", | device_type="xpu", | ||||
backend="auto", | |||||
): | ): | ||||
self.func = func | self.func = func | ||||
self.n_gpus = ( | self.n_gpus = ( | ||||
@@ -93,6 +99,7 @@ class launcher: | |||||
self.master_ip = master_ip | self.master_ip = master_ip | ||||
self.port = port | self.port = port | ||||
self.device_type = device_type | self.device_type = device_type | ||||
self.backend = backend | |||||
# master node create server | # master node create server | ||||
if self.rank_start == 0: | if self.rank_start == 0: | ||||
self.server = Server(self.port) | self.server = Server(self.port) | ||||
@@ -104,6 +111,7 @@ class launcher: | |||||
procs = [] | procs = [] | ||||
queue = mp.Queue(self.n_gpus) | queue = mp.Queue(self.n_gpus) | ||||
results = [None] * self.n_gpus | results = [None] * self.n_gpus | ||||
machine_ranks = [i + self.rank_start for i in range(self.n_gpus)] | |||||
for dev in range(self.n_gpus): | for dev in range(self.n_gpus): | ||||
p = mp.Process( | p = mp.Process( | ||||
target=_run_wrapped, | target=_run_wrapped, | ||||
@@ -118,7 +126,9 @@ class launcher: | |||||
self.device_type, | self.device_type, | ||||
args, | args, | ||||
kwargs, | kwargs, | ||||
self.backend, | |||||
queue, | queue, | ||||
machine_ranks, | |||||
), | ), | ||||
) | ) | ||||
p.start() | p.start() | ||||
@@ -11,6 +11,7 @@ | |||||
#include "megbrain/opr/megray_helper.h" | #include "megbrain/opr/megray_helper.h" | ||||
#include "megbrain/comp_node_env.h" | #include "megbrain/comp_node_env.h" | ||||
#include "megray/common.h" | |||||
using namespace mgb; | using namespace mgb; | ||||
using namespace opr; | using namespace opr; | ||||
@@ -39,6 +40,8 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) { | |||||
return MegRay::MEGRAY_RCCL; | return MegRay::MEGRAY_RCCL; | ||||
} else if (backend == "ucx") { | } else if (backend == "ucx") { | ||||
return MegRay::MEGRAY_UCX; | return MegRay::MEGRAY_UCX; | ||||
} else if (backend == "shm") { | |||||
return MegRay::MEGRAY_SHM; | |||||
} else { | } else { | ||||
mgb_throw(MegBrainError, "back CollectiveComm backend"); | mgb_throw(MegBrainError, "back CollectiveComm backend"); | ||||
} | } | ||||
@@ -90,7 +93,7 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | |||||
if (rank == root) { | if (rank == root) { | ||||
char* c = MegRay::get_host_ip(); | char* c = MegRay::get_host_ip(); | ||||
master_ip = std::string(c); | master_ip = std::string(c); | ||||
delete c; | |||||
delete [] c; | |||||
port = MegRay::get_free_port(); | port = MegRay::get_free_port(); | ||||
auto ret = MegRay::create_server(size, port); | auto ret = MegRay::create_server(size, port); | ||||
mgb_assert(ret == MegRay::Status::MEGRAY_OK); | mgb_assert(ret == MegRay::Status::MEGRAY_OK); | ||||