diff --git a/python_module/megengine/distributed/functional.py b/python_module/megengine/distributed/functional.py index b0e7cf0b..dd353bf5 100644 --- a/python_module/megengine/distributed/functional.py +++ b/python_module/megengine/distributed/functional.py @@ -112,26 +112,20 @@ def broadcast( rank = get_rank() if rank == root: - return _collective_comm( - tensor, - key, - CollParam.Mode.BROADCAST, - nr_ranks, - rank, - root, - device=tensor.device, - ) + inp = tensor else: - return _collective_comm( - get_default_graph(), - key, - CollParam.Mode.BROADCAST, - nr_ranks, - rank, - root, - dtype=tensor._symvar.dtype, - device=tensor.device, - ) + inp = tensor._symvar.owner_graph + + return _collective_comm( + inp, + key, + CollParam.Mode.BROADCAST, + nr_ranks, + rank, + root, + dtype=tensor.dtype, + device=tensor.device, + ) def all_gather(