|
|
@@ -15,12 +15,12 @@ from weakref import WeakSet |
|
|
|
import numpy as np |
|
|
|
|
|
|
|
from megengine.autodiff.grad_manager import GradManager, get_backwarding_grad_manager |
|
|
|
from megengine.device import get_default_device, get_device_count |
|
|
|
|
|
|
|
from ..core._imperative_rt.core2 import apply |
|
|
|
from ..core.ops.builtin import ParamPackConcat, ParamPackSplit |
|
|
|
from ..functional.tensor import copy |
|
|
|
from ..tensor import Tensor |
|
|
|
from ..utils.deprecation import deprecated_func |
|
|
|
from ..utils.future import Future |
|
|
|
from . import group as _group |
|
|
|
from .functional import _bcast_param, all_reduce_sum, broadcast |
|
|
@@ -193,6 +193,11 @@ def _check_device_initialized(device_type: str, rank: int): |
|
|
|
raise RuntimeError(errmsg) |
|
|
|
|
|
|
|
|
|
|
|
get_device_count_by_fork = deprecated_func( |
|
|
|
"1.5", "megengine.device", "get_device_count", False |
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
def bcast_list_(inps: list, group: Group = WORLD): |
|
|
|
""" |
|
|
|
Broadcast tensors between given group. |
|
|
|