|
|
@@ -320,12 +320,52 @@ def all_reduce_max( |
|
|
|
def all_reduce_min( |
|
|
|
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None, |
|
|
|
) -> Tensor: |
|
|
|
""" |
|
|
|
r""" |
|
|
|
Create all_reduce_min operator for collective communication. |
|
|
|
|
|
|
|
:param inp: input tensor. |
|
|
|
:param group: communication group. |
|
|
|
:param device: execution device. |
|
|
|
This operator calculates the minimum value of the tensor data by coordinates across the specified group and returns a tensor with the shape of the input tensor. |
|
|
|
|
|
|
|
Args: |
|
|
|
inp: The tensor data to apply this operator on. |
|
|
|
group: The communication node list instance of :class:'Group' to apply this operator across. The default group is WORLD which means all processes available. |
|
|
|
Specify a list of process ranks to apply this operator on specific processes, e.g. [1, 3, 5]. |
|
|
|
device: The specific device type of :class:'str' to execute this operator. The default device is None which mean the device of inp will be used. |
|
|
|
Specify "cpu" or "gpu" to execute this operator on specific devices. |
|
|
|
|
|
|
|
Returns: |
|
|
|
opt: The reduce min tensor of the input tensor data across the specified group. |
|
|
|
|
|
|
|
Examples: |
|
|
|
|
|
|
|
.. code-block:: |
|
|
|
|
|
|
|
import megengine as mge |
|
|
|
import megengine.distributed as dist |
|
|
|
import numpy as np |
|
|
|
from warnings import warn |
|
|
|
|
|
|
|
|
|
|
|
def func(min_value): |
|
|
|
# get the rank of this process, the ranks shold be 0, 1, 2, 3 for a 4 gpu task |
|
|
|
rank = dist.get_rank() |
|
|
|
data = mge.Tensor(rank) |
|
|
|
# the result should be 0 for all processes |
|
|
|
result = mge.functional.distributed.all_reduce_min(data).item() |
|
|
|
assert result == min_value |
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
p_num = dist.helper.get_device_count("gpu") |
|
|
|
if p_num < 2: |
|
|
|
warn('This opr only works on group with more than one gpu') |
|
|
|
return |
|
|
|
method = dist.launcher(func) |
|
|
|
method(0) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
main() |
|
|
|
|
|
|
|
""" |
|
|
|
mode = CollectiveComm.Mode.ALL_REDUCE_MIN |
|
|
|
return collective_comm(inp, mode, group, device) |
|
|
|