GitOrigin-RevId: 30cf2f514b
tags/v1.0.0-rc1
@@ -0,0 +1,295 @@ | |||||
# -*- coding: utf-8 -*- | |||||
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") | |||||
# | |||||
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved. | |||||
# | |||||
# Unless required by applicable law or agreed to in writing, | |||||
# software distributed under the License is distributed on an | |||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
from typing import Optional, Tuple | |||||
from ..core._imperative_rt.ops import CollectiveCommMode | |||||
from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn | |||||
from ..core.autodiff.grad import ( | |||||
Tracer, | |||||
check_backward_allow_noinput, | |||||
get_grad_managers, | |||||
get_op_has_grad_fn, | |||||
tracer_apply, | |||||
) | |||||
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||||
from ..core.tensor.core import apply | |||||
from ..core.tensor.tensor import Tensor, tensor_apply | |||||
from ..tensor import tensor | |||||
from ..device import get_default_device | |||||
from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank | |||||
__all__ = [ | |||||
"reduce_sum", | |||||
"broadcast", | |||||
"all_gather", | |||||
"reduce_scatter_sum", | |||||
"all_reduce_sum", | |||||
"all_reduce_max", | |||||
"all_reduce_min", | |||||
"gather", | |||||
"scatter", | |||||
"all_to_all", | |||||
"remote_send", | |||||
"remote_recv", | |||||
] | |||||
@apply.add | |||||
def _(op: RemoteSend, *args: Tensor): | |||||
ret = tensor_apply(op, *args) | |||||
# set extra information | |||||
tracer_set = dict() | |||||
for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))): | |||||
tracer_set[k.name] = True | |||||
# check tracer_set in remote_recv | |||||
get_client().set_remote_tracer(op.key, tracer_set) | |||||
return ret | |||||
@builtin_op_get_backward_fn.register(RemoteSend) | |||||
def _(op: RemoteSend, inputs, outputs, input_requires_grad): | |||||
def backward(*args): | |||||
return [ | |||||
remote_recv( | |||||
op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device) | |||||
) | |||||
] | |||||
return backward, [True] | |||||
@get_op_has_grad_fn.register(RemoteSend) | |||||
def _(op: RemoteSend): | |||||
def has_grad(opnode, reached): | |||||
return get_client().check_is_grad(op.key) | |||||
return has_grad | |||||
@check_backward_allow_noinput.register(RemoteSend) | |||||
def _(op: RemoteSend): | |||||
return True | |||||
@builtin_op_get_backward_fn.register(RemoteRecv) | |||||
def _(op: RemoteRecv, inputs, outputs, input_requires_grad): | |||||
def backward(*output_grads): | |||||
return [remote_send(output_grads[0], op.rank_from)] | |||||
return backward, [True] | |||||
@get_op_has_grad_fn.register(RemoteRecv) | |||||
def _(op: RemoteRecv): | |||||
def has_grad(opnode, reached): | |||||
ret = False | |||||
for v in opnode.outputs: | |||||
if v() in reached: | |||||
ret = True | |||||
break | |||||
get_client().set_is_grad(op.key, ret) | |||||
return ret | |||||
return has_grad | |||||
def collective_comm(inp, mode, group, device): | |||||
"""Helper function for applying collective communication functions""" | |||||
assert isinstance(group, Group) | |||||
if group is None: | |||||
return inp | |||||
op = CollectiveComm() | |||||
op.key = group.key | |||||
op.nr_devices = group.size | |||||
op.rank = group.rank | |||||
op.is_root = op.rank == 0 | |||||
op.local_grad = False | |||||
op.addr, op.port = get_mm_server_addr() | |||||
op.mode = mode | |||||
op.dtype = inp.dtype | |||||
op.backend = get_backend() | |||||
op.comp_node = device | |||||
return apply(op, inp)[0] | |||||
def reduce_sum( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create reduce_sum operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.REDUCE_SUM | |||||
return collective_comm(inp, mode, group, device) | |||||
def broadcast( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create broadcast operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.BROADCAST | |||||
return collective_comm(inp, mode, group, device) | |||||
def all_gather( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create all_gather operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.ALL_GATHER | |||||
return collective_comm(inp, mode, group, device) | |||||
def reduce_scatter_sum( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create reduce_scatter_sum operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.REDUCE_SCATTER_SUM | |||||
return collective_comm(inp, mode, group, device) | |||||
def all_reduce_sum( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create all_reduce_sum operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.ALL_REDUCE_SUM | |||||
return collective_comm(inp, mode, group, device) | |||||
def all_reduce_max( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create all_reduce_max operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.ALL_REDUCE_MAX | |||||
return collective_comm(inp, mode, group, device) | |||||
def all_reduce_min( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create all_reduce_min operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.ALL_REDUCE_MIN | |||||
return collective_comm(inp, mode, group, device) | |||||
def gather( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create gather operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.GATHER | |||||
return collective_comm(inp, mode, group, device) | |||||
def scatter( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create scatter operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.SCATTER | |||||
return collective_comm(inp, mode, group, device) | |||||
def all_to_all( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create all_to_all operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.ALL_TO_ALL | |||||
return collective_comm(inp, mode, group, device) | |||||
def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||||
"""Send a Tensor to a remote process | |||||
:param inp: tensor to send | |||||
:param dest_rank: destination process rank | |||||
""" | |||||
op = RemoteSend() | |||||
op.key = "{}->{}".format(get_rank(), dest_rank) | |||||
op.addr, op.port = get_mm_server_addr() | |||||
op.rank_to = dest_rank | |||||
return apply(op, inp)[0] | |||||
def remote_recv( | |||||
src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None | |||||
) -> Tensor: | |||||
"""Receive a Tensor from a remote process | |||||
:param src_rank: source process rank | |||||
:param shape: the shape of the tensor to receive | |||||
:param dtype: the data type of the tensor to receive | |||||
:param device: the device to place the received tensor | |||||
""" | |||||
key = "{}->{}".format(src_rank, get_rank()) | |||||
if device is None: | |||||
device = get_default_device() | |||||
# dummpy input | |||||
inp = tensor([0]) | |||||
tracer_set = get_client().check_remote_tracer(key) | |||||
for grad_manager in get_grad_managers(): | |||||
if grad_manager.name in tracer_set: | |||||
grad_manager.wrt(inp) | |||||
op = RemoteRecv() | |||||
op.key = key | |||||
op.cn = device | |||||
op.shape = shape | |||||
op.dtype = dtype | |||||
op.addr, op.port = get_mm_server_addr() | |||||
op.rank_from = src_rank | |||||
return apply(op, inp)[0] |
@@ -6,298 +6,19 @@ | |||||
# Unless required by applicable law or agreed to in writing, | # Unless required by applicable law or agreed to in writing, | ||||
# software distributed under the License is distributed on an | # software distributed under the License is distributed on an | ||||
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
from typing import Optional, Tuple | |||||
from ..core._imperative_rt.ops import CollectiveCommMode | |||||
from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn | |||||
from ..core.autodiff.grad import ( | |||||
Tracer, | |||||
check_backward_allow_noinput, | |||||
get_grad_managers, | |||||
get_op_has_grad_fn, | |||||
tracer_apply, | |||||
# pylint: disable=redefined-builtin | |||||
from ..distributed.functional import ( | |||||
all_gather, | |||||
all_reduce_max, | |||||
all_reduce_min, | |||||
all_reduce_sum, | |||||
all_to_all, | |||||
broadcast, | |||||
collective_comm, | |||||
gather, | |||||
reduce_scatter_sum, | |||||
reduce_sum, | |||||
remote_recv, | |||||
remote_send, | |||||
scatter, | |||||
) | ) | ||||
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend | |||||
from ..core.tensor.core import apply | |||||
from ..core.tensor.tensor import Tensor | |||||
from ..device import get_default_device | |||||
from ..distributed.group import ( | |||||
WORLD, | |||||
Group, | |||||
get_backend, | |||||
get_client, | |||||
get_mm_server_addr, | |||||
get_rank, | |||||
) | |||||
from ..tensor import tensor | |||||
__all__ = [ | |||||
"reduce_sum", | |||||
"broadcast", | |||||
"all_gather", | |||||
"reduce_scatter_sum", | |||||
"all_reduce_sum", | |||||
"all_reduce_max", | |||||
"all_reduce_min", | |||||
"gather", | |||||
"scatter", | |||||
"all_to_all", | |||||
"remote_send", | |||||
"remote_recv", | |||||
] | |||||
@apply.register() | |||||
def _(op: RemoteSend, *args: Tensor): | |||||
ret = apply.super(op, *args) | |||||
# set extra information | |||||
tracer_set = dict() | |||||
for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))): | |||||
tracer_set[k.name] = True | |||||
# check tracer_set in remote_recv | |||||
get_client().set_remote_tracer(op.key, tracer_set) | |||||
return ret | |||||
@builtin_op_get_backward_fn.register(RemoteSend) | |||||
def _(op: RemoteSend, inputs, outputs, input_requires_grad): | |||||
def backward(*args): | |||||
return [ | |||||
remote_recv( | |||||
op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device) | |||||
) | |||||
] | |||||
return backward, [True] | |||||
@get_op_has_grad_fn.register(RemoteSend) | |||||
def _(op: RemoteSend): | |||||
def has_grad(opnode, reached): | |||||
return get_client().check_is_grad(op.key) | |||||
return has_grad | |||||
@check_backward_allow_noinput.register(RemoteSend) | |||||
def _(op: RemoteSend): | |||||
return True | |||||
@builtin_op_get_backward_fn.register(RemoteRecv) | |||||
def _(op: RemoteRecv, inputs, outputs, input_requires_grad): | |||||
def backward(*output_grads): | |||||
return [remote_send(output_grads[0], op.rank_from)] | |||||
return backward, [True] | |||||
@get_op_has_grad_fn.register(RemoteRecv) | |||||
def _(op: RemoteRecv): | |||||
def has_grad(opnode, reached): | |||||
ret = False | |||||
for v in opnode.outputs: | |||||
if v() in reached: | |||||
ret = True | |||||
break | |||||
get_client().set_is_grad(op.key, ret) | |||||
return ret | |||||
return has_grad | |||||
def collective_comm(inp, mode, group, device): | |||||
"""Helper function for applying collective communication functions""" | |||||
assert isinstance(group, Group) | |||||
if group is None: | |||||
return inp | |||||
op = CollectiveComm() | |||||
op.key = group.key | |||||
op.nr_devices = group.size | |||||
op.rank = group.rank | |||||
op.is_root = op.rank == 0 | |||||
op.local_grad = False | |||||
op.addr, op.port = get_mm_server_addr() | |||||
op.mode = mode | |||||
op.dtype = inp.dtype | |||||
op.backend = get_backend() | |||||
op.comp_node = device | |||||
return apply(op, inp)[0] | |||||
def reduce_sum( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create reduce_sum operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.REDUCE_SUM | |||||
return collective_comm(inp, mode, group, device) | |||||
def broadcast( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create broadcast operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.BROADCAST | |||||
return collective_comm(inp, mode, group, device) | |||||
def all_gather( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create all_gather operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.ALL_GATHER | |||||
return collective_comm(inp, mode, group, device) | |||||
def reduce_scatter_sum( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create reduce_scatter_sum operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.REDUCE_SCATTER_SUM | |||||
return collective_comm(inp, mode, group, device) | |||||
def all_reduce_sum( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create all_reduce_sum operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.ALL_REDUCE_SUM | |||||
return collective_comm(inp, mode, group, device) | |||||
def all_reduce_max( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create all_reduce_max operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.ALL_REDUCE_MAX | |||||
return collective_comm(inp, mode, group, device) | |||||
def all_reduce_min( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create all_reduce_min operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.ALL_REDUCE_MIN | |||||
return collective_comm(inp, mode, group, device) | |||||
def gather( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create gather operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.GATHER | |||||
return collective_comm(inp, mode, group, device) | |||||
def scatter( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create scatter operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.SCATTER | |||||
return collective_comm(inp, mode, group, device) | |||||
def all_to_all( | |||||
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" | |||||
) -> Tensor: | |||||
"""Create all_to_all operator for collective communication | |||||
:param inp: input tensor | |||||
:param group: communication group | |||||
:param device: execute placement | |||||
""" | |||||
mode = CollectiveCommMode.ALL_TO_ALL | |||||
return collective_comm(inp, mode, group, device) | |||||
def remote_send(inp: Tensor, dest_rank: int) -> Tensor: | |||||
"""Send a Tensor to a remote process | |||||
:param inp: tensor to send | |||||
:param dest_rank: destination process rank | |||||
""" | |||||
op = RemoteSend() | |||||
op.key = "{}->{}".format(get_rank(), dest_rank) | |||||
op.addr, op.port = get_mm_server_addr() | |||||
op.rank_to = dest_rank | |||||
return apply(op, inp)[0] | |||||
def remote_recv( | |||||
src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None | |||||
) -> Tensor: | |||||
"""Receive a Tensor from a remote process | |||||
:param src_rank: source process rank | |||||
:param shape: the shape of the tensor to receive | |||||
:param dtype: the data type of the tensor to receive | |||||
:param device: the device to place the received tensor, | |||||
if None, use default device | |||||
""" | |||||
key = "{}->{}".format(src_rank, get_rank()) | |||||
if device is None: | |||||
device = get_default_device() | |||||
# dummpy input | |||||
inp = tensor([0]) | |||||
tracer_set = get_client().check_remote_tracer(key) | |||||
for grad_manager in get_grad_managers(): | |||||
if grad_manager.name in tracer_set: | |||||
grad_manager.wrt(inp) | |||||
op = RemoteRecv() | |||||
op.key = key | |||||
op.cn = device | |||||
op.shape = shape | |||||
op.dtype = dtype | |||||
op.addr, op.port = get_mm_server_addr() | |||||
op.rank_from = src_rank | |||||
return apply(op, inp)[0] |