|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- from typing import Optional, Tuple
-
- import numpy as np
-
- from ..core._imperative_rt.core2 import apply
- from ..core.autodiff.grad import Function, _grad_manager_dict
- from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
- from ..core.tensor.utils import isscalar, setscalar
- from ..device import get_default_device
- from ..tensor import Tensor
- from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank
-
- __all__ = [
- "reduce_sum",
- "broadcast",
- "all_gather",
- "reduce_scatter_sum",
- "all_reduce_sum",
- "all_reduce_max",
- "all_reduce_min",
- "gather",
- "scatter",
- "all_to_all",
- "remote_send",
- "remote_recv",
- ]
-
-
- def collective_comm(inp, mode, group, device):
- """Helper function for applying collective communication functions."""
- assert isinstance(group, Group)
- if group is None:
- return inp
- addr, port = get_mm_server_addr()
- op = CollectiveComm(
- key=group.key,
- nr_devices=group.size,
- rank=group.rank,
- is_root=(group.rank == 0),
- local_grad=False,
- addr=addr,
- port=port,
- mode=mode,
- dtype=inp.dtype,
- backend=get_backend(),
- comp_node=device,
- )
- (result,) = apply(op, inp)
- # assume all workers have homogeneous shape
- if mode in (
- CollectiveComm.Mode.REDUCE_SUM,
- CollectiveComm.Mode.BROADCAST,
- CollectiveComm.Mode.ALL_REDUCE_SUM,
- CollectiveComm.Mode.ALL_REDUCE_MAX,
- CollectiveComm.Mode.ALL_REDUCE_MIN,
- ):
- if isscalar(inp):
- setscalar(result)
- return result
-
-
- def _save_output_for_autodiff(inp, out):
- for g in _grad_manager_dict.values():
- if g._is_attached_to(inp):
- g._refkeeper.append(out)
-
-
- def _bcast_has_grad(group, grad):
- if group.rank == 0:
- has_grad = grad is not None
- get_client().bcast_val(has_grad, group.key, group.size)
- else:
- has_grad = get_client().bcast_val(None, group.key, group.size)
- return has_grad
-
-
- def _bcast_shape_dtype(group, inp):
- if group.rank == 0:
- # FIXME in some cases, shape is not available(output of condtake)
- shape = inp._tuple_shape
- dtype = np.dtype(inp.dtype).name
- get_client().bcast_val({"shape": shape, "dtype": dtype}, group.key, group.size)
- else:
- val = get_client().bcast_val(None, group.key, group.size)
- shape = val["shape"]
- dtype = val["dtype"]
-
- return shape, dtype
-
-
- def _bcast_tracer_state(group, inp):
- if group.rank == 0:
- tracer_keys = []
- for n, g in _grad_manager_dict.items():
- if g._is_attached_to(inp):
- tracer_keys.append(n)
- get_client().bcast_val(tracer_keys, group.key, group.size)
- else:
- tracer_keys = get_client().bcast_val(None, group.key, group.size)
- for n in tracer_keys:
- g = _grad_manager_dict.get(n)
- if g is not None:
- g.wrt(inp)
- g._refkeeper.append(inp)
-
-
- def _dummy_input(shape, dtype, device=""):
- if device == "":
- device = get_default_device()
- inp = Tensor(0, dtype=dtype, device=device)
- if len(shape) > 0:
- inp = inp._broadcast(shape)
- return inp
-
-
- class _ReduceSum(Function):
- def __init__(self, group=WORLD, device=""):
- self.group = group
- self.out_device = device
-
- def forward(self, data):
- self.in_device = str(data.device)
- return collective_comm(
- data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device
- )
-
- def backward(self, grad):
- has_grad = _bcast_has_grad(self.group, grad)
- if has_grad:
- return broadcast(grad, self.group, self.in_device)
-
-
- def reduce_sum(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """
- Create reduce_sum operator for collective communication.
-
- :param inp: input tensor.
- :param group: communication group.
- :param device: execution device.
- """
- op = _ReduceSum(group, device)
- (out,) = apply(op, inp)
-
- if group.rank == 0:
- return out
- else:
- _save_output_for_autodiff(inp, out)
-
-
- class _Broadcast(Function):
- def __init__(self, group=WORLD, device=""):
- self.group = group
- self.out_device = device
-
- def forward(self, data):
- self.in_device = str(data.device)
- return collective_comm(
- data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device
- )
-
- def backward(self, grad):
- # TODO backward with a part of grad
- if grad is not None:
- return reduce_sum(grad, self.group, self.in_device)
-
-
- def broadcast(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """
- Create broadcast operator for collective communication.
-
- :param inp: input tensor.
- :param group: communication group.
- :param device: execution device.
- """
- shape, dtype = _bcast_shape_dtype(group, inp)
- if group.rank != 0:
- # dummy input to infer shape
- inp = _dummy_input(shape, dtype, device)
-
- _bcast_tracer_state(group, inp)
-
- op = _Broadcast(group, device)
- (out,) = apply(op, inp)
- return out
-
-
- def _bcast_param(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- mode = CollectiveComm.Mode.BROADCAST
- return collective_comm(inp, mode, group, device)
-
-
- def all_gather(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """
- Create all_gather operator for collective communication.
-
- :param inp: input tensor.
- :param group: communication group.
- :param device: execution device.
- """
- mode = CollectiveComm.Mode.ALL_GATHER
- return collective_comm(inp, mode, group, device)
-
-
- def reduce_scatter_sum(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """
- Create reduce_scatter_sum operator for collective communication.
-
- :param inp: input tensor.
- :param group: communication group.
- :param device: execution device.
- """
- mode = CollectiveComm.Mode.REDUCE_SCATTER_SUM
- return collective_comm(inp, mode, group, device)
-
-
- def all_reduce_sum(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """
- Create all_reduce_sum operator for collective communication.
-
- :param inp: input tensor.
- :param group: communication group.
- :param device: execution device.
- """
- mode = CollectiveComm.Mode.ALL_REDUCE_SUM
- return collective_comm(inp, mode, group, device)
-
-
- def all_reduce_max(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """
- Create all_reduce_max operator for collective communication.
-
- :param inp: input tensor.
- :param group: communication group.
- :param device: execution device.
- """
- mode = CollectiveComm.Mode.ALL_REDUCE_MAX
- return collective_comm(inp, mode, group, device)
-
-
- def all_reduce_min(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """
- Create all_reduce_min operator for collective communication.
-
- :param inp: input tensor.
- :param group: communication group.
- :param device: execution device.
- """
- mode = CollectiveComm.Mode.ALL_REDUCE_MIN
- return collective_comm(inp, mode, group, device)
-
-
- class _Gather(Function):
- def __init__(self, group=WORLD, device=""):
- self.group = group
- self.out_device = device
-
- def forward(self, data):
- self.in_device = str(data.device)
- return collective_comm(
- data, CollectiveComm.Mode.GATHER, self.group, self.out_device
- )
-
- def backward(self, grad):
- has_grad = _bcast_has_grad(self.group, grad)
- if has_grad:
- return scatter(grad, self.group, self.in_device)
-
-
- def gather(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """
- Create gather operator for collective communication.
-
- :param inp: input tensor.
- :param group: communication group.
- :param device: execution device.
- """
-
- op = _Gather(group, device)
- (out,) = apply(op, inp)
-
- if group.rank == 0:
- return out
- else:
- _save_output_for_autodiff(inp, out)
-
-
- class _Scatter(Function):
- def __init__(self, group=WORLD, device=""):
- self.group = group
- self.out_device = device
-
- def forward(self, data):
- self.in_device = str(data.device)
- return collective_comm(
- data, CollectiveComm.Mode.SCATTER, self.group, self.out_device
- )
-
- def backward(self, grad):
- # TODO backward with a part of grad
- if grad is not None:
- return gather(grad, self.group, self.in_device)
-
-
- def scatter(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """
- Create scatter operator for collective communication.
-
- :param inp: input tensor.
- :param group: communication group.
- :param device: execution device.
- """
- shape, dtype = _bcast_shape_dtype(group, inp)
- if group.rank != 0:
- # dummy input to infer shape
- inp = _dummy_input(shape, dtype, device)
-
- _bcast_tracer_state(group, inp)
-
- op = _Scatter(group, device)
- (out,) = apply(op, inp)
- return out
-
-
- def all_to_all(
- inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
- ) -> Tensor:
- """
- Create all_to_all operator for collective communication.
-
- :param inp: input tensor.
- :param group: communication group.
- :param device: execution device.
- """
- mode = CollectiveComm.Mode.ALL_TO_ALL
- return collective_comm(inp, mode, group, device)
-
-
- class _SendRecvGroup:
- def __init__(self, rank_from, rank_to):
- self.key = "{}->{}".format(rank_from, rank_to)
- self.rank_from = rank_from
- self.rank_to = rank_to
- self.size = 2
-
- @property
- def rank(self):
- if get_rank() == self.rank_from:
- return 0
- else:
- return 1
-
-
- class _RemoteSend(Function):
- def __init__(self, op: RemoteSend):
- self.op = op
-
- def forward(self, data):
- self.device = str(data.device)
- (self.dummy,) = apply(self.op, data)
- return self.dummy
-
- def backward(self, grad):
- assert grad is None
- has_grad = get_client().bcast_val(None, self.op.key, 2)
- if has_grad:
- return remote_recv(self.op.rank_to, device=self.device, inp=self.dummy,)
-
-
- class _RemoteRecv(Function):
- def __init__(self, op: RemoteRecv):
- self.op = op
-
- def forward(self, dummy):
- return apply(self.op, dummy)
-
- def backward(self, grad):
- get_client().bcast_val(grad is not None, self.op.key, 2)
- if grad is not None:
- remote_send(grad, self.op.rank_from)
-
-
- def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
- """
- Send a Tensor to a remote process.
-
- :param inp: tensor to send.
- :param dest_rank: destination process rank.
- """
- group = _SendRecvGroup(get_rank(), dest_rank)
- _bcast_shape_dtype(group, inp)
-
- _bcast_tracer_state(group, inp)
-
- op = RemoteSend()
- op.key = group.key
- op.addr, op.port = get_mm_server_addr()
- op.rank_to = dest_rank
- op.backend = get_backend()
- (out,) = apply(_RemoteSend(op), inp)
-
- _save_output_for_autodiff(inp, out)
-
-
- def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tensor:
- """
- Receive a Tensor from a remote process.
-
- :param src_rank: source process rank.
- :param device: the device to place the received tensor.
- :param inp: dummy input to determine recved tensor type
- """
- group = _SendRecvGroup(src_rank, get_rank())
- shape, dtype = _bcast_shape_dtype(group, None)
-
- if device is None:
- device = get_default_device()
- # dummy input
- if inp is None:
- inp = Tensor(0, device=device)
- _bcast_tracer_state(group, inp)
-
- _isscalar = False
- if len(shape) == 0:
- shape = (1,)
- _isscalar = True
-
- op = RemoteRecv()
- op.key = group.key
- op.cn = device
- op.shape = shape
- op.dtype = dtype
- op.addr, op.port = get_mm_server_addr()
- op.rank_from = src_rank
- op.backend = get_backend()
-
- (ret,) = apply(_RemoteRecv(op), inp)
- if _isscalar:
- setscalar(ret)
- return ret
|