|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- from typing import Optional, Tuple
-
- from ..core._imperative_rt.ops import CollectiveCommDefModeEnum
- from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn
- from ..core.autodiff.grad import (
- Tracer,
- check_backward_allow_noinput,
- get_grad_managers,
- get_op_has_grad_fn,
- tracer_apply,
- )
- from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
- from ..core.tensor.core import apply
- from ..core.tensor.tensor import Tensor, tensor_apply
- from ..distributed.group import (
- WORLD,
- Group,
- get_backend,
- get_client,
- get_mm_server_addr,
- get_rank,
- )
- from ..tensor import tensor
-
- __all__ = [
- "reduce_sum",
- "broadcast",
- "all_gather",
- "reduce_scatter_sum",
- "all_reduce_sum",
- "all_reduce_max",
- "all_reduce_min",
- "gather",
- "scatter",
- "all_to_all",
- "remote_send",
- "remote_recv",
- ]
-
-
- @apply.add
- def _(op: RemoteSend, *args: Tensor):
- ret = tensor_apply(op, *args)
-
- # set extra information
- tracer_set = dict()
- for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))):
- tracer_set[k.name] = True
-
- # check tracer_set in remote_recv
- get_client().set_remote_tracer(op.key, tracer_set)
- return ret
-
-
- @builtin_op_get_backward_fn.register(RemoteSend)
- def _(op: RemoteSend, inputs, outputs, input_requires_grad):
- def backward(*args):
- return [
- remote_recv(
- op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device)
- )
- ]
-
- return backward, [True]
-
-
- @get_op_has_grad_fn.register(RemoteSend)
- def _(op: RemoteSend):
- def has_grad(opnode, reached):
- return get_client().check_is_grad(op.key)
-
- return has_grad
-
-
- @check_backward_allow_noinput.register(RemoteSend)
- def _(op: RemoteSend):
- return True
-
-
- @builtin_op_get_backward_fn.register(RemoteRecv)
- def _(op: RemoteRecv, inputs, outputs, input_requires_grad):
- def backward(*output_grads):
- return [remote_send(output_grads[0], op.rank_from)]
-
- return backward, [True]
-
-
- @get_op_has_grad_fn.register(RemoteRecv)
- def _(op: RemoteRecv):
- def has_grad(opnode, reached):
- ret = False
- for v in opnode.outputs:
- if v() in reached:
- ret = True
- break
- get_client().set_is_grad(op.key, ret)
- return ret
-
- return has_grad
-
-
- def collective_comm(inp, mode, group, device):
- """Helper function for applying collective communication functions"""
- assert isinstance(group, Group)
- if group is None:
- return inp
- op = CollectiveComm()
- op.key = group.key
- op.nr_devices = group.size
- op.rank = group.rank
- op.is_root = op.rank == 0
- op.local_grad = False
- op.addr, op.port = get_mm_server_addr()
- op.mode = mode
- op.dtype = inp.dtype
- op.backend = get_backend()
- op.comp_node = device
- return apply(op, inp)[0]
-
-
- def reduce_sum(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """Create reduce_sum operator for collective communication
-
- :param inp: input tensor
- :param group: communication group
- :param device: execute placement
- """
- mode = CollectiveCommDefModeEnum.REDUCE_SUM
- return collective_comm(inp, mode, group, device)
-
-
- def broadcast(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """Create broadcast operator for collective communication
-
- :param inp: input tensor
- :param group: communication group
- :param device: execute placement
- """
- mode = CollectiveCommDefModeEnum.BROADCAST
- return collective_comm(inp, mode, group, device)
-
-
- def all_gather(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """Create all_gather operator for collective communication
-
- :param inp: input tensor
- :param group: communication group
- :param device: execute placement
- """
- mode = CollectiveCommDefModeEnum.ALL_GATHER
- return collective_comm(inp, mode, group, device)
-
-
- def reduce_scatter_sum(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """Create reduce_scatter_sum operator for collective communication
-
- :param inp: input tensor
- :param group: communication group
- :param device: execute placement
- """
- mode = CollectiveCommDefModeEnum.REDUCE_SCATTER_SUM
- return collective_comm(inp, mode, group, device)
-
-
- def all_reduce_sum(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """Create all_reduce_sum operator for collective communication
-
- :param inp: input tensor
- :param group: communication group
- :param device: execute placement
- """
- mode = CollectiveCommDefModeEnum.ALL_REDUCE_SUM
- return collective_comm(inp, mode, group, device)
-
-
- def all_reduce_max(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """Create all_reduce_max operator for collective communication
-
- :param inp: input tensor
- :param group: communication group
- :param device: execute placement
- """
- mode = CollectiveCommDefModeEnum.ALL_REDUCE_MAX
- return collective_comm(inp, mode, group, device)
-
-
- def all_reduce_min(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """Create all_reduce_min operator for collective communication
-
- :param inp: input tensor
- :param group: communication group
- :param device: execute placement
- """
- mode = CollectiveCommDefModeEnum.ALL_REDUCE_MIN
- return collective_comm(inp, mode, group, device)
-
-
- def gather(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """Create gather operator for collective communication
-
- :param inp: input tensor
- :param group: communication group
- :param device: execute placement
- """
- mode = CollectiveCommDefModeEnum.GATHER
- return collective_comm(inp, mode, group, device)
-
-
- def scatter(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """Create scatter operator for collective communication
-
- :param inp: input tensor
- :param group: communication group
- :param device: execute placement
- """
- mode = CollectiveCommDefModeEnum.SCATTER
- return collective_comm(inp, mode, group, device)
-
-
- def all_to_all(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """Create all_to_all operator for collective communication
-
- :param inp: input tensor
- :param group: communication group
- :param device: execute placement
- """
- mode = CollectiveCommDefModeEnum.ALL_TO_ALL
- return collective_comm(inp, mode, group, device)
-
-
- def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
- """Send a Tensor to a remote process
-
- :param inp: tensor to send
- :param dest_rank: destination process rank
- """
- op = RemoteSend()
- op.key = "{}->{}".format(get_rank(), dest_rank)
- op.addr, op.port = get_mm_server_addr()
- op.rank_to = dest_rank
- return apply(op, inp)[0]
-
-
- def remote_recv(
- src_rank: int, shape: Tuple[int], dtype: type, cn: Optional[str] = "gpu0"
- ) -> Tensor:
- """Receive a Tensor from a remote process
-
- :param src_rank: source process rank
- :param shape: the shape of the tensor to receive
- :param dtype: the data type of the tensor to receive
- :param cn: the comp node to place the received tensor
- """
- key = "{}->{}".format(src_rank, get_rank())
-
- # dummpy input
- inp = tensor([0])
- tracer_set = get_client().check_remote_tracer(key)
- for grad_manager in get_grad_managers():
- if grad_manager.name in tracer_set:
- grad_manager.wrt(inp)
-
- op = RemoteRecv()
- op.key = key
- op.cn = cn
- op.shape = shape
- op.dtype = dtype
- op.addr, op.port = get_mm_server_addr()
- op.rank_from = src_rank
-
- return apply(op, inp)[0]
|