|
- # -*- coding: utf-8 -*-
- # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
- #
- # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
- #
- # Unless required by applicable law or agreed to in writing,
- # software distributed under the License is distributed on an
- # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- import multiprocessing as mp
- import platform
- import queue
-
- import pytest
-
- import megengine as mge
- import megengine.distributed as dist
-
-
- def _assert_q_empty(q):
- try:
- res = q.get(timeout=1)
- except Exception as e:
- assert isinstance(e, queue.Empty)
- else:
- assert False, "queue is not empty"
-
-
- def _assert_q_val(q, val):
- ret = q.get()
- assert ret == val
-
-
- @pytest.mark.skipif(
- platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
- )
- @pytest.mark.skipif(
- platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
- )
- @pytest.mark.isolated_distributed
- def test_init_process_group():
- world_size = 2
- port = dist.get_free_ports(1)[0]
- server = dist.Server(port)
-
- def worker(rank, backend):
- if mge.get_device_count("gpu") < world_size:
- return
- dist.init_process_group("localhost", port, world_size, rank, rank, backend)
- assert dist.is_distributed() == True
- assert dist.get_rank() == rank
- assert dist.get_world_size() == world_size
- assert dist.get_backend() == backend
-
- py_server_addr = dist.get_py_server_addr()
- assert py_server_addr[0] == "localhost"
- assert py_server_addr[1] == port
-
- mm_server_addr = dist.get_mm_server_addr()
- assert mm_server_addr[0] == "localhost"
- assert mm_server_addr[1] > 0
-
- assert isinstance(dist.get_client(), dist.Client)
-
- def check(backend):
- procs = []
- for rank in range(world_size):
- p = mp.Process(target=worker, args=(rank, backend))
- p.start()
- procs.append(p)
-
- for p in procs:
- p.join(20)
- assert p.exitcode == 0
-
- check("nccl")
- check("ucx")
-
-
- @pytest.mark.skipif(
- platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
- )
- @pytest.mark.skipif(
- platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
- )
- @pytest.mark.isolated_distributed
- def test_new_group():
- world_size = 3
- ranks = [2, 0]
- port = dist.get_free_ports(1)[0]
- server = dist.Server(port)
-
- def worker(rank):
- if mge.get_device_count("gpu") < world_size:
- return
- dist.init_process_group("localhost", port, world_size, rank, rank)
- if rank in ranks:
- group = dist.new_group(ranks)
- assert group.size == 2
- assert group.key == "2,0"
- assert group.rank == ranks.index(rank)
- assert group.comp_node == "gpu{}:2".format(rank)
-
- procs = []
- for rank in range(world_size):
- p = mp.Process(target=worker, args=(rank,))
- p.start()
- procs.append(p)
-
- for p in procs:
- p.join(20)
- assert p.exitcode == 0
-
-
- @pytest.mark.skipif(
- platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
- )
- @pytest.mark.skipif(
- platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
- )
- @pytest.mark.isolated_distributed
- def test_group_barrier():
- world_size = 2
- port = dist.get_free_ports(1)[0]
- server = dist.Server(port)
-
- def worker(rank, q):
- if mge.get_device_count("gpu") < world_size:
- return
- dist.init_process_group("localhost", port, world_size, rank, rank)
- dist.group_barrier()
- if rank == 0:
- dist.group_barrier()
- q.put(0) # to be observed in rank 1
- else:
- _assert_q_empty(q) # q.put(0) is not executed in rank 0
- dist.group_barrier()
- _assert_q_val(q, 0) # q.put(0) executed in rank 0
-
- Q = mp.Queue()
- procs = []
- for rank in range(world_size):
- p = mp.Process(target=worker, args=(rank, Q))
- p.start()
- procs.append(p)
-
- for p in procs:
- p.join(20)
- assert p.exitcode == 0
-
-
- @pytest.mark.skipif(
- platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
- )
- @pytest.mark.skipif(
- platform.system() == "Windows", reason="do not imp GPU mode at Windows now"
- )
- @pytest.mark.isolated_distributed
- def test_synchronized():
- world_size = 2
- port = dist.get_free_ports(1)[0]
- server = dist.Server(port)
-
- @dist.synchronized
- def func(rank, q):
- q.put(rank)
-
- def worker(rank, q):
- if mge.get_device_count("gpu") < world_size:
- return
- dist.init_process_group("localhost", port, world_size, rank, rank)
- dist.group_barrier()
- if rank == 0:
- func(0, q) # q.put(0)
- q.put(2)
- else:
- _assert_q_val(q, 0) # func executed in rank 0
- _assert_q_empty(q) # q.put(2) is not executed
- func(1, q)
- _assert_q_val(
- q, 1
- ) # func in rank 1 executed earlier than q.put(2) in rank 0
- _assert_q_val(q, 2) # q.put(2) executed in rank 0
-
- Q = mp.Queue()
- procs = []
- for rank in range(world_size):
- p = mp.Process(target=worker, args=(rank, Q))
- p.start()
- procs.append(p)
-
- for p in procs:
- p.join(20)
- assert p.exitcode == 0
|