|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- from typing import Optional, Union
-
- import megengine._internal as mgb
- from megengine._internal.opr_param_defs import CollectiveComm as Param
-
- from ..core import Buffer, Parameter, Tensor, wrap_io_tensor
- from ..functional import add_update
- from .helper import collective_comm_symvar
- from .util import get_rank, is_distributed
-
-
- @wrap_io_tensor
- def _collective_comm(*args, **kargs):
- return collective_comm_symvar(*args, **kargs)
-
-
- def _group_check(*args):
- """Return True when arguments are all None or all not None
- """
- l = [val is None for val in args]
- return len(set(l)) <= 1
-
-
- def reduce_sum(
- tensor: Tensor,
- key: Optional[str] = None,
- nr_ranks: Optional[int] = None,
- is_root: Optional[bool] = None,
- ) -> Tensor:
- """Create reduce_sum operator for collective communication
-
- :param tensor: input tensor
- :param key: unique identifier for collective communication
- :param nr_ranks: number of ranks, use util.get_world_size() as default
- :param is_root: whether this is a root node
- """
- assert _group_check(
- key, nr_ranks, is_root
- ), "key, nr_ranks, is_root should be set at the same time"
- return _collective_comm(
- tensor, key, Param.Mode.REDUCE_SUM, nr_ranks, is_root, device=tensor.device,
- )
-
-
- def gather(
- tensor: Tensor,
- key: Optional[str] = None,
- nr_ranks: Optional[int] = None,
- is_root: Optional[bool] = None,
- rank: Optional[int] = None,
- ) -> Tensor:
- """Create gather operator for collective communication
-
- :param tensor: input tensor
- :param key: unique identifier for collective communication
- :param nr_ranks: number of ranks, use util.get_world_size() as default
- :param is_root: whether this is a root node
- :param rank: rank of this node
- """
- assert _group_check(
- key, nr_ranks, is_root, rank
- ), "key, nr_ranks, is_root, rank should be set at the same time"
- return _collective_comm(
- tensor, key, Param.Mode.GATHER, nr_ranks, is_root, rank, device=tensor.device,
- )
-
-
- def broadcast(
- tensor: Tensor,
- key: Optional[str] = None,
- nr_ranks: Optional[int] = None,
- is_root: Optional[bool] = None,
- ) -> Tensor:
- """Create broadcast operator for collective communication
-
- :param tensor: input tensor
- :param key: unique identifier for collective communication
- :param nr_ranks: number of ranks, use util.get_world_size() as default
- :param is_root: whether this is a root node
- """
- assert _group_check(
- key, nr_ranks, is_root
- ), "key, nr_ranks, is_root should be set at the same time"
-
- if is_root is None:
- is_root = get_rank() == 0
- if is_root:
- inp = tensor
- else:
- inp = tensor._symvar.owner_graph
-
- return _collective_comm(
- inp,
- key,
- Param.Mode.BROADCAST,
- nr_ranks,
- is_root,
- dtype=tensor.dtype,
- device=tensor.device,
- )
-
-
- def scatter(
- tensor: Tensor,
- key: Optional[str] = None,
- nr_ranks: Optional[int] = None,
- is_root: Optional[bool] = None,
- rank: Optional[int] = None,
- ) -> Tensor:
- """Create scatter operator for collective communication
-
- :param tensor: input tensor
- :param key: unique identifier for collective communication
- :param nr_ranks: number of ranks, use util.get_world_size() as default
- :param is_root: whether this is a root node
- :param rank: rank of this node
- """
- assert _group_check(
- key, nr_ranks, is_root, rank
- ), "key, nr_ranks, is_root, rank should be set at the same time"
- if key is None:
- key = tensor._symvar.name
- if is_root is None:
- is_root = get_rank() == 0
-
- if is_root:
- inp = tensor
- else:
- inp = tensor._symvar.owner_graph
-
- return _collective_comm(
- inp,
- key,
- Param.Mode.SCATTER,
- nr_ranks,
- is_root,
- rank,
- dtype=tensor.dtype,
- device=tensor.device,
- )
-
-
- def all_to_all(
- tensor: Tensor,
- key: Optional[str] = None,
- nr_ranks: Optional[int] = None,
- rank: Optional[int] = None,
- local_grad: Optional[bool] = False,
- ) -> Tensor:
- """Create all_to_all operator for collective communication
-
- :param tensor: input tensor
- :param key: unique identifier for collective communication
- :param nr_ranks: number of ranks, use util.get_world_size() as default
- :param rank: rank of this node
- :param local_grad: whether use local grad
- """
- assert _group_check(
- key, nr_ranks, rank
- ), "key, nr_ranks, rank should be set at the same time"
- return _collective_comm(
- tensor, key, Param.Mode.ALL_TO_ALL, nr_ranks, rank=rank, local_grad=local_grad,
- )
-
-
- def all_gather(
- tensor: Tensor,
- key: Optional[str] = None,
- nr_ranks: Optional[int] = None,
- rank: Optional[int] = None,
- local_grad: Optional[bool] = False,
- ) -> Tensor:
- """Create all_gather operator for collective communication
-
- :param tensor: input tensor
- :param key: unique identifier for collective communication
- :param nr_ranks: number of ranks, use util.get_world_size() as default
- :param rank: rank of this node
- :param local_grad: whether use local grad
- """
- assert _group_check(
- key, nr_ranks, rank
- ), "key, nr_ranks, rank should be set at the same time"
- return _collective_comm(
- tensor, key, Param.Mode.ALL_GATHER, nr_ranks, rank=rank, local_grad=local_grad
- )
-
-
- def reduce_scatter_sum(
- tensor: Tensor,
- key: Optional[str] = None,
- nr_ranks: Optional[int] = None,
- rank: Optional[int] = None,
- local_grad: Optional[bool] = False,
- ) -> Tensor:
- """Create reduce_scatter_sum operator for collective communication
-
- :param tensor: input tensor
- :param key: unique identifier for collective communication
- :param nr_ranks: number of ranks, use util.get_world_size() as default
- :param rank: rank of this node
- :param local_grad: whether use local grad
- """
- assert _group_check(
- key, nr_ranks, rank
- ), "key, nr_ranks, rank should be set at the same time"
- return _collective_comm(
- tensor,
- key,
- Param.Mode.REDUCE_SCATTER_SUM,
- nr_ranks,
- rank=rank,
- local_grad=local_grad,
- )
-
-
- def all_reduce_sum(
- tensor: Tensor,
- key: Optional[str] = None,
- nr_ranks: Optional[int] = None,
- local_grad: Optional[bool] = False,
- ) -> Tensor:
- """Create all_reduce_sum operator for collective communication
-
- :param tensor: input tensor
- :param key: unique identifier for collective communication
- :param nr_ranks: number of ranks, use util.get_world_size() as default
- :param local_grad: whether use local grad
- """
- assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time"
- return _collective_comm(
- tensor, key, Param.Mode.ALL_REDUCE_SUM, nr_ranks, local_grad=local_grad
- )
-
-
- def all_reduce_max(
- tensor: Tensor,
- key: Optional[str] = None,
- nr_ranks: Optional[int] = None,
- local_grad: Optional[bool] = False,
- ) -> Tensor:
- """Create all_reduce_max operator for collective communication
-
- :param tensor: input tensor
- :param key: unique identifier for collective communication
- :param nr_ranks: number of ranks, use util.get_world_size() as default
- :param local_grad: whether use local grad
- """
- assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time"
- return _collective_comm(
- tensor, key, Param.Mode.ALL_REDUCE_MAX, nr_ranks, local_grad=local_grad
- )
-
-
- def all_reduce_min(
- tensor: Tensor,
- key: Optional[str] = None,
- nr_ranks: Optional[int] = None,
- local_grad: Optional[bool] = False,
- ) -> Tensor:
- """Create all_reduce_min operator for collective communication
-
- :param tensor: input tensor
- :param key: unique identifier for collective communication
- :param nr_ranks: number of ranks, use util.get_world_size() as default
- :param local_grad: whether use local grad
- """
- assert _group_check(key, nr_ranks), "key, nr_ranks should be set at the same time"
- return _collective_comm(
- tensor, key, Param.Mode.ALL_REDUCE_MIN, nr_ranks, local_grad=local_grad
- )
-
-
- def bcast_param(
- inp: Union[Buffer, Parameter],
- key: Optional[str] = None,
- nr_ranks: Optional[int] = None,
- is_root: Optional[bool] = None,
- ) -> None:
- """Broadcast parameters among devices
-
- :param inp: input Buffer or Parameter to be synchronized
- :param key: unique identifier for collective communication
- :param nr_ranks: number of ranks, use util.get_world_size() as default
- :param is_root: whether this is a root node
- """
- if not is_distributed():
- return
- assert _group_check(
- key, nr_ranks, is_root
- ), "key, nr_ranks, is_root should be set at the same time"
- assert isinstance(inp, (Buffer, Parameter))
- bcast_res = broadcast(inp, key, nr_ranks, is_root)
- add_update(inp, bcast_res, alpha=0)
|