Browse Source

fix(mge/distributed): fix input comp_graph of broadcast operator

GitOrigin-RevId: 039fd06a93
tags/v0.5.0
Megvii Engine Team 5 years ago
parent
commit
6972bfdeba
1 changed files with 13 additions and 19 deletions
  1. +13
    -19
      python_module/megengine/distributed/functional.py

+ 13
- 19
python_module/megengine/distributed/functional.py View File

@@ -112,26 +112,20 @@ def broadcast(
rank = get_rank() rank = get_rank()


if rank == root: if rank == root:
return _collective_comm(
tensor,
key,
CollParam.Mode.BROADCAST,
nr_ranks,
rank,
root,
device=tensor.device,
)
inp = tensor
else: else:
return _collective_comm(
get_default_graph(),
key,
CollParam.Mode.BROADCAST,
nr_ranks,
rank,
root,
dtype=tensor._symvar.dtype,
device=tensor.device,
)
inp = tensor._symvar.owner_graph

return _collective_comm(
inp,
key,
CollParam.Mode.BROADCAST,
nr_ranks,
rank,
root,
dtype=tensor.dtype,
device=tensor.device,
)




def all_gather( def all_gather(


Loading…
Cancel
Save