Browse Source

feat(mge/distributed): scalar support for distributed functions

GitOrigin-RevId: 53f3575baf
tags/v1.3.0
Megvii Engine Team 4 years ago
parent
commit
ea8eb4cf72
2 changed files with 41 additions and 22 deletions
  1. +20
    -1
      imperative/python/megengine/distributed/functional.py
  2. +21
    -21
      imperative/python/test/unit/functional/test_functional_distributed.py

+ 20
- 1
imperative/python/megengine/distributed/functional.py View File

@@ -11,6 +11,7 @@ from typing import Optional, Tuple
from ..core._imperative_rt.core2 import apply from ..core._imperative_rt.core2 import apply
from ..core.autodiff.grad import _grad_manager_dict from ..core.autodiff.grad import _grad_manager_dict
from ..core.ops.builtin import CollectiveComm, Copy, PyOpBase, RemoteRecv, RemoteSend from ..core.ops.builtin import CollectiveComm, Copy, PyOpBase, RemoteRecv, RemoteSend
from ..core.tensor.utils import isscalar, setscalar
from ..device import get_default_device from ..device import get_default_device
from ..tensor import Tensor from ..tensor import Tensor
from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank
@@ -50,7 +51,18 @@ def collective_comm(inp, mode, group, device):
backend=get_backend(), backend=get_backend(),
comp_node=device, comp_node=device,
) )
return apply(op, inp)[0]
(result,) = apply(op, inp)
# assume all workers have homogeneous shape
if mode in (
CollectiveComm.Mode.REDUCE_SUM,
CollectiveComm.Mode.BROADCAST,
CollectiveComm.Mode.ALL_REDUCE_SUM,
CollectiveComm.Mode.ALL_REDUCE_MAX,
CollectiveComm.Mode.ALL_REDUCE_MIN,
):
if isscalar(inp):
setscalar(result)
return result




def reduce_sum( def reduce_sum(
@@ -289,6 +301,11 @@ def remote_recv(
g.wrt(inp) g.wrt(inp)
g._refkeeper.append(inp) g._refkeeper.append(inp)


_isscalar = False
if len(shape) == 0:
shape = (1,)
_isscalar = True

op = RemoteRecv() op = RemoteRecv()
op.key = key op.key = key
op.cn = device op.cn = device
@@ -298,4 +315,6 @@ def remote_recv(
op.rank_from = src_rank op.rank_from = src_rank


(ret,) = apply(_RemoteRecv(op), inp) (ret,) = apply(_RemoteRecv(op), inp)
if _isscalar:
setscalar(ret)
return ret return ret

+ 21
- 21
imperative/python/test/unit/functional/test_functional_distributed.py View File

@@ -13,7 +13,7 @@ import pytest


import megengine as mge import megengine as mge
import megengine.distributed as dist import megengine.distributed as dist
from megengine import Parameter, Tensor, tensor
from megengine import Parameter, tensor
from megengine.core._imperative_rt.core2 import sync from megengine.core._imperative_rt.core2 import sync
from megengine.device import get_default_device, set_default_device from megengine.device import get_default_device, set_default_device
from megengine.distributed.helper import get_device_count_by_fork from megengine.distributed.helper import get_device_count_by_fork
@@ -53,14 +53,14 @@ def test_reduce_sum():
assert np.allclose(output.numpy(), 0) assert np.allclose(output.numpy(), 0)


def check(shape): def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
x = np.random.rand(*shape)
y = np.random.rand(*shape)
z = x + y z = x + y
data = (x, y) data = (x, y)
expect = (z, None) expect = (z, None)
worker(data, expect) worker(data, expect)


for shape in [(2, 3), (8, 10), (99, 77)]:
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape) check(shape)




@@ -81,13 +81,13 @@ def test_broadcast():
assert np.allclose(output.numpy(), expect[rank]) assert np.allclose(output.numpy(), expect[rank])


def check(shape): def check(shape):
x = np.random.rand(*shape).astype("float32")
x = np.random.rand(*shape)
y = x + 1 y = x + 1
data = (x, y) data = (x, y)
expect = (x, x) expect = (x, x)
worker(data, expect) worker(data, expect)


for shape in [(2, 3), (8, 10), (99, 77)]:
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape) check(shape)




@@ -164,14 +164,14 @@ def test_all_reduce_sum():
assert np.allclose(output.numpy(), expect[rank]) assert np.allclose(output.numpy(), expect[rank])


def check(shape): def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
x = np.random.rand(*shape)
y = np.random.rand(*shape)
z = x + y z = x + y
data = (x, y) data = (x, y)
expect = (z, z) expect = (z, z)
worker(data, expect) worker(data, expect)


for shape in [(2, 3), (8, 10), (99, 77)]:
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape) check(shape)




@@ -192,14 +192,14 @@ def test_all_reduce_max():
assert np.allclose(output.numpy(), expect[rank]) assert np.allclose(output.numpy(), expect[rank])


def check(shape): def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
x = np.random.rand(*shape)
y = np.random.rand(*shape)
z = np.maximum(x, y) z = np.maximum(x, y)
data = (x, y) data = (x, y)
expect = (z, z) expect = (z, z)
worker(data, expect) worker(data, expect)


for shape in [(2, 3), (8, 10), (99, 77)]:
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape) check(shape)




@@ -220,14 +220,14 @@ def test_all_reduce_min():
assert np.allclose(output.numpy(), expect[rank]) assert np.allclose(output.numpy(), expect[rank])


def check(shape): def check(shape):
x = np.random.rand(*shape).astype("float32")
y = np.random.rand(*shape).astype("float32")
x = np.random.rand(*shape)
y = np.random.rand(*shape)
z = np.minimum(x, y) z = np.minimum(x, y)
data = (x, y) data = (x, y)
expect = (z, z) expect = (z, z)
worker(data, expect) worker(data, expect)


for shape in [(2, 3), (8, 10), (99, 77)]:
for shape in [(), (1,), (2, 3), (8, 10), (99, 77)]:
check(shape) check(shape)




@@ -327,18 +327,18 @@ def test_all_to_all():
@pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device") @pytest.mark.skipif(get_device_count_by_fork("gpu") < 2, reason="need more gpu device")
@pytest.mark.isolated_distributed @pytest.mark.isolated_distributed
def test_io_remote(): def test_io_remote():
val = np.random.rand(4, 5).astype(np.float32)

@dist.launcher(n_gpus=2) @dist.launcher(n_gpus=2)
def worker():
def worker(val, shape):
rank = dist.get_rank() rank = dist.get_rank()
if rank == 0: # remote send if rank == 0: # remote send
x = Tensor(val, device="gpu0")
x = tensor(val, device="gpu0")
remote_send(x, 1) remote_send(x, 1)
sync() sync()
else: # remote recv else: # remote recv
y = remote_recv(0, val.shape, val.dtype)
y = remote_recv(0, shape, np.float32)
assert y.device == "gpu1" assert y.device == "gpu1"
np.testing.assert_almost_equal(val, y.numpy()) np.testing.assert_almost_equal(val, y.numpy())


worker()
for shape in [(), (1,), (4, 5)]:
val = np.random.rand(*shape)
worker(val, shape)

Loading…
Cancel
Save