|
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import multiprocessing as mp
- import queue
- from time import sleep
-
- import pytest
-
- import megengine as mge
- import megengine._internal as mgb
- import megengine.distributed as dist
-
- _LOCALHOST = "127.0.0.1"
-
-
- def _assert_q_empty(q):
- try:
- res = q.get(timeout=1)
- except Exception as e:
- assert isinstance(e, queue.Empty)
- else:
- assert False, "queue is not empty"
-
-
- def _assert_q_val(q, val):
- ret = q.get()
- assert ret == val
-
-
- def _init_process_group_wrapper(world_size, rank, dev, backend, q):
- if rank == 0:
- dist.init_process_group(_LOCALHOST, 0, world_size, rank, dev, backend)
- q.put(dist.get_master_port())
- else:
- port = q.get()
- dist.init_process_group(_LOCALHOST, port, world_size, rank, dev, backend)
-
-
- @pytest.mark.isolated_distributed
- def test_create_mm_server():
- def worker():
- if not mge.is_cuda_available():
- return
- port = mgb.config.create_mm_server("0.0.0.0", 0)
- assert port > 0
- res = mgb.config.create_mm_server("0.0.0.0", port)
- assert res == -1
-
- p = mp.Process(target=worker)
-
- p.start()
-
- p.join(10)
-
- assert p.exitcode == 0
-
-
- @pytest.mark.isolated_distributed
- def test_init_process_group():
- world_size = 2
-
- def worker(rank, backend, q):
- if not mge.is_cuda_available():
- return
- _init_process_group_wrapper(world_size, rank, rank, backend, q)
- assert dist.is_distributed() == True
- assert dist.get_master_ip() == _LOCALHOST
- assert dist.get_master_port() > 0
- assert dist.get_world_size() == world_size
- assert dist.get_rank() == rank
- assert dist.get_backend() == backend
-
- def check(backend):
- Q = mp.Queue()
- p0 = mp.Process(target=worker, args=(0, backend, Q))
- p1 = mp.Process(target=worker, args=(1, backend, Q))
-
- p0.start()
- p1.start()
-
- p0.join(10)
- p1.join(10)
-
- assert p0.exitcode == 0 and p1.exitcode == 0
-
- check("nccl")
- check("ucx")
-
-
- @pytest.mark.isolated_distributed
- def test_group_barrier():
- world_size = 2
- ip = "127.0.0.1"
- backend = "nccl"
-
- def worker(rank, q):
- if not mge.is_cuda_available():
- return
- _init_process_group_wrapper(world_size, rank, rank, backend, q)
- dist.group_barrier()
- if rank == 0:
- dist.group_barrier()
- q.put(0) # to be observed in rank 1
- else:
- _assert_q_empty(q) # q.put(0) is not executed in rank 0
- dist.group_barrier()
- _assert_q_val(q, 0) # q.put(0) executed in rank 0
-
- Q = mp.Queue()
- p0 = mp.Process(target=worker, args=(0, Q))
- p1 = mp.Process(target=worker, args=(1, Q))
-
- p0.start()
- p1.start()
-
- p0.join(10)
- p1.join(10)
-
- assert p0.exitcode == 0 and p1.exitcode == 0
-
-
- @pytest.mark.isolated_distributed
- def test_synchronized():
- world_size = 2
- backend = "nccl"
-
- @dist.synchronized
- def func(rank, q):
- q.put(rank)
-
- def worker(rank, q):
- if not mge.is_cuda_available():
- return
- _init_process_group_wrapper(world_size, rank, rank, backend, q)
- dist.group_barrier()
- if rank == 0:
- func(0, q) # q.put(0)
- q.put(2)
- else:
- _assert_q_val(q, 0) # func executed in rank 0
- _assert_q_empty(q) # q.put(2) is not executed
- func(1, q)
- _assert_q_val(
- q, 1
- ) # func in rank 1 executed earlier than q.put(2) in rank 0
- _assert_q_val(q, 2) # q.put(2) executed in rank 0
-
- Q = mp.Queue()
- p0 = mp.Process(target=worker, args=(0, Q))
- p1 = mp.Process(target=worker, args=(1, Q))
-
- p0.start()
- p1.start()
-
- p0.join(10)
- p1.join(10)
-
- assert p0.exitcode == 0 and p1.exitcode == 0
|