GitOrigin-RevId: 1dd5a02a51
release-1.5
@@ -1018,6 +1018,7 @@ endif() | |||
if(MGE_WITH_DISTRIBUTED) | |||
set(MEGRAY_WITH_NCCL ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE) | |||
set(MEGRAY_WITH_SHM ${MGE_WITH_CUDA} CACHE BOOL "Override MegRay option" FORCE) | |||
set(MEGRAY_WITH_RCCL ${MGE_WITH_ROCM} CACHE BOOL "Override MegRay option" FORCE) | |||
add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/MegRay) | |||
endif() | |||
@@ -6,6 +6,9 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from mprop import mproperty | |||
from . import group | |||
from .group import ( | |||
WORLD, | |||
Group, | |||
@@ -19,7 +22,20 @@ from .group import ( | |||
init_process_group, | |||
is_distributed, | |||
new_group, | |||
override_backend, | |||
) | |||
from .helper import bcast_list_, make_allreduce_cb, synchronized | |||
from .launcher import launcher | |||
from .server import Client, Server | |||
@mproperty | |||
def backend(mod): | |||
assert group._sd, "please call init_process_group first" | |||
return group._sd.backend | |||
@backend.setter | |||
def backend(mod, val): | |||
assert group._sd, "please call init_process_group first" | |||
group._sd.backend = val |
@@ -14,9 +14,10 @@ from ..core._imperative_rt.core2 import apply | |||
from ..core.autodiff.grad import Function, _grad_manager_dict | |||
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||
from ..core.tensor.utils import isscalar, setscalar | |||
from ..device import get_default_device | |||
from ..device import get_default_device, what_is_xpu | |||
from ..tensor import Tensor | |||
from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank | |||
from . import group | |||
from .group import WORLD, Group, get_client, get_mm_server_addr, get_rank | |||
__all__ = [ | |||
"reduce_sum", | |||
@@ -34,14 +35,30 @@ __all__ = [ | |||
] | |||
_device2backend = { | |||
"gpu": "nccl", | |||
"cuda": "nccl", | |||
"rocm": "rccl", | |||
} | |||
def _backend(): | |||
if group._sd.backend == "auto": | |||
return _device2backend[what_is_xpu()] | |||
else: | |||
return group._sd.backend | |||
def collective_comm(inp, mode, group, device): | |||
"""Helper function for applying collective communication functions.""" | |||
assert isinstance(group, Group) | |||
if group is None: | |||
return inp | |||
if device is None: | |||
device = "" | |||
addr, port = get_mm_server_addr() | |||
op = CollectiveComm( | |||
key=group.key, | |||
key=group.key + _backend(), | |||
nr_devices=group.size, | |||
rank=group.rank, | |||
is_root=(group.rank == 0), | |||
@@ -50,7 +67,7 @@ def collective_comm(inp, mode, group, device): | |||
port=port, | |||
mode=mode, | |||
dtype=inp.dtype, | |||
backend=get_backend(), | |||
backend=_backend(), | |||
comp_node=device, | |||
) | |||
(result,) = apply(op, inp) | |||
@@ -112,8 +129,8 @@ def _bcast_tracer_state(group, inp): | |||
g._refkeeper.append(inp) | |||
def _dummy_input(shape, dtype, device=""): | |||
if device == "": | |||
def _dummy_input(shape, dtype, device=None): | |||
if device is None: | |||
device = get_default_device() | |||
inp = Tensor(0, dtype=dtype, device=device) | |||
if len(shape) > 0: | |||
@@ -122,14 +139,14 @@ def _dummy_input(shape, dtype, device=""): | |||
class _ReduceSum(Function): | |||
def __init__(self, group=WORLD, device=""): | |||
def __init__(self, group=WORLD, device=None): | |||
self.group = group | |||
self.out_device = device | |||
def forward(self, data): | |||
self.in_device = str(data.device) | |||
return collective_comm( | |||
data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device | |||
data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device, | |||
) | |||
def backward(self, grad): | |||
@@ -139,7 +156,7 @@ class _ReduceSum(Function): | |||
def reduce_sum( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
) -> Tensor: | |||
""" | |||
Create reduce_sum operator for collective communication. | |||
@@ -158,14 +175,14 @@ def reduce_sum( | |||
class _Broadcast(Function): | |||
def __init__(self, group=WORLD, device=""): | |||
def __init__(self, group=WORLD, device=None): | |||
self.group = group | |||
self.out_device = device | |||
def forward(self, data): | |||
self.in_device = str(data.device) | |||
return collective_comm( | |||
data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device | |||
data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device, | |||
) | |||
def backward(self, grad): | |||
@@ -175,7 +192,7 @@ class _Broadcast(Function): | |||
def broadcast( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
) -> Tensor: | |||
""" | |||
Create broadcast operator for collective communication. | |||
@@ -197,14 +214,14 @@ def broadcast( | |||
def _bcast_param( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None | |||
) -> Tensor: | |||
mode = CollectiveComm.Mode.BROADCAST | |||
return collective_comm(inp, mode, group, device) | |||
def all_gather( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
) -> Tensor: | |||
""" | |||
Create all_gather operator for collective communication. | |||
@@ -218,7 +235,7 @@ def all_gather( | |||
def reduce_scatter_sum( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
) -> Tensor: | |||
""" | |||
Create reduce_scatter_sum operator for collective communication. | |||
@@ -232,7 +249,7 @@ def reduce_scatter_sum( | |||
def all_reduce_sum( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
) -> Tensor: | |||
""" | |||
Create all_reduce_sum operator for collective communication. | |||
@@ -246,7 +263,7 @@ def all_reduce_sum( | |||
def all_reduce_max( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
) -> Tensor: | |||
""" | |||
Create all_reduce_max operator for collective communication. | |||
@@ -260,7 +277,7 @@ def all_reduce_max( | |||
def all_reduce_min( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
) -> Tensor: | |||
""" | |||
Create all_reduce_min operator for collective communication. | |||
@@ -274,7 +291,7 @@ def all_reduce_min( | |||
class _Gather(Function): | |||
def __init__(self, group=WORLD, device=""): | |||
def __init__(self, group=WORLD, device=None): | |||
self.group = group | |||
self.out_device = device | |||
@@ -291,7 +308,7 @@ class _Gather(Function): | |||
def gather( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
) -> Tensor: | |||
""" | |||
Create gather operator for collective communication. | |||
@@ -311,7 +328,7 @@ def gather( | |||
class _Scatter(Function): | |||
def __init__(self, group=WORLD, device=""): | |||
def __init__(self, group=WORLD, device=None): | |||
self.group = group | |||
self.out_device = device | |||
@@ -328,7 +345,7 @@ class _Scatter(Function): | |||
def scatter( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
) -> Tensor: | |||
""" | |||
Create scatter operator for collective communication. | |||
@@ -350,7 +367,7 @@ def scatter( | |||
def all_to_all( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, | |||
) -> Tensor: | |||
""" | |||
Create all_to_all operator for collective communication. | |||
@@ -407,7 +424,7 @@ class _RemoteRecv(Function): | |||
remote_send(grad, self.op.rank_from) | |||
def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||
def remote_send(inp: Tensor, dest_rank: int): | |||
""" | |||
Send a Tensor to a remote process. | |||
@@ -423,13 +440,13 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||
op.key = group.key | |||
op.addr, op.port = get_mm_server_addr() | |||
op.rank_to = dest_rank | |||
op.backend = get_backend() | |||
op.backend = _backend() | |||
(out,) = apply(_RemoteSend(op), inp) | |||
_save_output_for_autodiff(inp, out) | |||
def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tensor: | |||
def remote_recv(src_rank: int, device: Optional[str] = None, inp=None) -> Tensor: | |||
""" | |||
Receive a Tensor from a remote process. | |||
@@ -459,7 +476,7 @@ def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tenso | |||
op.dtype = dtype | |||
op.addr, op.port = get_mm_server_addr() | |||
op.rank_from = src_rank | |||
op.backend = get_backend() | |||
op.backend = _backend() | |||
(ret,) = apply(_RemoteRecv(op), inp) | |||
if _isscalar: | |||
@@ -7,8 +7,11 @@ | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
import time | |||
from contextlib import contextmanager | |||
from typing import List, Optional, Tuple | |||
from mprop import mproperty | |||
from ..device import set_default_device, what_is_xpu | |||
from ..random import seed | |||
from .server import Client, Server | |||
@@ -26,6 +29,7 @@ class StaticData: | |||
backend = None | |||
next_stream = None | |||
device_type = None | |||
machine_ranks = None | |||
_sd = None | |||
@@ -55,6 +59,7 @@ class Group: | |||
self.proc_ranks = proc_ranks | |||
self.stream = _sd.next_stream | |||
_sd.next_stream += 1 | |||
self.is_single_machine_cache = None | |||
def check(self, proc_ranks): | |||
assert _sd is not None, "please call init_process_group first" | |||
@@ -83,17 +88,23 @@ class Group: | |||
assert len(self.proc_ranks) > 0, "invalid group" | |||
return "{}{}:{}".format(_sd.device_type, _sd.device, self.stream) | |||
WORLD = Group([]) | |||
@property | |||
def is_single_machine(self): | |||
if self.is_single_machine_cache is not None: | |||
return self.is_single_machine_cache | |||
assert _sd is not None, "please call init_process_group first" | |||
for rank in self.proc_ranks: | |||
if rank not in _sd.machine_ranks: | |||
self.is_single_machine_cache = False | |||
return False | |||
self.is_single_machine_cache = True | |||
return True | |||
_device2backend = { | |||
"gpu": "nccl", | |||
"cuda": "nccl", | |||
"rocm": "rccl", | |||
} | |||
WORLD = Group([]) | |||
_backends = {"nccl", "rccl", "ucx"} | |||
_devices = {"gpu", "cuda", "rocm"} | |||
_backends = {"nccl", "rccl", "ucx", "auto"} | |||
def init_process_group( | |||
@@ -102,7 +113,7 @@ def init_process_group( | |||
world_size: int, | |||
rank: int, | |||
device: int, | |||
backend: Optional[str] = None, | |||
backend: Optional[str] = "auto", | |||
device_type: str = "xpu", | |||
) -> None: | |||
""" | |||
@@ -113,10 +124,9 @@ def init_process_group( | |||
:param world_size: total number of processes participating in the job. | |||
:param rank: rank of the current process. | |||
:param device: the GPU device id to bind this process to. | |||
:param backend: communicator backend, currently support 'nccl' and 'ucx'. | |||
:param backend: communicator backend, currently support 'nccl' and 'shm'. | |||
""" | |||
physical_device_type = what_is_xpu() if device_type == "xpu" else device_type | |||
backend = _device2backend[physical_device_type] if backend is None else backend | |||
if not isinstance(master_ip, str): | |||
raise TypeError("Expect type str but got {}".format(type(master_ip))) | |||
if not isinstance(port, int): | |||
@@ -131,7 +141,7 @@ def init_process_group( | |||
raise ValueError( | |||
"backend should be one of {} but got {}".format(_backends, backend) | |||
) | |||
if physical_device_type not in _device2backend: | |||
if physical_device_type not in _devices: | |||
raise ValueError( | |||
"{} is not a valid distributed device type".format(device_type) | |||
) | |||
@@ -161,6 +171,30 @@ def init_process_group( | |||
seed(int(time.time()) + rank) | |||
def _set_machine_ranks(ranks) -> None: | |||
global _sd | |||
assert _sd is not None | |||
_sd.machine_ranks = ranks | |||
@contextmanager | |||
def override_backend(new_backend: str): | |||
""" | |||
Override distributed backend | |||
:param new_backend: communicator backend set in this context. | |||
""" | |||
global _sd | |||
assert _sd, "please call init_process_group first" | |||
old_backend = _sd.backend | |||
_sd.backend = new_backend | |||
try: | |||
yield | |||
finally: | |||
_sd.backend = old_backend | |||
def is_distributed() -> bool: | |||
"""Return True if the distributed process group has been initialized.""" | |||
return _sd is not None | |||
@@ -22,8 +22,9 @@ from ..core.ops.builtin import ParamPackConcat, ParamPackSplit | |||
from ..functional.tensor import copy | |||
from ..tensor import Tensor | |||
from ..utils.future import Future | |||
from . import group as _group | |||
from .functional import _bcast_param, all_reduce_sum, broadcast | |||
from .group import WORLD, Group, group_barrier, is_distributed | |||
from .group import WORLD, Group, group_barrier, is_distributed, override_backend | |||
def param_pack_split(inp: Tensor, offsets: list, shapes: list): | |||
@@ -118,10 +119,30 @@ def get_offsets(shapes): | |||
return offsets | |||
_enable_p2p_cache = None | |||
def _check_enable_p2p(): | |||
global _enable_p2p_cache | |||
if _enable_p2p_cache is not None: | |||
return _enable_p2p_cache | |||
cmd = ["nvidia-smi", "topo", "-p2p", "w"] | |||
import subprocess | |||
output = subprocess.run(cmd, stdout=subprocess.PIPE).stdout | |||
if output.count(b"OK") > 1: | |||
_enable_p2p_cache = True | |||
return True | |||
else: | |||
_enable_p2p_cache = False | |||
return False | |||
def pack_allreduce_split(pack_list, shapes, group, reduce_method): | |||
offsets_val = get_offsets(shapes) | |||
offsets = Tensor(offsets_val) | |||
packed_grads = param_pack_concat(pack_list, offsets, offsets_val) | |||
packed_grads = all_reduce_sum(packed_grads, group, group.comp_node) | |||
if reduce_method == "mean": | |||
packed_grads /= group.size | |||
@@ -207,9 +228,10 @@ class AllreduceCallback: | |||
:param reduce_method: the method to reduce gradiants. | |||
:param group: communication group. | |||
:param backend: override distributed backend in allreduce | |||
""" | |||
def __init__(self, reduce_method: str, group: Group = WORLD): | |||
def __init__(self, reduce_method: str, group: Group = WORLD, backend: str = None): | |||
reduce_method = reduce_method.lower() | |||
assert reduce_method in ["sum", "mean"], "reduce_method should be sum or mean" | |||
self._reduce_method = reduce_method | |||
@@ -217,6 +239,15 @@ class AllreduceCallback: | |||
self._marked_gm = WeakSet() | |||
self._param_pack_thd = 10 * 1024 * 1024 | |||
self._reset() | |||
if backend is None: | |||
assert _group._sd, "please call init_process_group first" | |||
backend = _group._sd.backend | |||
if backend == "auto": | |||
if group.is_single_machine and not _check_enable_p2p(): | |||
backend = "shm" | |||
else: | |||
backend = "nccl" | |||
self._backend = backend | |||
def _reset(self): | |||
self._params = [] | |||
@@ -231,9 +262,10 @@ class AllreduceCallback: | |||
return | |||
grad_list = [self._gradients_dict[p] for p in self._packing_list[dtype]] | |||
shapes = [p._tuple_shape for p in self._packing_list[dtype]] | |||
reduced_grads = pack_allreduce_split( | |||
grad_list, shapes, self._group, self._reduce_method | |||
) | |||
with override_backend(self._backend): | |||
reduced_grads = pack_allreduce_split( | |||
grad_list, shapes, self._group, self._reduce_method | |||
) | |||
for param, grad in zip(self._packing_list[dtype], reduced_grads): | |||
self._gradients_dict[param] = grad | |||
self._packing_list[dtype] = [] | |||
@@ -14,7 +14,7 @@ import queue | |||
from .. import _exit | |||
from ..core._imperative_rt.core2 import full_sync | |||
from ..logger import get_logger | |||
from .group import group_barrier, init_process_group | |||
from .group import _set_machine_ranks, group_barrier, init_process_group | |||
from .helper import _check_device_initialized, get_device_count_by_fork | |||
from .server import Client, Server | |||
@@ -34,7 +34,9 @@ def _run_wrapped( | |||
device_type, | |||
args, | |||
kwargs, | |||
backend, | |||
queue: mp.Queue, | |||
machine_ranks: list, | |||
): | |||
"""Init distributed process group and run wrapped function.""" | |||
_check_device_initialized(device_type) | |||
@@ -44,10 +46,12 @@ def _run_wrapped( | |||
world_size=world_size, | |||
rank=rank, | |||
device=dev, | |||
backend=backend, | |||
device_type=device_type, | |||
) | |||
# set NCCL_LAUNCH_MODE to avoid deadlock | |||
os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL" | |||
_set_machine_ranks(machine_ranks) | |||
if is_multimachine: | |||
group_barrier() | |||
ret = func(*args, **kwargs) | |||
@@ -67,6 +71,7 @@ class launcher: | |||
:param rank_start: start number for rank. | |||
:param master_ip: ip address for master node (where the rank 0 is). | |||
:param port: server port for distributed server. | |||
:param backend: set default collective communication backend. | |||
""" | |||
def __new__(cls, *args, **kwargs): | |||
@@ -83,6 +88,7 @@ class launcher: | |||
master_ip="localhost", | |||
port=0, | |||
device_type="xpu", | |||
backend="auto", | |||
): | |||
self.func = func | |||
self.n_gpus = ( | |||
@@ -93,6 +99,7 @@ class launcher: | |||
self.master_ip = master_ip | |||
self.port = port | |||
self.device_type = device_type | |||
self.backend = backend | |||
# master node create server | |||
if self.rank_start == 0: | |||
self.server = Server(self.port) | |||
@@ -104,6 +111,7 @@ class launcher: | |||
procs = [] | |||
queue = mp.Queue(self.n_gpus) | |||
results = [None] * self.n_gpus | |||
machine_ranks = [i + self.rank_start for i in range(self.n_gpus)] | |||
for dev in range(self.n_gpus): | |||
p = mp.Process( | |||
target=_run_wrapped, | |||
@@ -118,7 +126,9 @@ class launcher: | |||
self.device_type, | |||
args, | |||
kwargs, | |||
self.backend, | |||
queue, | |||
machine_ranks, | |||
), | |||
) | |||
p.start() | |||
@@ -11,6 +11,7 @@ | |||
#include "megbrain/opr/megray_helper.h" | |||
#include "megbrain/comp_node_env.h" | |||
#include "megray/common.h" | |||
using namespace mgb; | |||
using namespace opr; | |||
@@ -39,6 +40,8 @@ MegRay::Backend mgb::opr::get_megray_backend(const std::string& backend) { | |||
return MegRay::MEGRAY_RCCL; | |||
} else if (backend == "ucx") { | |||
return MegRay::MEGRAY_UCX; | |||
} else if (backend == "shm") { | |||
return MegRay::MEGRAY_SHM; | |||
} else { | |||
mgb_throw(MegBrainError, "back CollectiveComm backend"); | |||
} | |||
@@ -90,7 +93,7 @@ std::shared_ptr<MegRay::Communicator> MegRayCommBuilder::get_megray_comm( | |||
if (rank == root) { | |||
char* c = MegRay::get_host_ip(); | |||
master_ip = std::string(c); | |||
delete c; | |||
delete [] c; | |||
port = MegRay::get_free_port(); | |||
auto ret = MegRay::create_server(size, port); | |||
mgb_assert(ret == MegRay::Status::MEGRAY_OK); | |||