Browse Source

perf(dist): add fastpath for bcast params

GitOrigin-RevId: aa40b3cd72
release-1.4
Megvii Engine Team 4 years ago
parent
commit
47138c06cf
2 changed files with 9 additions and 2 deletions
  1. +7
    -0
      imperative/python/megengine/distributed/functional.py
  2. +2
    -2
      imperative/python/megengine/distributed/helper.py

+ 7
- 0
imperative/python/megengine/distributed/functional.py View File

@@ -196,6 +196,13 @@ def broadcast(
return out return out




def _bcast_param(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
mode = CollectiveComm.Mode.BROADCAST
return collective_comm(inp, mode, group, device)


def all_gather( def all_gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = "" inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor: ) -> Tensor:


+ 2
- 2
imperative/python/megengine/distributed/helper.py View File

@@ -22,7 +22,7 @@ from ..core.ops.builtin import ParamPackConcat, ParamPackSplit
from ..functional.tensor import copy from ..functional.tensor import copy
from ..tensor import Tensor from ..tensor import Tensor
from ..utils.future import Future from ..utils.future import Future
from .functional import all_reduce_sum, broadcast
from .functional import _bcast_param, all_reduce_sum, broadcast
from .group import WORLD, Group, group_barrier, is_distributed from .group import WORLD, Group, group_barrier, is_distributed




@@ -186,7 +186,7 @@ def bcast_list_(inps: list, group: Group = WORLD):
:param group: communication group. :param group: communication group.
""" """
for inp in inps: for inp in inps:
inp._reset(broadcast(inp, group))
inp._reset(_bcast_param(inp, group))




class AllreduceCallback: class AllreduceCallback:


Loading…
Cancel
Save