Browse Source

fix(mge/imperative): move functional/distributed.py to distributed/functional.py

GitOrigin-RevId: 30cf2f514b
tags/v1.0.0-rc1
Megvii Engine Team 4 years ago
parent
commit
3f2eac2fe1
2 changed files with 310 additions and 294 deletions
  1. +295
    -0
      imperative/python/megengine/distributed/functional.py
  2. +15
    -294
      imperative/python/megengine/functional/distributed.py

+ 295
- 0
imperative/python/megengine/distributed/functional.py View File

@@ -0,0 +1,295 @@
# -*- coding: utf-8 -*-
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Optional, Tuple

from ..core._imperative_rt.ops import CollectiveCommMode
from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn
from ..core.autodiff.grad import (
Tracer,
check_backward_allow_noinput,
get_grad_managers,
get_op_has_grad_fn,
tracer_apply,
)
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.tensor.core import apply
from ..core.tensor.tensor import Tensor, tensor_apply
from ..tensor import tensor
from ..device import get_default_device
from .group import WORLD, Group, get_backend, get_client, get_mm_server_addr, get_rank

__all__ = [
"reduce_sum",
"broadcast",
"all_gather",
"reduce_scatter_sum",
"all_reduce_sum",
"all_reduce_max",
"all_reduce_min",
"gather",
"scatter",
"all_to_all",
"remote_send",
"remote_recv",
]


@apply.add
def _(op: RemoteSend, *args: Tensor):
ret = tensor_apply(op, *args)

# set extra information
tracer_set = dict()
for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))):
tracer_set[k.name] = True

# check tracer_set in remote_recv
get_client().set_remote_tracer(op.key, tracer_set)
return ret


@builtin_op_get_backward_fn.register(RemoteSend)
def _(op: RemoteSend, inputs, outputs, input_requires_grad):
def backward(*args):
return [
remote_recv(
op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device)
)
]

return backward, [True]


@get_op_has_grad_fn.register(RemoteSend)
def _(op: RemoteSend):
def has_grad(opnode, reached):
return get_client().check_is_grad(op.key)

return has_grad


@check_backward_allow_noinput.register(RemoteSend)
def _(op: RemoteSend):
return True


@builtin_op_get_backward_fn.register(RemoteRecv)
def _(op: RemoteRecv, inputs, outputs, input_requires_grad):
def backward(*output_grads):
return [remote_send(output_grads[0], op.rank_from)]

return backward, [True]


@get_op_has_grad_fn.register(RemoteRecv)
def _(op: RemoteRecv):
def has_grad(opnode, reached):
ret = False
for v in opnode.outputs:
if v() in reached:
ret = True
break
get_client().set_is_grad(op.key, ret)
return ret

return has_grad


def collective_comm(inp, mode, group, device):
"""Helper function for applying collective communication functions"""
assert isinstance(group, Group)
if group is None:
return inp
op = CollectiveComm()
op.key = group.key
op.nr_devices = group.size
op.rank = group.rank
op.is_root = op.rank == 0
op.local_grad = False
op.addr, op.port = get_mm_server_addr()
op.mode = mode
op.dtype = inp.dtype
op.backend = get_backend()
op.comp_node = device
return apply(op, inp)[0]


def reduce_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create reduce_sum operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.REDUCE_SUM
return collective_comm(inp, mode, group, device)


def broadcast(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create broadcast operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.BROADCAST
return collective_comm(inp, mode, group, device)


def all_gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create all_gather operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.ALL_GATHER
return collective_comm(inp, mode, group, device)


def reduce_scatter_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create reduce_scatter_sum operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.REDUCE_SCATTER_SUM
return collective_comm(inp, mode, group, device)


def all_reduce_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create all_reduce_sum operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.ALL_REDUCE_SUM
return collective_comm(inp, mode, group, device)


def all_reduce_max(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create all_reduce_max operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.ALL_REDUCE_MAX
return collective_comm(inp, mode, group, device)


def all_reduce_min(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create all_reduce_min operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.ALL_REDUCE_MIN
return collective_comm(inp, mode, group, device)


def gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create gather operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.GATHER
return collective_comm(inp, mode, group, device)


def scatter(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create scatter operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.SCATTER
return collective_comm(inp, mode, group, device)


def all_to_all(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create all_to_all operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.ALL_TO_ALL
return collective_comm(inp, mode, group, device)


def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
"""Send a Tensor to a remote process

:param inp: tensor to send
:param dest_rank: destination process rank
"""
op = RemoteSend()
op.key = "{}->{}".format(get_rank(), dest_rank)
op.addr, op.port = get_mm_server_addr()
op.rank_to = dest_rank
return apply(op, inp)[0]


def remote_recv(
src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None
) -> Tensor:
"""Receive a Tensor from a remote process

:param src_rank: source process rank
:param shape: the shape of the tensor to receive
:param dtype: the data type of the tensor to receive
:param device: the device to place the received tensor
"""
key = "{}->{}".format(src_rank, get_rank())

if device is None:
device = get_default_device()
# dummpy input
inp = tensor([0])
tracer_set = get_client().check_remote_tracer(key)
for grad_manager in get_grad_managers():
if grad_manager.name in tracer_set:
grad_manager.wrt(inp)

op = RemoteRecv()
op.key = key
op.cn = device
op.shape = shape
op.dtype = dtype
op.addr, op.port = get_mm_server_addr()
op.rank_from = src_rank

return apply(op, inp)[0]

+ 15
- 294
imperative/python/megengine/functional/distributed.py View File

@@ -6,298 +6,19 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Optional, Tuple

from ..core._imperative_rt.ops import CollectiveCommMode
from ..core.autodiff.builtin_op_utils import builtin_op_get_backward_fn
from ..core.autodiff.grad import (
Tracer,
check_backward_allow_noinput,
get_grad_managers,
get_op_has_grad_fn,
tracer_apply,
# pylint: disable=redefined-builtin
from ..distributed.functional import (
all_gather,
all_reduce_max,
all_reduce_min,
all_reduce_sum,
all_to_all,
broadcast,
collective_comm,
gather,
reduce_scatter_sum,
reduce_sum,
remote_recv,
remote_send,
scatter,
)
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.tensor.core import apply
from ..core.tensor.tensor import Tensor
from ..device import get_default_device
from ..distributed.group import (
WORLD,
Group,
get_backend,
get_client,
get_mm_server_addr,
get_rank,
)
from ..tensor import tensor

__all__ = [
"reduce_sum",
"broadcast",
"all_gather",
"reduce_scatter_sum",
"all_reduce_sum",
"all_reduce_max",
"all_reduce_min",
"gather",
"scatter",
"all_to_all",
"remote_send",
"remote_recv",
]


@apply.register()
def _(op: RemoteSend, *args: Tensor):
ret = apply.super(op, *args)

# set extra information
tracer_set = dict()
for k in set().union(*(i._extra_data for i in args if isinstance(i, Tensor))):
tracer_set[k.name] = True

# check tracer_set in remote_recv
get_client().set_remote_tracer(op.key, tracer_set)
return ret


@builtin_op_get_backward_fn.register(RemoteSend)
def _(op: RemoteSend, inputs, outputs, input_requires_grad):
def backward(*args):
return [
remote_recv(
op.rank_to, inputs[0].shape, inputs[0].dtype, str(inputs[0].device)
)
]

return backward, [True]


@get_op_has_grad_fn.register(RemoteSend)
def _(op: RemoteSend):
def has_grad(opnode, reached):
return get_client().check_is_grad(op.key)

return has_grad


@check_backward_allow_noinput.register(RemoteSend)
def _(op: RemoteSend):
return True


@builtin_op_get_backward_fn.register(RemoteRecv)
def _(op: RemoteRecv, inputs, outputs, input_requires_grad):
def backward(*output_grads):
return [remote_send(output_grads[0], op.rank_from)]

return backward, [True]


@get_op_has_grad_fn.register(RemoteRecv)
def _(op: RemoteRecv):
def has_grad(opnode, reached):
ret = False
for v in opnode.outputs:
if v() in reached:
ret = True
break
get_client().set_is_grad(op.key, ret)
return ret

return has_grad


def collective_comm(inp, mode, group, device):
"""Helper function for applying collective communication functions"""
assert isinstance(group, Group)
if group is None:
return inp
op = CollectiveComm()
op.key = group.key
op.nr_devices = group.size
op.rank = group.rank
op.is_root = op.rank == 0
op.local_grad = False
op.addr, op.port = get_mm_server_addr()
op.mode = mode
op.dtype = inp.dtype
op.backend = get_backend()
op.comp_node = device
return apply(op, inp)[0]


def reduce_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create reduce_sum operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.REDUCE_SUM
return collective_comm(inp, mode, group, device)


def broadcast(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create broadcast operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.BROADCAST
return collective_comm(inp, mode, group, device)


def all_gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create all_gather operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.ALL_GATHER
return collective_comm(inp, mode, group, device)


def reduce_scatter_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create reduce_scatter_sum operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.REDUCE_SCATTER_SUM
return collective_comm(inp, mode, group, device)


def all_reduce_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create all_reduce_sum operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.ALL_REDUCE_SUM
return collective_comm(inp, mode, group, device)


def all_reduce_max(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create all_reduce_max operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.ALL_REDUCE_MAX
return collective_comm(inp, mode, group, device)


def all_reduce_min(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create all_reduce_min operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.ALL_REDUCE_MIN
return collective_comm(inp, mode, group, device)


def gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create gather operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.GATHER
return collective_comm(inp, mode, group, device)


def scatter(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create scatter operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.SCATTER
return collective_comm(inp, mode, group, device)


def all_to_all(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
"""Create all_to_all operator for collective communication

:param inp: input tensor
:param group: communication group
:param device: execute placement
"""
mode = CollectiveCommMode.ALL_TO_ALL
return collective_comm(inp, mode, group, device)


def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
"""Send a Tensor to a remote process

:param inp: tensor to send
:param dest_rank: destination process rank
"""
op = RemoteSend()
op.key = "{}->{}".format(get_rank(), dest_rank)
op.addr, op.port = get_mm_server_addr()
op.rank_to = dest_rank
return apply(op, inp)[0]


def remote_recv(
src_rank: int, shape: Tuple[int], dtype: type, device: Optional[str] = None
) -> Tensor:
"""Receive a Tensor from a remote process

:param src_rank: source process rank
:param shape: the shape of the tensor to receive
:param dtype: the data type of the tensor to receive
:param device: the device to place the received tensor,
if None, use default device
"""
key = "{}->{}".format(src_rank, get_rank())
if device is None:
device = get_default_device()

# dummpy input
inp = tensor([0])
tracer_set = get_client().check_remote_tracer(key)
for grad_manager in get_grad_managers():
if grad_manager.name in tracer_set:
grad_manager.wrt(inp)

op = RemoteRecv()
op.key = key
op.cn = device
op.shape = shape
op.dtype = dtype
op.addr, op.port = get_mm_server_addr()
op.rank_from = src_rank

return apply(op, inp)[0]

Loading…
Cancel
Save