|
@@ -185,7 +185,7 @@ def reduce_sum( |
|
|
output = reduce_sum(input) |
|
|
output = reduce_sum(input) |
|
|
# Rank 0 # output: Tensor([1]) |
|
|
# Rank 0 # output: Tensor([1]) |
|
|
# Rank 1 # output: None |
|
|
# Rank 1 # output: None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input = Tensor([rank]) |
|
|
input = Tensor([rank]) |
|
|
group = Group([1, 0]) # first rank is root |
|
|
group = Group([1, 0]) # first rank is root |
|
|
output = reduce_sum(input, group) |
|
|
output = reduce_sum(input, group) |
|
@@ -248,7 +248,7 @@ def broadcast( |
|
|
output = broadcast(input) |
|
|
output = broadcast(input) |
|
|
# Rank 0 # output: Tensor([0]) |
|
|
# Rank 0 # output: Tensor([0]) |
|
|
# Rank 1 # output: Tensor([0]) |
|
|
# Rank 1 # output: Tensor([0]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input = Tensor([rank]) |
|
|
input = Tensor([rank]) |
|
|
group = Group([1, 0]) # first rank is root |
|
|
group = Group([1, 0]) # first rank is root |
|
|
output = broadcast(input, group) |
|
|
output = broadcast(input, group) |
|
@@ -276,7 +276,7 @@ def _bcast_param( |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def all_gather( |
|
|
def all_gather( |
|
|
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, |
|
|
|
|
|
|
|
|
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0, |
|
|
) -> Tensor: |
|
|
) -> Tensor: |
|
|
r""" |
|
|
r""" |
|
|
Gather tensors across the specified group and concat them at first dimension. |
|
|
Gather tensors across the specified group and concat them at first dimension. |
|
@@ -290,6 +290,8 @@ def all_gather( |
|
|
None default device means the device of inp will be used. |
|
|
None default device means the device of inp will be used. |
|
|
Specify "gpu0:1" to execute this operator on diffrent cuda stream, |
|
|
Specify "gpu0:1" to execute this operator on diffrent cuda stream, |
|
|
1 is stream id, and default stream id is 0. |
|
|
1 is stream id, and default stream id is 0. |
|
|
|
|
|
axis: The concat axis for collective_comm result |
|
|
|
|
|
The default axis is 0 |
|
|
|
|
|
|
|
|
Returns: |
|
|
Returns: |
|
|
Result tensor. |
|
|
Result tensor. |
|
@@ -304,7 +306,7 @@ def all_gather( |
|
|
output = all_gather(input) |
|
|
output = all_gather(input) |
|
|
# Rank 0 # output: Tensor([0 1]) |
|
|
# Rank 0 # output: Tensor([0 1]) |
|
|
# Rank 1 # output: Tensor([0 1]) |
|
|
# Rank 1 # output: Tensor([0 1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input = Tensor([rank]) |
|
|
input = Tensor([rank]) |
|
|
group = Group([1, 0]) |
|
|
group = Group([1, 0]) |
|
|
output = all_gather(input, group) |
|
|
output = all_gather(input, group) |
|
@@ -313,11 +315,28 @@ def all_gather( |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
mode = CollectiveComm.Mode.ALL_GATHER |
|
|
mode = CollectiveComm.Mode.ALL_GATHER |
|
|
return collective_comm(inp, mode, group, device) |
|
|
|
|
|
|
|
|
out = collective_comm(inp, mode, group, device) |
|
|
|
|
|
if axis == 0: |
|
|
|
|
|
return out |
|
|
|
|
|
else: |
|
|
|
|
|
group_size = group.size if group is not None else 1 |
|
|
|
|
|
transformed_shape = list(inp._tuple_shape) |
|
|
|
|
|
transformed_shape[axis] *= group_size |
|
|
|
|
|
n, *shp = out._tuple_shape |
|
|
|
|
|
index = ( |
|
|
|
|
|
[_ for _ in range(1, axis)] |
|
|
|
|
|
+ [axis, 0] |
|
|
|
|
|
+ [_ for _ in range(axis + 1, out.ndim + 1)] |
|
|
|
|
|
) |
|
|
|
|
|
return ( |
|
|
|
|
|
out.reshape(group_size, n // group_size, *shp) |
|
|
|
|
|
.transpose(index) |
|
|
|
|
|
.reshape(transformed_shape) |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reduce_scatter_sum( |
|
|
def reduce_scatter_sum( |
|
|
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, |
|
|
|
|
|
|
|
|
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0 |
|
|
) -> Tensor: |
|
|
) -> Tensor: |
|
|
r""" |
|
|
r""" |
|
|
Reduce tensors across the specified group by sum and split them at first dimension. |
|
|
Reduce tensors across the specified group by sum and split them at first dimension. |
|
@@ -331,6 +350,8 @@ def reduce_scatter_sum( |
|
|
None default device means the device of inp will be used. |
|
|
None default device means the device of inp will be used. |
|
|
Specify "gpu0:1" to execute this operator on diffrent cuda stream, |
|
|
Specify "gpu0:1" to execute this operator on diffrent cuda stream, |
|
|
1 is stream id, and default stream id is 0. |
|
|
1 is stream id, and default stream id is 0. |
|
|
|
|
|
axis: The split axis for collective_comm result |
|
|
|
|
|
The default axis is 0, the data will split in the 0 axis |
|
|
|
|
|
|
|
|
Returns: |
|
|
Returns: |
|
|
Split tensor. |
|
|
Split tensor. |
|
@@ -345,7 +366,7 @@ def reduce_scatter_sum( |
|
|
output = reduce_scatter_sum(input) |
|
|
output = reduce_scatter_sum(input) |
|
|
# Rank 0 # output: Tensor([0]) |
|
|
# Rank 0 # output: Tensor([0]) |
|
|
# Rank 1 # output: Tensor([2]) |
|
|
# Rank 1 # output: Tensor([2]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input = Tensor([0 1]) |
|
|
input = Tensor([0 1]) |
|
|
group = Group([1, 0]) |
|
|
group = Group([1, 0]) |
|
|
output = reduce_scatter_sum(input, group) |
|
|
output = reduce_scatter_sum(input, group) |
|
@@ -353,6 +374,23 @@ def reduce_scatter_sum( |
|
|
# Rank 1 # output: Tensor([0]) |
|
|
# Rank 1 # output: Tensor([0]) |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
|
|
|
group_size = group.size if group is not None else 1 |
|
|
|
|
|
assert ( |
|
|
|
|
|
list(inp._tuple_shape)[axis] % group_size == 0 |
|
|
|
|
|
), "current axis: {} can't devided by group size".format(axis) |
|
|
|
|
|
if axis != 0: |
|
|
|
|
|
k_new_shape = list(inp._tuple_shape) |
|
|
|
|
|
k_new_shape[axis] //= group_size |
|
|
|
|
|
k_new_shape[0] *= group_size |
|
|
|
|
|
new_shape = list(inp._tuple_shape) |
|
|
|
|
|
new_shape[axis] //= group_size |
|
|
|
|
|
new_shape.insert(axis, group_size) |
|
|
|
|
|
index = ( |
|
|
|
|
|
[axis] |
|
|
|
|
|
+ [_ for _ in range(0, axis)] |
|
|
|
|
|
+ [_ for _ in range(axis + 1, inp.ndim + 1)] |
|
|
|
|
|
) |
|
|
|
|
|
inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape) |
|
|
mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM |
|
|
mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM |
|
|
return collective_comm(inp, mode, group, device) |
|
|
return collective_comm(inp, mode, group, device) |
|
|
|
|
|
|
|
@@ -480,7 +518,7 @@ class _Gather(Function): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def gather( |
|
|
def gather( |
|
|
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, |
|
|
|
|
|
|
|
|
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0, |
|
|
) -> Tensor: |
|
|
) -> Tensor: |
|
|
r""" |
|
|
r""" |
|
|
Gather tensors across the specified group. |
|
|
Gather tensors across the specified group. |
|
@@ -495,7 +533,8 @@ def gather( |
|
|
None default device means the device of inp will be used. |
|
|
None default device means the device of inp will be used. |
|
|
Specify "gpu0:1" to execute this operator on diffrent cuda stream, |
|
|
Specify "gpu0:1" to execute this operator on diffrent cuda stream, |
|
|
1 is stream id, and default stream id is 0. |
|
|
1 is stream id, and default stream id is 0. |
|
|
|
|
|
|
|
|
|
|
|
axis: The concat axis for collective_comm result |
|
|
|
|
|
The default axis is 0 |
|
|
Returns: |
|
|
Returns: |
|
|
Result tensor if in root process, None if in other process |
|
|
Result tensor if in root process, None if in other process |
|
|
|
|
|
|
|
@@ -509,7 +548,7 @@ def gather( |
|
|
output = gather(input) |
|
|
output = gather(input) |
|
|
# Rank 0 # output: Tensor([0 1]) |
|
|
# Rank 0 # output: Tensor([0 1]) |
|
|
# Rank 1 # output: None |
|
|
# Rank 1 # output: None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input = Tensor([rank]) |
|
|
input = Tensor([rank]) |
|
|
group = Group([1, 0]) # first rank is root |
|
|
group = Group([1, 0]) # first rank is root |
|
|
output = gather(input, group) |
|
|
output = gather(input, group) |
|
@@ -517,12 +556,33 @@ def gather( |
|
|
# Rank 1 # output: Tensor([1 0]) |
|
|
# Rank 1 # output: Tensor([1 0]) |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
|
|
|
assert ( |
|
|
|
|
|
axis < inp.ndim |
|
|
|
|
|
), "your concat_axis exceeds the dim of the tensor, the tensor shape is {}".format( |
|
|
|
|
|
inp.shape |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
op = _Gather(group, device) |
|
|
op = _Gather(group, device) |
|
|
(out,) = apply(op, inp) |
|
|
(out,) = apply(op, inp) |
|
|
|
|
|
|
|
|
if group.rank == 0: |
|
|
if group.rank == 0: |
|
|
return out |
|
|
|
|
|
|
|
|
if axis == 0: |
|
|
|
|
|
return out |
|
|
|
|
|
else: |
|
|
|
|
|
group_size = group.size |
|
|
|
|
|
transformed_shape = list(inp._tuple_shape) |
|
|
|
|
|
transformed_shape[axis] *= group_size |
|
|
|
|
|
n, *shp = out._tuple_shape |
|
|
|
|
|
index = ( |
|
|
|
|
|
[_ for _ in range(1, axis)] |
|
|
|
|
|
+ [axis, 0] |
|
|
|
|
|
+ [_ for _ in range(axis + 1, out.ndim + 1)] |
|
|
|
|
|
) |
|
|
|
|
|
return ( |
|
|
|
|
|
out.reshape(group_size, n // group_size, *shp) |
|
|
|
|
|
.transpose(index) |
|
|
|
|
|
.reshape(transformed_shape) |
|
|
|
|
|
) |
|
|
else: |
|
|
else: |
|
|
_save_output_for_autodiff(inp, out) |
|
|
_save_output_for_autodiff(inp, out) |
|
|
|
|
|
|
|
@@ -545,7 +605,7 @@ class _Scatter(Function): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def scatter( |
|
|
def scatter( |
|
|
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, |
|
|
|
|
|
|
|
|
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, axis=0, |
|
|
) -> Tensor: |
|
|
) -> Tensor: |
|
|
r""" |
|
|
r""" |
|
|
Split tensor in root process at first dimension. |
|
|
Split tensor in root process at first dimension. |
|
@@ -559,6 +619,8 @@ def scatter( |
|
|
None default device means the device of inp will be used. |
|
|
None default device means the device of inp will be used. |
|
|
Specify "gpu0:1" to execute this operator on diffrent cuda stream, |
|
|
Specify "gpu0:1" to execute this operator on diffrent cuda stream, |
|
|
1 is stream id, and default stream id is 0. |
|
|
1 is stream id, and default stream id is 0. |
|
|
|
|
|
axis: The concat axis for collective_comm result |
|
|
|
|
|
The default axis is 0 |
|
|
|
|
|
|
|
|
Returns: |
|
|
Returns: |
|
|
Split tensor. |
|
|
Split tensor. |
|
@@ -573,7 +635,7 @@ def scatter( |
|
|
output = scatter(input) |
|
|
output = scatter(input) |
|
|
# Rank 0 # output: Tensor([0]) |
|
|
# Rank 0 # output: Tensor([0]) |
|
|
# Rank 1 # output: Tensor([1]) |
|
|
# Rank 1 # output: Tensor([1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input = Tensor([0 1]) + rank*2 |
|
|
input = Tensor([0 1]) + rank*2 |
|
|
group = Group([1, 0]) # first rank is root |
|
|
group = Group([1, 0]) # first rank is root |
|
|
output = scatter(input, group) |
|
|
output = scatter(input, group) |
|
@@ -588,13 +650,35 @@ def scatter( |
|
|
|
|
|
|
|
|
_bcast_tracer_state(group, inp) |
|
|
_bcast_tracer_state(group, inp) |
|
|
|
|
|
|
|
|
|
|
|
assert ( |
|
|
|
|
|
list(inp._tuple_shape)[axis] % group.size == 0 |
|
|
|
|
|
), "current axis: {} can't devided by group size".format(axis) |
|
|
|
|
|
|
|
|
|
|
|
if axis != 0: |
|
|
|
|
|
group_size = group.size |
|
|
|
|
|
k_new_shape = list(inp._tuple_shape) |
|
|
|
|
|
k_new_shape[axis] //= group_size |
|
|
|
|
|
k_new_shape[0] *= group_size |
|
|
|
|
|
new_shape = list(inp._tuple_shape) |
|
|
|
|
|
new_shape[axis] //= group_size |
|
|
|
|
|
new_shape.insert(axis, group_size) |
|
|
|
|
|
index = ( |
|
|
|
|
|
[axis] |
|
|
|
|
|
+ [_ for _ in range(0, axis)] |
|
|
|
|
|
+ [_ for _ in range(axis + 1, inp.ndim + 1)] |
|
|
|
|
|
) |
|
|
|
|
|
inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape) |
|
|
op = _Scatter(group, device) |
|
|
op = _Scatter(group, device) |
|
|
(out,) = apply(op, inp) |
|
|
(out,) = apply(op, inp) |
|
|
return out |
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def all_to_all( |
|
|
def all_to_all( |
|
|
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, |
|
|
|
|
|
|
|
|
inp: Tensor, |
|
|
|
|
|
group: Optional[Group] = WORLD, |
|
|
|
|
|
device: Optional[str] = None, |
|
|
|
|
|
split_axis: int = 0, |
|
|
|
|
|
concat_axis: int = 0, |
|
|
) -> Tensor: |
|
|
) -> Tensor: |
|
|
r""" |
|
|
r""" |
|
|
Each process scatter input tensor to all processes and return gathered tensor. |
|
|
Each process scatter input tensor to all processes and return gathered tensor. |
|
@@ -608,6 +692,10 @@ def all_to_all( |
|
|
None default device means the device of inp will be used. |
|
|
None default device means the device of inp will be used. |
|
|
Specify "gpu0:1" to execute this operator on diffrent cuda stream, |
|
|
Specify "gpu0:1" to execute this operator on diffrent cuda stream, |
|
|
1 is stream id, and default stream id is 0. |
|
|
1 is stream id, and default stream id is 0. |
|
|
|
|
|
split_axis: The axis that collectivecomm will split data |
|
|
|
|
|
the default axis is 0 |
|
|
|
|
|
split_axis: The axis that collectivecomm will concat data |
|
|
|
|
|
the default axis is 0 |
|
|
|
|
|
|
|
|
Returns: |
|
|
Returns: |
|
|
Result tensor. |
|
|
Result tensor. |
|
@@ -622,7 +710,7 @@ def all_to_all( |
|
|
output = all_to_all(input) |
|
|
output = all_to_all(input) |
|
|
# Rank 0 # output: Tensor([0 2]) |
|
|
# Rank 0 # output: Tensor([0 2]) |
|
|
# Rank 1 # output: Tensor([1 3]) |
|
|
# Rank 1 # output: Tensor([1 3]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input = Tensor([0 1]) + rank*2 |
|
|
input = Tensor([0 1]) + rank*2 |
|
|
group = Group([1, 0]) |
|
|
group = Group([1, 0]) |
|
|
output = all_to_all(input, group) |
|
|
output = all_to_all(input, group) |
|
@@ -630,8 +718,46 @@ def all_to_all( |
|
|
# Rank 1 # output: Tensor([2 1]) |
|
|
# Rank 1 # output: Tensor([2 1]) |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
|
|
|
group_size = group.size if group is not None else 1 |
|
|
|
|
|
assert ( |
|
|
|
|
|
list(inp._tuple_shape)[split_axis] % group_size == 0 |
|
|
|
|
|
), "current axis: {} can't devided by group size".format(split_axis) |
|
|
|
|
|
origin_shape = inp._tuple_shape |
|
|
|
|
|
if split_axis != 0: |
|
|
|
|
|
k_new_shape = list(inp._tuple_shape) |
|
|
|
|
|
k_new_shape[split_axis] //= group_size |
|
|
|
|
|
k_new_shape[0] *= group_size |
|
|
|
|
|
new_shape = list(inp._tuple_shape) |
|
|
|
|
|
new_shape[split_axis] //= group_size |
|
|
|
|
|
new_shape.insert(split_axis, group_size) |
|
|
|
|
|
index = ( |
|
|
|
|
|
[split_axis] |
|
|
|
|
|
+ [_ for _ in range(0, split_axis)] |
|
|
|
|
|
+ [_ for _ in range(split_axis + 1, inp.ndim + 1)] |
|
|
|
|
|
) |
|
|
|
|
|
inp = inp.reshape(new_shape).transpose(index).reshape(k_new_shape) |
|
|
|
|
|
|
|
|
mode = CollectiveComm.Mode.ALL_TO_ALL |
|
|
mode = CollectiveComm.Mode.ALL_TO_ALL |
|
|
return collective_comm(inp, mode, group, device) |
|
|
|
|
|
|
|
|
out = collective_comm(inp, mode, group, device) |
|
|
|
|
|
|
|
|
|
|
|
if concat_axis == 0: |
|
|
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
transformed_shape = list(origin_shape) |
|
|
|
|
|
transformed_shape[concat_axis] *= group_size |
|
|
|
|
|
transformed_shape[split_axis] //= group_size |
|
|
|
|
|
|
|
|
|
|
|
n, *shp = out._tuple_shape |
|
|
|
|
|
index = ( |
|
|
|
|
|
[_ for _ in range(1, concat_axis)] |
|
|
|
|
|
+ [concat_axis, 0] |
|
|
|
|
|
+ [_ for _ in range(concat_axis + 1, out.ndim + 1)] |
|
|
|
|
|
) |
|
|
|
|
|
return ( |
|
|
|
|
|
out.reshape(group_size, n // group_size, *shp) |
|
|
|
|
|
.transpose(index) |
|
|
|
|
|
.reshape(transformed_shape) |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class _SendRecvGroup: |
|
|
class _SendRecvGroup: |
|
|