Browse Source

perf(dist): add fastpath for bcast params

GitOrigin-RevId: aa40b3cd72
release-1.4
Megvii Engine Team 4 years ago
parent
commit
47138c06cf
2 changed files with 9 additions and 2 deletions
  1. +7
    -0
      imperative/python/megengine/distributed/functional.py
  2. +2
    -2
      imperative/python/megengine/distributed/helper.py

+ 7
- 0
imperative/python/megengine/distributed/functional.py View File

@@ -196,6 +196,13 @@ def broadcast(
return out


def _bcast_param(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:
mode = CollectiveComm.Mode.BROADCAST
return collective_comm(inp, mode, group, device)


def all_gather(
inp: Tensor, group: Optional[Group] = WORLD, device: Optional[str] = ""
) -> Tensor:


+ 2
- 2
imperative/python/megengine/distributed/helper.py View File

@@ -22,7 +22,7 @@ from ..core.ops.builtin import ParamPackConcat, ParamPackSplit
from ..functional.tensor import copy
from ..tensor import Tensor
from ..utils.future import Future
from .functional import all_reduce_sum, broadcast
from .functional import _bcast_param, all_reduce_sum, broadcast
from .group import WORLD, Group, group_barrier, is_distributed


@@ -186,7 +186,7 @@ def bcast_list_(inps: list, group: Group = WORLD):
:param group: communication group.
"""
for inp in inps:
inp._reset(broadcast(inp, group))
inp._reset(_bcast_param(inp, group))


class AllreduceCallback:


Loading…
Cancel
Save