|
@@ -22,7 +22,7 @@ from ..core.ops.builtin import ParamPackConcat, ParamPackSplit |
|
|
from ..functional.tensor import copy |
|
|
from ..functional.tensor import copy |
|
|
from ..tensor import Tensor |
|
|
from ..tensor import Tensor |
|
|
from ..utils.future import Future |
|
|
from ..utils.future import Future |
|
|
from .functional import all_reduce_sum, broadcast |
|
|
|
|
|
|
|
|
from .functional import _bcast_param, all_reduce_sum, broadcast |
|
|
from .group import WORLD, Group, group_barrier, is_distributed |
|
|
from .group import WORLD, Group, group_barrier, is_distributed |
|
|
|
|
|
|
|
|
|
|
|
|
|
@@ -186,7 +186,7 @@ def bcast_list_(inps: list, group: Group = WORLD): |
|
|
:param group: communication group. |
|
|
:param group: communication group. |
|
|
""" |
|
|
""" |
|
|
for inp in inps: |
|
|
for inp in inps: |
|
|
inp._reset(broadcast(inp, group)) |
|
|
|
|
|
|
|
|
inp._reset(_bcast_param(inp, group)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AllreduceCallback: |
|
|
class AllreduceCallback: |
|
|