Browse Source

fix(mge/functional): fix indexing_one_hot and remote_recv

GitOrigin-RevId: 00bdfb502b
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
8dc23e0fdf
3 changed files with 11 additions and 5 deletions
  1. +7
    -3
      imperative/python/megengine/functional/distributed.py
  2. +1
    -1
      imperative/python/megengine/functional/nn.py
  3. +3
    -1
      imperative/python/test/unit/functional/test_distributed.py

+ 7
- 3
imperative/python/megengine/functional/distributed.py View File

@@ -20,6 +20,7 @@ from ..core.autodiff.grad import (
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.tensor.core import apply
from ..core.tensor.tensor import Tensor
from ..device import get_default_device
from ..distributed.group import (
WORLD,
Group,
@@ -270,16 +271,19 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:


def remote_recv(
src_rank: int, shape: Tuple[int], dtype: type, cn: Optional[str] = "gpu0"
src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None
) -> Tensor:
"""Receive a Tensor from a remote process

:param src_rank: source process rank
:param shape: the shape of the tensor to receive
:param dtype: the data type of the tensor to receive
:param cn: the comp node to place the received tensor
:param device: the device to place the received tensor,
if None, use default device
"""
key = "{}->{}".format(src_rank, get_rank())
if device is None:
device = get_default_device()

# dummpy input
inp = tensor([0])
@@ -290,7 +294,7 @@ def remote_recv(

op = RemoteRecv()
op.key = key
op.cn = cn
op.cn = device
op.shape = shape
op.dtype = dtype
op.addr, op.port = get_mm_server_addr()


+ 1
- 1
imperative/python/megengine/functional/nn.py View File

@@ -1447,7 +1447,7 @@ def indexing_one_hot(
src, (TensorWrapperBase, TensorBase)
), "src must be of Tensor type"
op = builtin.IndexingOneHot(axis=axis)
index = utils.convert_single_value(index, (src,), dtype="int32")
index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device)
(result,) = apply(op, src, index)
if not keepdims:
result = remove_axis(result, axis)


+ 3
- 1
imperative/python/test/unit/functional/test_distributed.py View File

@@ -15,6 +15,7 @@ import pytest
import megengine as mge
import megengine.distributed as dist
from megengine import Parameter, Tensor, tensor
from megengine.device import get_default_device, set_default_device
from megengine.functional.distributed import (
all_gather,
all_reduce_max,
@@ -449,7 +450,8 @@ def test_io_remote():
assert y.numpy()[0] == 0
else: # remote recv
dist.init_process_group("localhost", port, world_size, rank, rank)
y = remote_recv(0, val.shape, val.dtype, cn="gpu1")
y = remote_recv(0, val.shape, val.dtype)
assert y.device == "gpu1"
np.testing.assert_almost_equal(val, y.numpy())

procs = []


Loading…
Cancel
Save