Browse Source

fix(mge/distributed): fix gather scatter reduce broadcast autodiff

GitOrigin-RevId: 1c2250a079
release-1.4
Megvii Engine Team 4 years ago
parent
commit
78d1e2b2ca
3 changed files with 214 additions and 62 deletions
  1. +200
    -62
      imperative/python/megengine/distributed/functional.py
  2. +3
    -0
      imperative/python/megengine/distributed/launcher.py
  3. +11
    -0
      imperative/python/megengine/distributed/server.py

+ 200
- 62
imperative/python/megengine/distributed/functional.py View File

@@ -8,9 +8,11 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from typing import Optional, Tuple

import numpy as np

from ..core._imperative_rt.core2 import apply
from ..core.autodiff.grad import _grad_manager_dict
from ..core.ops.builtin import CollectiveComm, Copy, PyOpBase, RemoteRecv, RemoteSend
from ..core.autodiff.grad import Function, _grad_manager_dict
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend
from ..core.tensor.utils import isscalar, setscalar
from ..device import get_default_device
from ..tensor import Tensor
@@ -65,6 +67,77 @@ def collective_comm(inp, mode, group, device):
return result


def _save_output_for_autodiff(inp, out):
for g in _grad_manager_dict.values():
if g._is_attached_to(inp):
g._refkeeper.append(out)


def _bcast_has_grad(group, grad):
if group.rank == 0:
has_grad = grad is not None
get_client().bcast_val(has_grad, group.key, group.size)
else:
has_grad = get_client().bcast_val(None, group.key, group.size)
return has_grad


def _bcast_shape_dtype(group, inp):
if group.rank == 0:
# FIXME in some cases, shape is not available(output of condtake)
shape = inp._tuple_shape
dtype = np.dtype(inp.dtype).name
get_client().bcast_val({"shape": shape, "dtype": dtype}, group.key, group.size)
else:
val = get_client().bcast_val(None, group.key, group.size)
shape = val["shape"]
dtype = val["dtype"]

return shape, dtype


def _bcast_tracer_state(group, inp):
if group.rank == 0:
tracer_keys = []
for n, g in _grad_manager_dict.items():
if g._is_attached_to(inp):
tracer_keys.append(n)
get_client().bcast_val(tracer_keys, group.key, group.size)
else:
tracer_keys = get_client().bcast_val(None, group.key, group.size)
for n in tracer_keys:
g = _grad_manager_dict.get(n)
if g is not None:
g.wrt(inp)
g._refkeeper.append(inp)


def _dummy_input(shape, dtype, device=""):
if device == "":
device = get_default_device()
inp = Tensor(0, dtype=dtype, device=device)
if len(shape) > 0:
inp = inp._broadcast(shape)
return inp


class _ReduceSum(Function):
def __init__(self, group=WORLD, device=""):
self.group = group
self.out_device = device

def forward(self, data):
self.in_device = str(data.device)
return collective_comm(
data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device
)

def backward(self, grad):
has_grad = _bcast_has_grad(self.group, grad)
if has_grad:
return broadcast(grad, self.group, self.in_device)


def reduce_sum(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
@@ -75,8 +148,30 @@ def reduce_sum(
:param group: communication group.
:param device: execution device.
"""
mode = CollectiveComm.Mode.REDUCE_SUM
return collective_comm(inp, mode, group, device)
op = _ReduceSum(group, device)
(out,) = apply(op, inp)

if group.rank == 0:
return out
else:
_save_output_for_autodiff(inp, out)


class _Broadcast(Function):
def __init__(self, group=WORLD, device=""):
self.group = group
self.out_device = device

def forward(self, data):
self.in_device = str(data.device)
return collective_comm(
data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device
)

def backward(self, grad):
# TODO backward with a part of grad
if grad is not None:
return reduce_sum(grad, self.group, self.in_device)


def broadcast(
@@ -89,8 +184,16 @@ def broadcast(
:param group: communication group.
:param device: execution device.
"""
mode = CollectiveComm.Mode.BROADCAST
return collective_comm(inp, mode, group, device)
shape, dtype = _bcast_shape_dtype(group, inp)
if group.rank != 0:
# dummy input to infer shape
inp = _dummy_input(shape, dtype, device)

_bcast_tracer_state(group, inp)

op = _Broadcast(group, device)
(out,) = apply(op, inp)
return out


def all_gather(
@@ -163,6 +266,23 @@ def all_reduce_min(
return collective_comm(inp, mode, group, device)


class _Gather(Function):
def __init__(self, group=WORLD, device=""):
self.group = group
self.out_device = device

def forward(self, data):
self.in_device = str(data.device)
return collective_comm(
data, CollectiveComm.Mode.GATHER, self.group, self.out_device
)

def backward(self, grad):
has_grad = _bcast_has_grad(self.group, grad)
if has_grad:
return scatter(grad, self.group, self.in_device)


def gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
@@ -173,8 +293,31 @@ def gather(
:param group: communication group.
:param device: execution device.
"""
mode = CollectiveComm.Mode.GATHER
return collective_comm(inp, mode, group, device)

op = _Gather(group, device)
(out,) = apply(op, inp)

if group.rank == 0:
return out
else:
_save_output_for_autodiff(inp, out)


class _Scatter(Function):
def __init__(self, group=WORLD, device=""):
self.group = group
self.out_device = device

def forward(self, data):
self.in_device = str(data.device)
return collective_comm(
data, CollectiveComm.Mode.SCATTER, self.group, self.out_device
)

def backward(self, grad):
# TODO backward with a part of grad
if grad is not None:
return gather(grad, self.group, self.in_device)


def scatter(
@@ -187,8 +330,16 @@ def scatter(
:param group: communication group.
:param device: execution device.
"""
mode = CollectiveComm.Mode.SCATTER
return collective_comm(inp, mode, group, device)
shape, dtype = _bcast_shape_dtype(group, inp)
if group.rank != 0:
# dummy input to infer shape
inp = _dummy_input(shape, dtype, device)

_bcast_tracer_state(group, inp)

op = _Scatter(group, device)
(out,) = apply(op, inp)
return out


def all_to_all(
@@ -205,44 +356,46 @@ def all_to_all(
return collective_comm(inp, mode, group, device)


class _RemoteSend(PyOpBase):
class _SendRecvGroup:
def __init__(self, rank_from, rank_to):
self.key = "{}->{}".format(rank_from, rank_to)
self.rank_from = rank_from
self.rank_to = rank_to
self.size = 2

@property
def rank(self):
if get_rank() == self.rank_from:
return 0
else:
return 1


class _RemoteSend(Function):
def __init__(self, op: RemoteSend):
self.op = op

def _default_rule(self, data):
return apply(self.op, data)

def _grad_rule(self, data):
self.dtype = data.dtype
self.shape = data.shape
self.device = data.device
(self.dummy,) = self._default_rule(data)
return self.dummy, self.backward
def forward(self, data):
self.device = str(data.device)
(self.dummy,) = apply(self.op, data)
return self.dummy

def backward(self, grad):
assert grad is None
if get_client().check_is_grad(self.op.key):
return remote_recv(
self.op.rank_to,
self.shape,
self.dtype,
device=str(self.device),
inp=self.dummy,
)
has_grad = get_client().bcast_val(None, self.op.key, 2)
if has_grad:
return remote_recv(self.op.rank_to, device=self.device, inp=self.dummy,)


class _RemoteRecv(PyOpBase):
class _RemoteRecv(Function):
def __init__(self, op: RemoteRecv):
self.op = op

def _default_rule(self, dummy):
def forward(self, dummy):
return apply(self.op, dummy)

def _grad_rule(self, dummy):
return self._default_rule(dummy), self.backward

def backward(self, grad):
get_client().set_is_grad(self.op.key, grad is not None)
get_client().bcast_val(grad is not None, self.op.key, 2)
if grad is not None:
remote_send(grad, self.op.rank_from)

@@ -254,53 +407,38 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor:
:param inp: tensor to send.
:param dest_rank: destination process rank.
"""
key = "{}->{}".format(get_rank(), dest_rank)
grad_keys = {}
for n, g in _grad_manager_dict.items():
if g._is_attached_to(inp):
grad_keys[n] = g
get_client().set_remote_tracer(key, grad_keys)
group = _SendRecvGroup(get_rank(), dest_rank)
_bcast_shape_dtype(group, inp)

_bcast_tracer_state(group, inp)

op = RemoteSend()
op.key = key
op.key = group.key
op.addr, op.port = get_mm_server_addr()
op.rank_to = dest_rank
op.backend = get_backend()
(dummy,) = apply(_RemoteSend(op), inp)
(out,) = apply(_RemoteSend(op), inp)

for g in grad_keys.values():
g._refkeeper.append(dummy)
_save_output_for_autodiff(inp, out)


def remote_recv(
src_rank: int,
shape: Tuple[int],
dtype: type,
device: Optional[str] = None,
inp=None,
) -> Tensor:
def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tensor:
"""
Receive a Tensor from a remote process.

:param src_rank: source process rank.
:param shape: the shape of the tensor to receive.
:param dtype: the data type of the tensor to receive.
:param device: the device to place the received tensor.
:param inp: dummy input to determine recved tensor type
"""
key = "{}->{}".format(src_rank, get_rank())
group = _SendRecvGroup(src_rank, get_rank())
shape, dtype = _bcast_shape_dtype(group, None)

if device is None:
device = get_default_device()
# dummy input
if inp is None:
inp = Tensor([0], device=device)
tracer_set = get_client().check_remote_tracer(key)
for n in tracer_set:
g = _grad_manager_dict.get(n)
if g is not None:
g.wrt(inp)
g._refkeeper.append(inp)
inp = Tensor(0, device=device)
_bcast_tracer_state(group, inp)

_isscalar = False
if len(shape) == 0:
@@ -308,7 +446,7 @@ def remote_recv(
_isscalar = True

op = RemoteRecv()
op.key = key
op.key = group.key
op.cn = device
op.shape = shape
op.dtype = dtype


+ 3
- 0
imperative/python/megengine/distributed/launcher.py View File

@@ -8,6 +8,7 @@
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
import functools
import multiprocessing as mp
import os
import queue

from ..core._imperative_rt.core2 import sync
@@ -43,6 +44,8 @@ def _run_wrapped(
device=dev,
device_type=device_type,
)
# set NCCL_LAUNCH_MODE to avoid deadlock
os.environ["NCCL_LAUNCH_MODE"] = "PARALLEL"
if is_multimachine:
group_barrier()
ret = func(*args, **kwargs)


+ 11
- 0
imperative/python/megengine/distributed/server.py View File

@@ -253,6 +253,17 @@ class Client:
"""Get user defined key-value pairs across processes."""
return self.proxy.user_get(key)

def bcast_val(self, val, key, size):
if val is not None:
self.user_set(key + "_sync", val)
self.group_barrier(key, size)
self.group_barrier(key, size)
else:
self.group_barrier(key, size)
val = self.user_get(key + "_sync")
self.group_barrier(key, size)
return val


def main(port=0, verbose=True):
mm_server_port = create_mm_server("0.0.0.0", 0)


Loading…
Cancel
Save