diff --git a/imperative/python/megengine/distributed/group.py b/imperative/python/megengine/distributed/group.py index fb0f3f11..8fc0634d 100644 --- a/imperative/python/megengine/distributed/group.py +++ b/imperative/python/megengine/distributed/group.py @@ -104,7 +104,7 @@ class Group: WORLD = Group([]) _devices = {"gpu", "cuda", "rocm"} -_backends = {"nccl", "rccl", "ucx", "auto"} +_backends = {"nccl", "rccl", "shm", "auto"} def init_process_group( diff --git a/imperative/python/megengine/distributed/launcher.py b/imperative/python/megengine/distributed/launcher.py index f963bd61..531962ee 100644 --- a/imperative/python/megengine/distributed/launcher.py +++ b/imperative/python/megengine/distributed/launcher.py @@ -89,7 +89,7 @@ class launcher: master_ip="localhost", port=0, device_type="xpu", - backend="auto", + backend="nccl", ): self.func = func self.n_gpus = n_gpus if n_gpus is not None else get_device_count(device_type) diff --git a/imperative/python/test/unit/utils/test_network_node.py b/imperative/python/test/unit/utils/test_network_node.py index e6c71379..f6bcddac 100644 --- a/imperative/python/test/unit/utils/test_network_node.py +++ b/imperative/python/test/unit/utils/test_network_node.py @@ -14,6 +14,10 @@ from megengine.core._imperative_rt.core2 import apply from megengine.core._wrap import Device from megengine.core.ops import builtin from megengine.device import get_device_count, is_cuda_available +from megengine.functional.debug_param import ( + get_execution_strategy, + set_execution_strategy, +) from megengine.functional.external import tensorrt_runtime_opr from megengine.jit.tracing import trace from megengine.tensor import Tensor @@ -106,10 +110,13 @@ def test_matmul(): def fwd(data1, data2): return F.matmul(data1, data2) + old = get_execution_strategy() + set_execution_strategy("HEURISTIC_REPRODUCIBLE") data1 = Tensor(np.random.random((32, 64))) data2 = Tensor(np.random.random((64, 16))) result = fwd(data1, data2) check_pygraph_dump(fwd, [data1, data2], [result]) + set_execution_strategy(old) def test_batchmatmul():