Browse Source

feat(mge/distributed): add hybird parallel Opr

GitOrigin-RevId: ff26671746
release-1.6
Megvii Engine Team 3 years ago
parent
commit
03e80759b6
2 changed files with 284 additions and 16 deletions
  1. +142
    -16
      imperative/python/megengine/distributed/functional.py
  2. +142
    -0
      imperative/python/test/unit/functional/test_functional_distributed_axis.py

+ 142
- 16
imperative/python/megengine/distributed/functional.py View File

@@ -185,7 +185,7 @@ def reduce_sum(
output = reduce_sum(input) output = reduce_sum(input)
# Rank 0 # output: Tensor([1]) # Rank 0 # output: Tensor([1])
# Rank 1 # output: None # Rank 1 # output: None
input = Tensor([rank]) input = Tensor([rank])
group = Group([1, 0]) # first rank is root group = Group([1, 0]) # first rank is root
output = reduce_sum(input, group) output = reduce_sum(input, group)
@@ -248,7 +248,7 @@ def broadcast(
output = broadcast(input) output = broadcast(input)
# Rank 0 # output: Tensor([0]) # Rank 0 # output: Tensor([0])
# Rank 1 # output: Tensor([0]) # Rank 1 # output: Tensor([0])
input = Tensor([rank]) input = Tensor([rank])
group = Group([1, 0]) # first rank is root group = Group([1, 0]) # first rank is root
output = broadcast(input, group) output = broadcast(input, group)
@@ -276,7 +276,7 @@ def _bcast_param(




def all_gather( def all_gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
) -> Tensor: ) -> Tensor:
r""" r"""
Gather tensors across the specified group and concat them at first dimension. Gather tensors across the specified group and concat them at first dimension.
@@ -290,6 +290,8 @@ def all_gather(
None default device means the device of inp will be used. None default device means the device of inp will be used.
Specify "gpu0:1" to execute this operator on diffrent cuda stream, Specify "gpu0:1" to execute this operator on diffrent cuda stream,
1 is stream id, and default stream id is 0. 1 is stream id, and default stream id is 0.
axis: The concat axis for collective_comm result
The default axis is 0


Returns: Returns:
Result tensor. Result tensor.
@@ -304,7 +306,7 @@ def all_gather(
output = all_gather(input) output = all_gather(input)
# Rank 0 # output: Tensor([0 1]) # Rank 0 # output: Tensor([0 1])
# Rank 1 # output: Tensor([0 1]) # Rank 1 # output: Tensor([0 1])
input = Tensor([rank]) input = Tensor([rank])
group = Group([1, 0]) group = Group([1, 0])
output = all_gather(input, group) output = all_gather(input, group)
@@ -313,11 +315,28 @@ def all_gather(


""" """
mode = CollectiveComm.Mode.ALL_GATHER mode = CollectiveComm.Mode.ALL_GATHER
return collective_comm(inp, mode, group, device)
out = collective_comm(inp, mode, group, device)
if axis == 0:
return out
else:
group_size = group.size if group is not None else 1
transformed_shape = list(inp._tuple_shape)
transformed_shape[axis] *= group_size
n, *shp = out._tuple_shape
index = (
[_ for _ in range(1, axis)]
+ [axis, 0]
+ [_ for _ in range(axis + 1, out.ndim + 1)]
)
return (
out.reshape(group_size, n // group_size, *shp)
.transpose(index)
.reshape(transformed_shape)
)




def reduce_scatter_sum( def reduce_scatter_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0
) -> Tensor: ) -> Tensor:
r""" r"""
Reduce tensors across the specified group by sum and split them at first dimension. Reduce tensors across the specified group by sum and split them at first dimension.
@@ -331,6 +350,8 @@ def reduce_scatter_sum(
None default device means the device of inp will be used. None default device means the device of inp will be used.
Specify "gpu0:1" to execute this operator on diffrent cuda stream, Specify "gpu0:1" to execute this operator on diffrent cuda stream,
1 is stream id, and default stream id is 0. 1 is stream id, and default stream id is 0.
axis: The split axis for collective_comm result
The default axis is 0, the data will split in the 0 axis


Returns: Returns:
Split tensor. Split tensor.
@@ -345,7 +366,7 @@ def reduce_scatter_sum(
output = reduce_scatter_sum(input) output = reduce_scatter_sum(input)
# Rank 0 # output: Tensor([0]) # Rank 0 # output: Tensor([0])
# Rank 1 # output: Tensor([2]) # Rank 1 # output: Tensor([2])
input = Tensor([0 1]) input = Tensor([0 1])
group = Group([1, 0]) group = Group([1, 0])
output = reduce_scatter_sum(input, group) output = reduce_scatter_sum(input, group)
@@ -353,6 +374,23 @@ def reduce_scatter_sum(
# Rank 1 # output: Tensor([0]) # Rank 1 # output: Tensor([0])


""" """
group_size = group.size if group is not None else 1
assert (
list(inp._tuple_shape)[axis] % group_size == 0
), "current axis: {} can't devided by group size".format(axis)
if axis != 0:
k_new_shape = list(inp._tuple_shape)
k_new_shape[axis] //= group_size
k_new_shape[0] *= group_size
new_shape = list(inp._tuple_shape)
new_shape[axis] //= group_size
new_shape.insert(axis, group_size)
index = (
[axis]
+ [_ for _ in range(0, axis)]
+ [_ for _ in range(axis + 1, inp.ndim + 1)]
)
inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM
return collective_comm(inp, mode, group, device) return collective_comm(inp, mode, group, device)


@@ -480,7 +518,7 @@ class _Gather(Function):




def gather( def gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
) -> Tensor: ) -> Tensor:
r""" r"""
Gather tensors across the specified group. Gather tensors across the specified group.
@@ -495,7 +533,8 @@ def gather(
None default device means the device of inp will be used. None default device means the device of inp will be used.
Specify "gpu0:1" to execute this operator on diffrent cuda stream, Specify "gpu0:1" to execute this operator on diffrent cuda stream,
1 is stream id, and default stream id is 0. 1 is stream id, and default stream id is 0.

axis: The concat axis for collective_comm result
The default axis is 0
Returns: Returns:
Result tensor if in root process, None if in other process Result tensor if in root process, None if in other process


@@ -509,7 +548,7 @@ def gather(
output = gather(input) output = gather(input)
# Rank 0 # output: Tensor([0 1]) # Rank 0 # output: Tensor([0 1])
# Rank 1 # output: None # Rank 1 # output: None
input = Tensor([rank]) input = Tensor([rank])
group = Group([1, 0]) # first rank is root group = Group([1, 0]) # first rank is root
output = gather(input, group) output = gather(input, group)
@@ -517,12 +556,33 @@ def gather(
# Rank 1 # output: Tensor([1 0]) # Rank 1 # output: Tensor([1 0])


""" """
assert (
axis < inp.ndim
), "your concat_axis exceeds the dim of the tensor, the tensor shape is {}".format(
inp.shape
)


op = _Gather(group, device) op = _Gather(group, device)
(out,) = apply(op, inp) (out,) = apply(op, inp)


if group.rank == 0: if group.rank == 0:
return out
if axis == 0:
return out
else:
group_size = group.size
transformed_shape = list(inp._tuple_shape)
transformed_shape[axis] *= group_size
n, *shp = out._tuple_shape
index = (
[_ for _ in range(1, axis)]
+ [axis, 0]
+ [_ for _ in range(axis + 1, out.ndim + 1)]
)
return (
out.reshape(group_size, n // group_size, *shp)
.transpose(index)
.reshape(transformed_shape)
)
else: else:
_save_output_for_autodiff(inp, out) _save_output_for_autodiff(inp, out)


@@ -545,7 +605,7 @@ class _Scatter(Function):




def scatter( def scatter(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0,
) -> Tensor: ) -> Tensor:
r""" r"""
Split tensor in root process at first dimension. Split tensor in root process at first dimension.
@@ -559,6 +619,8 @@ def scatter(
None default device means the device of inp will be used. None default device means the device of inp will be used.
Specify "gpu0:1" to execute this operator on diffrent cuda stream, Specify "gpu0:1" to execute this operator on diffrent cuda stream,
1 is stream id, and default stream id is 0. 1 is stream id, and default stream id is 0.
axis: The concat axis for collective_comm result
The default axis is 0


Returns: Returns:
Split tensor. Split tensor.
@@ -573,7 +635,7 @@ def scatter(
output = scatter(input) output = scatter(input)
# Rank 0 # output: Tensor([0]) # Rank 0 # output: Tensor([0])
# Rank 1 # output: Tensor([1]) # Rank 1 # output: Tensor([1])
input = Tensor([0 1]) + rank*2 input = Tensor([0 1]) + rank*2
group = Group([1, 0]) # first rank is root group = Group([1, 0]) # first rank is root
output = scatter(input, group) output = scatter(input, group)
@@ -588,13 +650,35 @@ def scatter(


_bcast_tracer_state(group, inp) _bcast_tracer_state(group, inp)


assert (
list(inp._tuple_shape)[axis] % group.size == 0
), "current axis: {} can't devided by group size".format(axis)

if axis != 0:
group_size = group.size
k_new_shape = list(inp._tuple_shape)
k_new_shape[axis] //= group_size
k_new_shape[0] *= group_size
new_shape = list(inp._tuple_shape)
new_shape[axis] //= group_size
new_shape.insert(axis, group_size)
index = (
[axis]
+ [_ for _ in range(0, axis)]
+ [_ for _ in range(axis + 1, inp.ndim + 1)]
)
inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)
op = _Scatter(group, device) op = _Scatter(group, device)
(out,) = apply(op, inp) (out,) = apply(op, inp)
return out return out




def all_to_all( def all_to_all(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
inp: Tensor,
group: Optional[Group] = WORLD,
device: Optional[str] = None,
split_axis: int = 0,
concat_axis: int = 0,
) -> Tensor: ) -> Tensor:
r""" r"""
Each process scatter input tensor to all processes and return gathered tensor. Each process scatter input tensor to all processes and return gathered tensor.
@@ -608,6 +692,10 @@ def all_to_all(
None default device means the device of inp will be used. None default device means the device of inp will be used.
Specify "gpu0:1" to execute this operator on diffrent cuda stream, Specify "gpu0:1" to execute this operator on diffrent cuda stream,
1 is stream id, and default stream id is 0. 1 is stream id, and default stream id is 0.
split_axis: The axis that collectivecomm will split data
the default axis is 0
split_axis: The axis that collectivecomm will concat data
the default axis is 0


Returns: Returns:
Result tensor. Result tensor.
@@ -622,7 +710,7 @@ def all_to_all(
output = all_to_all(input) output = all_to_all(input)
# Rank 0 # output: Tensor([0 2]) # Rank 0 # output: Tensor([0 2])
# Rank 1 # output: Tensor([1 3]) # Rank 1 # output: Tensor([1 3])
input = Tensor([0 1]) + rank*2 input = Tensor([0 1]) + rank*2
group = Group([1, 0]) group = Group([1, 0])
output = all_to_all(input, group) output = all_to_all(input, group)
@@ -630,8 +718,46 @@ def all_to_all(
# Rank 1 # output: Tensor([2 1]) # Rank 1 # output: Tensor([2 1])


""" """
group_size = group.size if group is not None else 1
assert (
list(inp._tuple_shape)[split_axis] % group_size == 0
), "current axis: {} can't devided by group size".format(split_axis)
origin_shape = inp._tuple_shape
if split_axis != 0:
k_new_shape = list(inp._tuple_shape)
k_new_shape[split_axis] //= group_size
k_new_shape[0] *= group_size
new_shape = list(inp._tuple_shape)
new_shape[split_axis] //= group_size
new_shape.insert(split_axis, group_size)
index = (
[split_axis]
+ [_ for _ in range(0, split_axis)]
+ [_ for _ in range(split_axis + 1, inp.ndim + 1)]
)
inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape)

mode = CollectiveComm.Mode.ALL_TO_ALL mode = CollectiveComm.Mode.ALL_TO_ALL
return collective_comm(inp, mode, group, device)
out = collective_comm(inp, mode, group, device)

if concat_axis == 0:
return out

transformed_shape = list(origin_shape)
transformed_shape[concat_axis] *= group_size
transformed_shape[split_axis] //= group_size

n, *shp = out._tuple_shape
index = (
[_ for _ in range(1, concat_axis)]
+ [concat_axis, 0]
+ [_ for _ in range(concat_axis + 1, out.ndim + 1)]
)
return (
out.reshape(group_size, n // group_size, *shp)
.transpose(index)
.reshape(transformed_shape)
)




class _SendRecvGroup: class _SendRecvGroup:


+ 142
- 0
imperative/python/test/unit/functional/test_functional_distributed_axis.py View File

@@ -0,0 +1,142 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import numpy as np
import pytest

import megengine as mge
import megengine.distributed as dist
from megengine import tensor
from megengine.distributed.functional import (
all_gather,
all_to_all,
gather,
reduce_scatter_sum,
scatter,
)
from megengine.jit import trace


@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("shape", [(2, 3), (8, 10), (99, 77), (2, 2, 2, 2)], ids=str)
@pytest.mark.parametrize("symbolic", [False, True], ids=str)
@pytest.mark.parametrize("axis", [0, 1], ids=str)
@pytest.mark.isolated_distributed
def test_all_gather(shape, symbolic, axis):
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])

def func():
output = all_gather(inp, axis=axis)
return output

func = trace(symbolic=symbolic)(func)
output = func()
assert np.allclose(output.numpy(), expect[rank])

x = np.random.random_sample(shape).astype("float32")
y = np.random.random_sample(shape).astype("float32")
z = np.concatenate((x, y), axis=axis)
data = (x, y)
expect = (z, z)
worker(data, expect)


@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize(
"shape,symbolic", [((2, 4, 6, 8), False), ((2, 4, 6, 8), True)], ids=str
)
@pytest.mark.parametrize("axis", [1, 0, 2, 3], ids=str)
@pytest.mark.isolated_distributed
def test_reduce_scatter_sum(shape, symbolic, axis):
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])

def func():
output = reduce_scatter_sum(inp, axis=axis)
return output

func = trace(symbolic=symbolic)(func)
output = func()
assert np.allclose(output.numpy(), expect[rank])

x = np.random.random_sample(shape).astype("float32")
y = np.random.random_sample(shape).astype("float32")
z = x + y
data = (x, y)
z = np.split(z, 2, axis=axis)
z = np.concatenate(z, axis=0)
expect = (z[: z.shape[0] // 2], z[z.shape[0] // 2 :])
worker(data, expect)


@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize(
"shape,symbolic", [((2, 4, 6, 8), True), ((2, 4, 6, 8), False)], ids=str
)
@pytest.mark.parametrize("axis", [1, 0, 2, 3], ids=str)
@pytest.mark.isolated_distributed
def test_scatter(shape, symbolic, axis):
@dist.launcher(n_gpus=2)
def worker(data, expect):
rank = dist.get_rank()
inp = tensor(data[rank])

def func():
output = scatter(inp, axis=axis)
return output

func = trace(symbolic=symbolic)(func)
output = func()
assert np.allclose(output.numpy(), expect[rank])

x = np.random.random_sample(shape).astype("float32")
y = x + 1
data = (x, y)
_x = np.split(x, 2, axis=axis)
_x = np.concatenate(_x, axis=0)
expect = (_x[: _x.shape[0] // 2], _x[_x.shape[0] // 2 :])
worker(data, expect)


@pytest.mark.require_ngpu(2)
@pytest.mark.parametrize("shape", [(2, 4, 6, 8)], ids=str)
@pytest.mark.parametrize("symbolic", [False, True], ids=str)
@pytest.mark.parametrize(
"split_axis,concat_axis", [(0, 1), (1, 0), (2, 0), (0, 2), (2, 3)], ids=str
)
@pytest.mark.isolated_distributed
def test_all_to_all(shape, symbolic, split_axis, concat_axis):
@dist.launcher(n_gpus=2)
def worker(data):
rank = dist.get_rank()
inp = tensor(data[rank])

def func():
all_to_all_output = all_to_all(
inp, split_axis=split_axis, concat_axis=concat_axis
)
gather_C = gather(inp, axis=concat_axis)
gather_B = gather(all_to_all_output, axis=split_axis)
if rank == 0:
return gather_B, gather_C
return all_to_all_output

func = trace(symbolic=symbolic)(func)
ret = func()
if rank == 0:
assert np.allclose(ret[0], ret[1])

x = np.random.random_sample(shape).astype("float32")
y = np.random.random_sample(shape).astype("float32")
data = (x, y)
worker(data)

Loading…
Cancel
Save