diff --git a/imperative/python/megengine/distributed/functional.py b/imperative/python/megengine/distributed/functional.py index 453d3ba5..d832ae86 100644 --- a/imperative/python/megengine/distributed/functional.py +++ b/imperative/python/megengine/distributed/functional.py @@ -196,6 +196,13 @@ def broadcast( return out +def _bcast_param( + inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" +) -> Tensor: + mode = CollectiveComm.Mode.BROADCAST + return collective_comm(inp, mode, group, device) + + def all_gather( inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" ) -> Tensor: diff --git a/imperative/python/megengine/distributed/helper.py b/imperative/python/megengine/distributed/helper.py index 0a67f2dd..c0743958 100644 --- a/imperative/python/megengine/distributed/helper.py +++ b/imperative/python/megengine/distributed/helper.py @@ -22,7 +22,7 @@ from ..core.ops.builtin import ParamPackConcat, ParamPackSplit from ..functional.tensor import copy from ..tensor import Tensor from ..utils.future import Future -from .functional import all_reduce_sum, broadcast +from .functional import _bcast_param, all_reduce_sum, broadcast from .group import WORLD, Group, group_barrier, is_distributed @@ -186,7 +186,7 @@ def bcast_list_(inps: list, group: Group = WORLD): :param group: communication group. """ for inp in inps: - inp._reset(broadcast(inp, group)) + inp._reset(_bcast_param(inp, group)) class AllreduceCallback: