|
|
@@ -112,26 +112,20 @@ def broadcast( |
|
|
|
rank = get_rank() |
|
|
|
|
|
|
|
if rank == root: |
|
|
|
return _collective_comm( |
|
|
|
tensor, |
|
|
|
key, |
|
|
|
CollParam.Mode.BROADCAST, |
|
|
|
nr_ranks, |
|
|
|
rank, |
|
|
|
root, |
|
|
|
device=tensor.device, |
|
|
|
) |
|
|
|
inp = tensor |
|
|
|
else: |
|
|
|
return _collective_comm( |
|
|
|
get_default_graph(), |
|
|
|
key, |
|
|
|
CollParam.Mode.BROADCAST, |
|
|
|
nr_ranks, |
|
|
|
rank, |
|
|
|
root, |
|
|
|
dtype=tensor._symvar.dtype, |
|
|
|
device=tensor.device, |
|
|
|
) |
|
|
|
inp = tensor._symvar.owner_graph |
|
|
|
|
|
|
|
return _collective_comm( |
|
|
|
inp, |
|
|
|
key, |
|
|
|
CollParam.Mode.BROADCAST, |
|
|
|
nr_ranks, |
|
|
|
rank, |
|
|
|
root, |
|
|
|
dtype=tensor.dtype, |
|
|
|
device=tensor.device, |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def all_gather( |
|
|
|