Browse Source

feat(mge/distributed): scalar support for distributed functions

GitOrigin-RevId: 53f3575baf
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
ea8eb4cf72
2 changed files with 41 additions and 22 deletions
  1. +20
    -1
      imperative/python/megengine/distributed/functional.py
  2. +21
    -21
      imperative/python/test/unit/functional/test_functional_distributed.py

+ 20
- 1
imperative/python/megengine/distributed/functional.py View File

@@ -11,6 +11,7 @@ from typing import Optional, Tuple
from ..core._imperative_rt.core2 import apply
from ..core.autodiff.grad import _grad_manager_dict
from ..core.ops.builtin import CollectiveComm, Copy, PyOpBase, RemoteRecv, RemoteSend
from ..core.tensor.utils import isscalar, setscalar
from ..device import get_default_device
from ..tensor import Tensor
from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank
@@ -50,7 +51,18 @@ def collective_comm(inp, mode, group, device):
backend=get_backend(),
comp_node=device,
)
return apply(op, inp)[0]
(result,) = apply(op, inp)
# assume all workers have homogeneous shape
if mode in (
CollectiveComm.Mode.REDUCE_SUM,
CollectiveComm.Mode.BROADCAST,
CollectiveComm.Mode.ALL_REDUCE_SUM,
CollectiveComm.Mode.ALL_REDUCE_MAX,
CollectiveComm.Mode.ALL_REDUCE_MIN,
):
if isscalar(inp):
setscalar(result)
return result


def reduce_sum(
@@ -289,6 +301,11 @@ def remote_recv(
g.wrt(inp)
g._refkeeper.append(inp)

_isscalar = False
if len(shape) == 0:
shape = (1,)
_isscalar = True

op = RemoteRecv()
op.key = key
op.cn = device
@@ -298,4 +315,6 @@ def remote_recv(
op.rank_from = src_rank

(ret,) = apply(_RemoteRecv(op), inp)
if _isscalar:
setscalar(ret)
return ret

+ 21
- 21
imperative/python/test/unit/functional/test_functional_distributed.py View File

@@ -13,7 +13,7 @@ import pytest

import megengine as mge
import megengine.distributed as dist
from megengine import Parameter, Tensor, tensor
from megengine import Parameter, tensor
from megengine.core._imperative_rt.core2 import sync
from megengine.device import get_default_device, set_default_device
from megengine.distributed.helper import get_device_count_by_fork
@@ -53,14 +53,14 @@ def test_reduce_sum():
assert np.allclose(output.numpy(), 0)

def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
x = np.random.rand(*shape)
y = np.random.rand(*shape)
z = x + y
data = (x, y)
expect = (z, None)
worker(data, expect)

for shape in [(2, 3), (8, 10), (99, 77)]:
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape)


@@ -81,13 +81,13 @@ def test_broadcast():
assert np.allclose(output.numpy(), expect[rank])

def check(shape):
x = np.random.rand(*shape).astype("float32")
x = np.random.rand(*shape)
y = x + 1
data = (x, y)
expect = (x, x)
worker(data, expect)

for shape in [(2, 3), (8, 10), (99, 77)]:
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape)


@@ -164,14 +164,14 @@ def test_all_reduce_sum():
assert np.allclose(output.numpy(), expect[rank])

def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
x = np.random.rand(*shape)
y = np.random.rand(*shape)
z = x + y
data = (x, y)
expect = (z, z)
worker(data, expect)

for shape in [(2, 3), (8, 10), (99, 77)]:
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape)


@@ -192,14 +192,14 @@ def test_all_reduce_max():
assert np.allclose(output.numpy(), expect[rank])

def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
x = np.random.rand(*shape)
y = np.random.rand(*shape)
z = np.maximum(x, y)
data = (x, y)
expect = (z, z)
worker(data, expect)

for shape in [(2, 3), (8, 10), (99, 77)]:
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape)


@@ -220,14 +220,14 @@ def test_all_reduce_min():
assert np.allclose(output.numpy(), expect[rank])

def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
x = np.random.rand(*shape)
y = np.random.rand(*shape)
z = np.minimum(x, y)
data = (x, y)
expect = (z, z)
worker(data, expect)

for shape in [(2, 3), (8, 10), (99, 77)]:
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape)


@@ -327,18 +327,18 @@ def test_all_to_all():
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed
def test_io_remote():
val = np.random.rand(4, 5).astype(np.float32)

@dist.launcher(n_gpus=2)
def worker():
def worker(val, shape):
rank = dist.get_rank()
if rank == 0: # remote send
x = Tensor(val, device="gpu0")
x = tensor(val, device="gpu0")
remote_send(x, 1)
sync()
else: # remote recv
y = remote_recv(0, val.shape, val.dtype)
y = remote_recv(0, shape, np.float32)
assert y.device == "gpu1"
np.testing.assert_almost_equal(val, y.numpy())

worker()
for shape in [(), (1,), (4, 5)]:
val = np.random.rand(*shape)
worker(val, shape)

Loading…
Cancel
Save