Browse Source

fix(mge/functional): fix indexing_one_hot and remote_recv

GitOrigin-RevId: 00bdfb502b
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
8dc23e0fdf
3 changed files with 11 additions and 5 deletions
  1. +7
    -3
      imperative/python/megengine/functional/distributed.py
  2. +1
    -1
      imperative/python/megengine/functional/nn.py
  3. +3
    -1
      imperative/python/test/unit/functional/test_distributed.py

+ 7
- 3
imperative/python/megengine/functional/distributed.py View File

@@ -20,6 +20,7 @@ from ..core.autodiff.grad import (
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.tensor.core import apply from ..core.tensor.core import apply
from ..core.tensor.tensor import Tensor from ..core.tensor.tensor import Tensor
from ..device import get_default_device
from ..distributed.group import ( from ..distributed.group import (
WORLD, WORLD,
Group, Group,
@@ -270,16 +271,19 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:




def remote_recv( def remote_recv(
src_rank: int, shape: Tuple[int], dtype: type, cn: Optional[str] = "gpu0"
src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None
) -> Tensor: ) -> Tensor:
"""Receive a Tensor from a remote process """Receive a Tensor from a remote process


:param src_rank: source process rank :param src_rank: source process rank
:param shape: the shape of the tensor to receive :param shape: the shape of the tensor to receive
:param dtype: the data type of the tensor to receive :param dtype: the data type of the tensor to receive
:param cn: the comp node to place the received tensor
:param device: the device to place the received tensor,
if None, use default device
""" """
key = "{}->{}".format(src_rank, get_rank()) key = "{}->{}".format(src_rank, get_rank())
if device is None:
device = get_default_device()


# dummpy input # dummpy input
inp = tensor([0]) inp = tensor([0])
@@ -290,7 +294,7 @@ def remote_recv(


op = RemoteRecv() op = RemoteRecv()
op.key = key op.key = key
op.cn = cn
op.cn = device
op.shape = shape op.shape = shape
op.dtype = dtype op.dtype = dtype
op.addr, op.port = get_mm_server_addr() op.addr, op.port = get_mm_server_addr()


+ 1
- 1
imperative/python/megengine/functional/nn.py View File

@@ -1447,7 +1447,7 @@ def indexing_one_hot(
src, (TensorWrapperBase, TensorBase) src, (TensorWrapperBase, TensorBase)
), "src must be of Tensor type" ), "src must be of Tensor type"
op = builtin.IndexingOneHot(axis=axis) op = builtin.IndexingOneHot(axis=axis)
index = utils.convert_single_value(index, (src,), dtype="int32")
index = utils.convert_single_value(index, (src,), dtype="int32", device=src.device)
(result,) = apply(op, src, index) (result,) = apply(op, src, index)
if not keepdims: if not keepdims:
result = remove_axis(result, axis) result = remove_axis(result, axis)


+ 3
- 1
imperative/python/test/unit/functional/test_distributed.py View File

@@ -15,6 +15,7 @@ import pytest
import megengine as mge import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
from megengine import Parameter, Tensor, tensor from megengine import Parameter, Tensor, tensor
from megengine.device import get_default_device, set_default_device
from megengine.functional.distributed import ( from megengine.functional.distributed import (
all_gather, all_gather,
all_reduce_max, all_reduce_max,
@@ -449,7 +450,8 @@ def test_io_remote():
assert y.numpy()[0] == 0 assert y.numpy()[0] == 0
else: # remote recv else: # remote recv
dist.init_process_group("localhost", port, world_size, rank, rank) dist.init_process_group("localhost", port, world_size, rank, rank)
y = remote_recv(0, val.shape, val.dtype, cn="gpu1")
y = remote_recv(0, val.shape, val.dtype)
assert y.device == "gpu1"
np.testing.assert_almost_equal(val, y.numpy()) np.testing.assert_almost_equal(val, y.numpy())


procs = [] procs = []


Loading…
Cancel
Save