GitOrigin-RevId: 30cf2f514b
tags/v1.0.0-rc1
@@ -0,0 +1,295 @@ | |||
# -*- coding: utf-8 -*- | |||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||
# | |||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||
# | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Optional, Tuple | |||
from ..core._imperative_rt.ops import CollectiveCommMode | |||
from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn | |||
from ..core.autodiff.grad import ( | |||
Tracer, | |||
check_backward_allow_noinput, | |||
get_grad_managers, | |||
get_op_has_grad_fn, | |||
tracer_apply, | |||
) | |||
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||
from ..core.tensor.core import apply | |||
from ..core.tensor.tensor import Tensor, tensor_apply | |||
from ..tensor import tensor | |||
from ..device import get_default_device | |||
from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank | |||
__all__ = [ | |||
"reduce_sum", | |||
"broadcast", | |||
"all_gather", | |||
"reduce_scatter_sum", | |||
"all_reduce_sum", | |||
"all_reduce_max", | |||
"all_reduce_min", | |||
"gather", | |||
"scatter", | |||
"all_to_all", | |||
"remote_send", | |||
"remote_recv", | |||
] | |||
@apply.add | |||
def _(op: RemoteSend, *args: Tensor): | |||
ret = tensor_apply(op, *args) | |||
# set extra information | |||
tracer_set = dict() | |||
for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))): | |||
tracer_set[k.name] = True | |||
# check tracer_set in remote_recv | |||
get_client().set_remote_tracer(op.key, tracer_set) | |||
return ret | |||
@builtin_op_get_backward_fn.register(RemoteSend) | |||
def _(op: RemoteSend, inputs, outputs, input_requires_grad): | |||
def backward(*args): | |||
return [ | |||
remote_recv( | |||
op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device) | |||
) | |||
] | |||
return backward, [True] | |||
@get_op_has_grad_fn.register(RemoteSend) | |||
def _(op: RemoteSend): | |||
def has_grad(opnode, reached): | |||
return get_client().check_is_grad(op.key) | |||
return has_grad | |||
@check_backward_allow_noinput.register(RemoteSend) | |||
def _(op: RemoteSend): | |||
return True | |||
@builtin_op_get_backward_fn.register(RemoteRecv) | |||
def _(op: RemoteRecv, inputs, outputs, input_requires_grad): | |||
def backward(*output_grads): | |||
return [remote_send(output_grads[0], op.rank_from)] | |||
return backward, [True] | |||
@get_op_has_grad_fn.register(RemoteRecv) | |||
def _(op: RemoteRecv): | |||
def has_grad(opnode, reached): | |||
ret = False | |||
for v in opnode.outputs: | |||
if v() in reached: | |||
ret = True | |||
break | |||
get_client().set_is_grad(op.key, ret) | |||
return ret | |||
return has_grad | |||
def collective_comm(inp, mode, group, device): | |||
"""Helper function for applying collective communication functions""" | |||
assert isinstance(group, Group) | |||
if group is None: | |||
return inp | |||
op = CollectiveComm() | |||
op.key = group.key | |||
op.nr_devices = group.size | |||
op.rank = group.rank | |||
op.is_root = op.rank == 0 | |||
op.local_grad = False | |||
op.addr, op.port = get_mm_server_addr() | |||
op.mode = mode | |||
op.dtype = inp.dtype | |||
op.backend = get_backend() | |||
op.comp_node = device | |||
return apply(op, inp)[0] | |||
def reduce_sum( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create reduce_sum operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.REDUCE_SUM | |||
return collective_comm(inp, mode, group, device) | |||
def broadcast( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create broadcast operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.BROADCAST | |||
return collective_comm(inp, mode, group, device) | |||
def all_gather( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create all_gather operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.ALL_GATHER | |||
return collective_comm(inp, mode, group, device) | |||
def reduce_scatter_sum( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create reduce_scatter_sum operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.REDUCE_SCATTER_SUM | |||
return collective_comm(inp, mode, group, device) | |||
def all_reduce_sum( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create all_reduce_sum operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.ALL_REDUCE_SUM | |||
return collective_comm(inp, mode, group, device) | |||
def all_reduce_max( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create all_reduce_max operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.ALL_REDUCE_MAX | |||
return collective_comm(inp, mode, group, device) | |||
def all_reduce_min( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create all_reduce_min operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.ALL_REDUCE_MIN | |||
return collective_comm(inp, mode, group, device) | |||
def gather( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create gather operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.GATHER | |||
return collective_comm(inp, mode, group, device) | |||
def scatter( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create scatter operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.SCATTER | |||
return collective_comm(inp, mode, group, device) | |||
def all_to_all( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create all_to_all operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.ALL_TO_ALL | |||
return collective_comm(inp, mode, group, device) | |||
def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||
"""Send a Tensor to a remote process | |||
:param inp: tensor to send | |||
:param dest_rank: destination process rank | |||
""" | |||
op = RemoteSend() | |||
op.key = "{}->{}".format(get_rank(), dest_rank) | |||
op.addr, op.port = get_mm_server_addr() | |||
op.rank_to = dest_rank | |||
return apply(op, inp)[0] | |||
def remote_recv( | |||
src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None | |||
) -> Tensor: | |||
"""Receive a Tensor from a remote process | |||
:param src_rank: source process rank | |||
:param shape: the shape of the tensor to receive | |||
:param dtype: the data type of the tensor to receive | |||
:param device: the device to place the received tensor | |||
""" | |||
key = "{}->{}".format(src_rank, get_rank()) | |||
if device is None: | |||
device = get_default_device() | |||
# dummpy input | |||
inp = tensor([0]) | |||
tracer_set = get_client().check_remote_tracer(key) | |||
for grad_manager in get_grad_managers(): | |||
if grad_manager.name in tracer_set: | |||
grad_manager.wrt(inp) | |||
op = RemoteRecv() | |||
op.key = key | |||
op.cn = device | |||
op.shape = shape | |||
op.dtype = dtype | |||
op.addr, op.port = get_mm_server_addr() | |||
op.rank_from = src_rank | |||
return apply(op, inp)[0] |
@@ -6,298 +6,19 @@ | |||
# Unless required by applicable law or agreed to in writing, | |||
# software distributed under the License is distributed on an | |||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
from typing import Optional, Tuple | |||
from ..core._imperative_rt.ops import CollectiveCommMode | |||
from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn | |||
from ..core.autodiff.grad import ( | |||
Tracer, | |||
check_backward_allow_noinput, | |||
get_grad_managers, | |||
get_op_has_grad_fn, | |||
tracer_apply, | |||
# pylint: disable=redefined-builtin | |||
from ..distributed.functional import ( | |||
all_gather, | |||
all_reduce_max, | |||
all_reduce_min, | |||
all_reduce_sum, | |||
all_to_all, | |||
broadcast, | |||
collective_comm, | |||
gather, | |||
reduce_scatter_sum, | |||
reduce_sum, | |||
remote_recv, | |||
remote_send, | |||
scatter, | |||
) | |||
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||
from ..core.tensor.core import apply | |||
from ..core.tensor.tensor import Tensor | |||
from ..device import get_default_device | |||
from ..distributed.group import ( | |||
WORLD, | |||
Group, | |||
get_backend, | |||
get_client, | |||
get_mm_server_addr, | |||
get_rank, | |||
) | |||
from ..tensor import tensor | |||
__all__ = [ | |||
"reduce_sum", | |||
"broadcast", | |||
"all_gather", | |||
"reduce_scatter_sum", | |||
"all_reduce_sum", | |||
"all_reduce_max", | |||
"all_reduce_min", | |||
"gather", | |||
"scatter", | |||
"all_to_all", | |||
"remote_send", | |||
"remote_recv", | |||
] | |||
@apply.register() | |||
def _(op: RemoteSend, *args: Tensor): | |||
ret = apply.super(op, *args) | |||
# set extra information | |||
tracer_set = dict() | |||
for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))): | |||
tracer_set[k.name] = True | |||
# check tracer_set in remote_recv | |||
get_client().set_remote_tracer(op.key, tracer_set) | |||
return ret | |||
@builtin_op_get_backward_fn.register(RemoteSend) | |||
def _(op: RemoteSend, inputs, outputs, input_requires_grad): | |||
def backward(*args): | |||
return [ | |||
remote_recv( | |||
op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device) | |||
) | |||
] | |||
return backward, [True] | |||
@get_op_has_grad_fn.register(RemoteSend) | |||
def _(op: RemoteSend): | |||
def has_grad(opnode, reached): | |||
return get_client().check_is_grad(op.key) | |||
return has_grad | |||
@check_backward_allow_noinput.register(RemoteSend) | |||
def _(op: RemoteSend): | |||
return True | |||
@builtin_op_get_backward_fn.register(RemoteRecv) | |||
def _(op: RemoteRecv, inputs, outputs, input_requires_grad): | |||
def backward(*output_grads): | |||
return [remote_send(output_grads[0], op.rank_from)] | |||
return backward, [True] | |||
@get_op_has_grad_fn.register(RemoteRecv) | |||
def _(op: RemoteRecv): | |||
def has_grad(opnode, reached): | |||
ret = False | |||
for v in opnode.outputs: | |||
if v() in reached: | |||
ret = True | |||
break | |||
get_client().set_is_grad(op.key, ret) | |||
return ret | |||
return has_grad | |||
def collective_comm(inp, mode, group, device): | |||
"""Helper function for applying collective communication functions""" | |||
assert isinstance(group, Group) | |||
if group is None: | |||
return inp | |||
op = CollectiveComm() | |||
op.key = group.key | |||
op.nr_devices = group.size | |||
op.rank = group.rank | |||
op.is_root = op.rank == 0 | |||
op.local_grad = False | |||
op.addr, op.port = get_mm_server_addr() | |||
op.mode = mode | |||
op.dtype = inp.dtype | |||
op.backend = get_backend() | |||
op.comp_node = device | |||
return apply(op, inp)[0] | |||
def reduce_sum( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create reduce_sum operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.REDUCE_SUM | |||
return collective_comm(inp, mode, group, device) | |||
def broadcast( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create broadcast operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.BROADCAST | |||
return collective_comm(inp, mode, group, device) | |||
def all_gather( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create all_gather operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.ALL_GATHER | |||
return collective_comm(inp, mode, group, device) | |||
def reduce_scatter_sum( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create reduce_scatter_sum operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.REDUCE_SCATTER_SUM | |||
return collective_comm(inp, mode, group, device) | |||
def all_reduce_sum( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create all_reduce_sum operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.ALL_REDUCE_SUM | |||
return collective_comm(inp, mode, group, device) | |||
def all_reduce_max( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create all_reduce_max operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.ALL_REDUCE_MAX | |||
return collective_comm(inp, mode, group, device) | |||
def all_reduce_min( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create all_reduce_min operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.ALL_REDUCE_MIN | |||
return collective_comm(inp, mode, group, device) | |||
def gather( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create gather operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.GATHER | |||
return collective_comm(inp, mode, group, device) | |||
def scatter( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create scatter operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.SCATTER | |||
return collective_comm(inp, mode, group, device) | |||
def all_to_all( | |||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||
) -> Tensor: | |||
"""Create all_to_all operator for collective communication | |||
:param inp: input tensor | |||
:param group: communication group | |||
:param device: execute placement | |||
""" | |||
mode = CollectiveCommMode.ALL_TO_ALL | |||
return collective_comm(inp, mode, group, device) | |||
def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||
"""Send a Tensor to a remote process | |||
:param inp: tensor to send | |||
:param dest_rank: destination process rank | |||
""" | |||
op = RemoteSend() | |||
op.key = "{}->{}".format(get_rank(), dest_rank) | |||
op.addr, op.port = get_mm_server_addr() | |||
op.rank_to = dest_rank | |||
return apply(op, inp)[0] | |||
def remote_recv( | |||
src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None | |||
) -> Tensor: | |||
"""Receive a Tensor from a remote process | |||
:param src_rank: source process rank | |||
:param shape: the shape of the tensor to receive | |||
:param dtype: the data type of the tensor to receive | |||
:param device: the device to place the received tensor, | |||
if None, use default device | |||
""" | |||
key = "{}->{}".format(src_rank, get_rank()) | |||
if device is None: | |||
device = get_default_device() | |||
# dummpy input | |||
inp = tensor([0]) | |||
tracer_set = get_client().check_remote_tracer(key) | |||
for grad_manager in get_grad_managers(): | |||
if grad_manager.name in tracer_set: | |||
grad_manager.wrt(inp) | |||
op = RemoteRecv() | |||
op.key = key | |||
op.cn = device | |||
op.shape = shape | |||
op.dtype = dtype | |||
op.addr, op.port = get_mm_server_addr() | |||
op.rank_from = src_rank | |||
return apply(op, inp)[0] |