Browse Source

docs(distributed.functional): add return type for all_reduce_min

GitOrigin-RevId: 9f734902fe
release-1.5
Megvii Engine Team 3 years ago
parent
commit
9526ee521b
1 changed files with 44 additions and 4 deletions
  1. +44
    -4
      imperative/python/megengine/distributed/functional.py

+ 44
- 4
imperative/python/megengine/distributed/functional.py View File

@@ -320,12 +320,52 @@ def all_reduce_max(
def all_reduce_min(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = None,
) -> Tensor:
"""
r"""
Create all_reduce_min operator for collective communication.

:param inp: input tensor.
:param group: communication group.
:param device: execution device.
This operator calculates the minimum value of the tensor data by coordinates across the specified group and returns a tensor with the shape of the input tensor.

Args:
inp: The tensor data to apply this operator on.
group: The communication node list instance of :class:'Group' to apply this operator across. The default group is WORLD which means all processes available.
Specify a list of process ranks to apply this operator on specific processes, e.g. [1, 3, 5].
device: The specific device type of :class:'str' to execute this operator. The default device is None which mean the device of inp will be used.
Specify "cpu" or "gpu" to execute this operator on specific devices.

Returns:
opt: The reduce min tensor of the input tensor data across the specified group.

Examples:

.. code-block::

import megengine as mge
import megengine.distributed as dist
import numpy as np
from warnings import warn


def func(min_value):
# get the rank of this process, the ranks shold be 0, 1, 2, 3 for a 4 gpu task
rank = dist.get_rank()
data = mge.Tensor(rank)
# the result should be 0 for all processes
result = mge.functional.distributed.all_reduce_min(data).item()
assert result == min_value


def main():
p_num = dist.helper.get_device_count("gpu")
if p_num < 2:
warn('This opr only works on group with more than one gpu')
return
method = dist.launcher(func)
method(0)


if __name__ == '__main__':
main()

"""
mode = CollectiveComm.Mode.ALL_REDUCE_MIN
return collective_comm(inp, mode, group, device)


Loading…
Cancel
Save