|
|
@@ -20,6 +20,7 @@ from ..core.autodiff.grad import ( |
|
|
|
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend |
|
|
|
from ..core.tensor.core import apply |
|
|
|
from ..core.tensor.tensor import Tensor |
|
|
|
from ..device import get_default_device |
|
|
|
from ..distributed.group import ( |
|
|
|
WORLD, |
|
|
|
Group, |
|
|
@@ -270,16 +271,19 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: |
|
|
|
|
|
|
|
|
|
|
|
def remote_recv( |
|
|
|
src_rank: int, shape: Tuple[int], dtype: type, cn: Optional[str] = "gpu0" |
|
|
|
src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None |
|
|
|
) -> Tensor: |
|
|
|
"""Receive a Tensor from a remote process |
|
|
|
|
|
|
|
:param src_rank: source process rank |
|
|
|
:param shape: the shape of the tensor to receive |
|
|
|
:param dtype: the data type of the tensor to receive |
|
|
|
:param cn: the comp node to place the received tensor |
|
|
|
:param device: the device to place the received tensor, |
|
|
|
if None, use default device |
|
|
|
""" |
|
|
|
key = "{}->{}".format(src_rank, get_rank()) |
|
|
|
if device is None: |
|
|
|
device = get_default_device() |
|
|
|
|
|
|
|
# dummpy input |
|
|
|
inp = tensor([0]) |
|
|
@@ -290,7 +294,7 @@ def remote_recv( |
|
|
|
|
|
|
|
op = RemoteRecv() |
|
|
|
op.key = key |
|
|
|
op.cn = cn |
|
|
|
op.cn = device |
|
|
|
op.shape = shape |
|
|
|
op.dtype = dtype |
|
|
|
op.addr, op.port = get_mm_server_addr() |
|
|
|