|
|
@@ -8,9 +8,11 @@ |
|
|
|
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
from typing import Optional, Tuple |
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from ..core._imperative_rt.core2 import apply |
|
|
|
from ..core.autodiff.grad import _grad_manager_dict |
|
|
|
from ..core.ops.builtin import CollectiveComm, Copy, PyOpBase, RemoteRecv, RemoteSend |
|
|
|
from ..core.autodiff.grad import Function, _grad_manager_dict |
|
|
|
from ..core.ops.builtin import CollectiveComm, Copy, RemoteRecv, RemoteSend |
|
|
|
from ..core.tensor.utils import isscalar, setscalar |
|
|
|
from ..device import get_default_device |
|
|
|
from ..tensor import Tensor |
|
|
@@ -65,6 +67,77 @@ def collective_comm(inp, mode, group, device): |
|
|
|
return result |
|
|
|
|
|
|
|
|
|
|
|
def _save_output_for_autodiff(inp, out): |
|
|
|
for g in _grad_manager_dict.values(): |
|
|
|
if g._is_attached_to(inp): |
|
|
|
g._refkeeper.append(out) |
|
|
|
|
|
|
|
|
|
|
|
def _bcast_has_grad(group, grad): |
|
|
|
if group.rank == 0: |
|
|
|
has_grad = grad is not None |
|
|
|
get_client().bcast_val(has_grad, group.key, group.size) |
|
|
|
else: |
|
|
|
has_grad = get_client().bcast_val(None, group.key, group.size) |
|
|
|
return has_grad |
|
|
|
|
|
|
|
|
|
|
|
def _bcast_shape_dtype(group, inp): |
|
|
|
if group.rank == 0: |
|
|
|
# FIXME in some cases, shape is not available(output of condtake) |
|
|
|
shape = inp._tuple_shape |
|
|
|
dtype = np.dtype(inp.dtype).name |
|
|
|
get_client().bcast_val({"shape": shape, "dtype": dtype}, group.key, group.size) |
|
|
|
else: |
|
|
|
val = get_client().bcast_val(None, group.key, group.size) |
|
|
|
shape = val["shape"] |
|
|
|
dtype = val["dtype"] |
|
|
|
|
|
|
|
return shape, dtype |
|
|
|
|
|
|
|
|
|
|
|
def _bcast_tracer_state(group, inp): |
|
|
|
if group.rank == 0: |
|
|
|
tracer_keys = [] |
|
|
|
for n, g in _grad_manager_dict.items(): |
|
|
|
if g._is_attached_to(inp): |
|
|
|
tracer_keys.append(n) |
|
|
|
get_client().bcast_val(tracer_keys, group.key, group.size) |
|
|
|
else: |
|
|
|
tracer_keys = get_client().bcast_val(None, group.key, group.size) |
|
|
|
for n in tracer_keys: |
|
|
|
g = _grad_manager_dict.get(n) |
|
|
|
if g is not None: |
|
|
|
g.wrt(inp) |
|
|
|
g._refkeeper.append(inp) |
|
|
|
|
|
|
|
|
|
|
|
def _dummy_input(shape, dtype, device=""): |
|
|
|
if device == "": |
|
|
|
device = get_default_device() |
|
|
|
inp = Tensor(0, dtype=dtype, device=device) |
|
|
|
if len(shape) > 0: |
|
|
|
inp = inp._broadcast(shape) |
|
|
|
return inp |
|
|
|
|
|
|
|
|
|
|
|
class _ReduceSum(Function): |
|
|
|
def __init__(self, group=WORLD, device=""): |
|
|
|
self.group = group |
|
|
|
self.out_device = device |
|
|
|
|
|
|
|
def forward(self, data): |
|
|
|
self.in_device = str(data.device) |
|
|
|
return collective_comm( |
|
|
|
data, CollectiveComm.Mode.REDUCE_SUM, self.group, self.out_device |
|
|
|
) |
|
|
|
|
|
|
|
def backward(self, grad): |
|
|
|
has_grad = _bcast_has_grad(self.group, grad) |
|
|
|
if has_grad: |
|
|
|
return broadcast(grad, self.group, self.in_device) |
|
|
|
|
|
|
|
|
|
|
|
def reduce_sum( |
|
|
|
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" |
|
|
|
) -> Tensor: |
|
|
@@ -75,8 +148,30 @@ def reduce_sum( |
|
|
|
:param group: communication group. |
|
|
|
:param device: execution device. |
|
|
|
""" |
|
|
|
mode = CollectiveComm.Mode.REDUCE_SUM |
|
|
|
return collective_comm(inp, mode, group, device) |
|
|
|
op = _ReduceSum(group, device) |
|
|
|
(out,) = apply(op, inp) |
|
|
|
|
|
|
|
if group.rank == 0: |
|
|
|
return out |
|
|
|
else: |
|
|
|
_save_output_for_autodiff(inp, out) |
|
|
|
|
|
|
|
|
|
|
|
class _Broadcast(Function): |
|
|
|
def __init__(self, group=WORLD, device=""): |
|
|
|
self.group = group |
|
|
|
self.out_device = device |
|
|
|
|
|
|
|
def forward(self, data): |
|
|
|
self.in_device = str(data.device) |
|
|
|
return collective_comm( |
|
|
|
data, CollectiveComm.Mode.BROADCAST, self.group, self.out_device |
|
|
|
) |
|
|
|
|
|
|
|
def backward(self, grad): |
|
|
|
# TODO backward with a part of grad |
|
|
|
if grad is not None: |
|
|
|
return reduce_sum(grad, self.group, self.in_device) |
|
|
|
|
|
|
|
|
|
|
|
def broadcast( |
|
|
@@ -89,8 +184,16 @@ def broadcast( |
|
|
|
:param group: communication group. |
|
|
|
:param device: execution device. |
|
|
|
""" |
|
|
|
mode = CollectiveComm.Mode.BROADCAST |
|
|
|
return collective_comm(inp, mode, group, device) |
|
|
|
shape, dtype = _bcast_shape_dtype(group, inp) |
|
|
|
if group.rank != 0: |
|
|
|
# dummy input to infer shape |
|
|
|
inp = _dummy_input(shape, dtype, device) |
|
|
|
|
|
|
|
_bcast_tracer_state(group, inp) |
|
|
|
|
|
|
|
op = _Broadcast(group, device) |
|
|
|
(out,) = apply(op, inp) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
def all_gather( |
|
|
@@ -163,6 +266,23 @@ def all_reduce_min( |
|
|
|
return collective_comm(inp, mode, group, device) |
|
|
|
|
|
|
|
|
|
|
|
class _Gather(Function): |
|
|
|
def __init__(self, group=WORLD, device=""): |
|
|
|
self.group = group |
|
|
|
self.out_device = device |
|
|
|
|
|
|
|
def forward(self, data): |
|
|
|
self.in_device = str(data.device) |
|
|
|
return collective_comm( |
|
|
|
data, CollectiveComm.Mode.GATHER, self.group, self.out_device |
|
|
|
) |
|
|
|
|
|
|
|
def backward(self, grad): |
|
|
|
has_grad = _bcast_has_grad(self.group, grad) |
|
|
|
if has_grad: |
|
|
|
return scatter(grad, self.group, self.in_device) |
|
|
|
|
|
|
|
|
|
|
|
def gather( |
|
|
|
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" |
|
|
|
) -> Tensor: |
|
|
@@ -173,8 +293,31 @@ def gather( |
|
|
|
:param group: communication group. |
|
|
|
:param device: execution device. |
|
|
|
""" |
|
|
|
mode = CollectiveComm.Mode.GATHER |
|
|
|
return collective_comm(inp, mode, group, device) |
|
|
|
|
|
|
|
op = _Gather(group, device) |
|
|
|
(out,) = apply(op, inp) |
|
|
|
|
|
|
|
if group.rank == 0: |
|
|
|
return out |
|
|
|
else: |
|
|
|
_save_output_for_autodiff(inp, out) |
|
|
|
|
|
|
|
|
|
|
|
class _Scatter(Function): |
|
|
|
def __init__(self, group=WORLD, device=""): |
|
|
|
self.group = group |
|
|
|
self.out_device = device |
|
|
|
|
|
|
|
def forward(self, data): |
|
|
|
self.in_device = str(data.device) |
|
|
|
return collective_comm( |
|
|
|
data, CollectiveComm.Mode.SCATTER, self.group, self.out_device |
|
|
|
) |
|
|
|
|
|
|
|
def backward(self, grad): |
|
|
|
# TODO backward with a part of grad |
|
|
|
if grad is not None: |
|
|
|
return gather(grad, self.group, self.in_device) |
|
|
|
|
|
|
|
|
|
|
|
def scatter( |
|
|
@@ -187,8 +330,16 @@ def scatter( |
|
|
|
:param group: communication group. |
|
|
|
:param device: execution device. |
|
|
|
""" |
|
|
|
mode = CollectiveComm.Mode.SCATTER |
|
|
|
return collective_comm(inp, mode, group, device) |
|
|
|
shape, dtype = _bcast_shape_dtype(group, inp) |
|
|
|
if group.rank != 0: |
|
|
|
# dummy input to infer shape |
|
|
|
inp = _dummy_input(shape, dtype, device) |
|
|
|
|
|
|
|
_bcast_tracer_state(group, inp) |
|
|
|
|
|
|
|
op = _Scatter(group, device) |
|
|
|
(out,) = apply(op, inp) |
|
|
|
return out |
|
|
|
|
|
|
|
|
|
|
|
def all_to_all( |
|
|
@@ -205,44 +356,46 @@ def all_to_all( |
|
|
|
return collective_comm(inp, mode, group, device) |
|
|
|
|
|
|
|
|
|
|
|
class _RemoteSend(PyOpBase): |
|
|
|
class _SendRecvGroup: |
|
|
|
def __init__(self, rank_from, rank_to): |
|
|
|
self.key = "{}->{}".format(rank_from, rank_to) |
|
|
|
self.rank_from = rank_from |
|
|
|
self.rank_to = rank_to |
|
|
|
self.size = 2 |
|
|
|
|
|
|
|
@property |
|
|
|
def rank(self): |
|
|
|
if get_rank() == self.rank_from: |
|
|
|
return 0 |
|
|
|
else: |
|
|
|
return 1 |
|
|
|
|
|
|
|
|
|
|
|
class _RemoteSend(Function): |
|
|
|
def __init__(self, op: RemoteSend): |
|
|
|
self.op = op |
|
|
|
|
|
|
|
def _default_rule(self, data): |
|
|
|
return apply(self.op, data) |
|
|
|
|
|
|
|
def _grad_rule(self, data): |
|
|
|
self.dtype = data.dtype |
|
|
|
self.shape = data.shape |
|
|
|
self.device = data.device |
|
|
|
(self.dummy,) = self._default_rule(data) |
|
|
|
return self.dummy, self.backward |
|
|
|
def forward(self, data): |
|
|
|
self.device = str(data.device) |
|
|
|
(self.dummy,) = apply(self.op, data) |
|
|
|
return self.dummy |
|
|
|
|
|
|
|
def backward(self, grad): |
|
|
|
assert grad is None |
|
|
|
if get_client().check_is_grad(self.op.key): |
|
|
|
return remote_recv( |
|
|
|
self.op.rank_to, |
|
|
|
self.shape, |
|
|
|
self.dtype, |
|
|
|
device=str(self.device), |
|
|
|
inp=self.dummy, |
|
|
|
) |
|
|
|
has_grad = get_client().bcast_val(None, self.op.key, 2) |
|
|
|
if has_grad: |
|
|
|
return remote_recv(self.op.rank_to, device=self.device, inp=self.dummy,) |
|
|
|
|
|
|
|
|
|
|
|
class _RemoteRecv(PyOpBase): |
|
|
|
class _RemoteRecv(Function): |
|
|
|
def __init__(self, op: RemoteRecv): |
|
|
|
self.op = op |
|
|
|
|
|
|
|
def _default_rule(self, dummy): |
|
|
|
def forward(self, dummy): |
|
|
|
return apply(self.op, dummy) |
|
|
|
|
|
|
|
def _grad_rule(self, dummy): |
|
|
|
return self._default_rule(dummy), self.backward |
|
|
|
|
|
|
|
def backward(self, grad): |
|
|
|
get_client().set_is_grad(self.op.key, grad is not None) |
|
|
|
get_client().bcast_val(grad is not None, self.op.key, 2) |
|
|
|
if grad is not None: |
|
|
|
remote_send(grad, self.op.rank_from) |
|
|
|
|
|
|
@@ -254,53 +407,38 @@ def remote_send(inp: Tensor, dest_rank: int) -> Tensor: |
|
|
|
:param inp: tensor to send. |
|
|
|
:param dest_rank: destination process rank. |
|
|
|
""" |
|
|
|
key = "{}->{}".format(get_rank(), dest_rank) |
|
|
|
grad_keys = {} |
|
|
|
for n, g in _grad_manager_dict.items(): |
|
|
|
if g._is_attached_to(inp): |
|
|
|
grad_keys[n] = g |
|
|
|
get_client().set_remote_tracer(key, grad_keys) |
|
|
|
group = _SendRecvGroup(get_rank(), dest_rank) |
|
|
|
_bcast_shape_dtype(group, inp) |
|
|
|
|
|
|
|
_bcast_tracer_state(group, inp) |
|
|
|
|
|
|
|
op = RemoteSend() |
|
|
|
op.key = key |
|
|
|
op.key = group.key |
|
|
|
op.addr, op.port = get_mm_server_addr() |
|
|
|
op.rank_to = dest_rank |
|
|
|
op.backend = get_backend() |
|
|
|
(dummy,) = apply(_RemoteSend(op), inp) |
|
|
|
(out,) = apply(_RemoteSend(op), inp) |
|
|
|
|
|
|
|
for g in grad_keys.values(): |
|
|
|
g._refkeeper.append(dummy) |
|
|
|
_save_output_for_autodiff(inp, out) |
|
|
|
|
|
|
|
|
|
|
|
def remote_recv( |
|
|
|
src_rank: int, |
|
|
|
shape: Tuple[int], |
|
|
|
dtype: type, |
|
|
|
device: Optional[str] = None, |
|
|
|
inp=None, |
|
|
|
) -> Tensor: |
|
|
|
def remote_recv(src_rank: int, device: Optional[str] = None, inp=None,) -> Tensor: |
|
|
|
""" |
|
|
|
Receive a Tensor from a remote process. |
|
|
|
|
|
|
|
:param src_rank: source process rank. |
|
|
|
:param shape: the shape of the tensor to receive. |
|
|
|
:param dtype: the data type of the tensor to receive. |
|
|
|
:param device: the device to place the received tensor. |
|
|
|
:param inp: dummy input to determine recved tensor type |
|
|
|
""" |
|
|
|
key = "{}->{}".format(src_rank, get_rank()) |
|
|
|
group = _SendRecvGroup(src_rank, get_rank()) |
|
|
|
shape, dtype = _bcast_shape_dtype(group, None) |
|
|
|
|
|
|
|
if device is None: |
|
|
|
device = get_default_device() |
|
|
|
# dummy input |
|
|
|
if inp is None: |
|
|
|
inp = Tensor([0], device=device) |
|
|
|
tracer_set = get_client().check_remote_tracer(key) |
|
|
|
for n in tracer_set: |
|
|
|
g = _grad_manager_dict.get(n) |
|
|
|
if g is not None: |
|
|
|
g.wrt(inp) |
|
|
|
g._refkeeper.append(inp) |
|
|
|
inp = Tensor(0, device=device) |
|
|
|
_bcast_tracer_state(group, inp) |
|
|
|
|
|
|
|
_isscalar = False |
|
|
|
if len(shape) == 0: |
|
|
@@ -308,7 +446,7 @@ def remote_recv( |
|
|
|
_isscalar = True |
|
|
|
|
|
|
|
op = RemoteRecv() |
|
|
|
op.key = key |
|
|
|
op.key = group.key |
|
|
|
op.cn = device |
|
|
|
op.shape = shape |
|
|
|
op.dtype = dtype |
|
|
|