You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_distributed.py 7.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. # -*- coding: utf-8 -*-
  2. import multiprocessing as mp
  3. import platform
  4. import queue
  5. import numpy as np
  6. import pytest
  7. import megengine as mge
  8. import megengine.distributed as dist
  9. from megengine.core.ops.builtin import CollectiveComm, ParamPackConcat, ParamPackSplit
  10. from megengine.device import get_default_device
  11. from megengine.distributed.helper import param_pack_concat, param_pack_split
  12. def _assert_q_empty(q):
  13. try:
  14. res = q.get(timeout=1)
  15. except Exception as e:
  16. assert isinstance(e, queue.Empty)
  17. else:
  18. assert False, "queue is not empty"
  19. def _assert_q_val(q, val):
  20. ret = q.get()
  21. assert ret == val
  22. @pytest.mark.require_ngpu(2)
  23. @pytest.mark.parametrize("backend", ["nccl"])
  24. @pytest.mark.isolated_distributed
  25. def test_init_process_group(backend):
  26. world_size = 2
  27. server = dist.Server()
  28. port = server.py_server_port
  29. def worker(rank):
  30. dist.init_process_group("localhost", port, world_size, rank, rank, backend)
  31. assert dist.is_distributed() == True
  32. assert dist.get_rank() == rank
  33. assert dist.get_world_size() == world_size
  34. assert dist.get_backend() == backend
  35. py_server_addr = dist.get_py_server_addr()
  36. assert py_server_addr[0] == "localhost"
  37. assert py_server_addr[1] == port
  38. mm_server_addr = dist.get_mm_server_addr()
  39. assert mm_server_addr[0] == "localhost"
  40. assert mm_server_addr[1] > 0
  41. assert isinstance(dist.get_client(), dist.Client)
  42. procs = []
  43. for rank in range(world_size):
  44. p = mp.Process(target=worker, args=(rank,))
  45. p.start()
  46. procs.append(p)
  47. for p in procs:
  48. p.join(20)
  49. assert p.exitcode == 0
  50. @pytest.mark.require_ngpu(3)
  51. @pytest.mark.isolated_distributed
  52. def test_new_group():
  53. world_size = 3
  54. ranks = [2, 0]
  55. @dist.launcher
  56. def worker():
  57. rank = dist.get_rank()
  58. if rank in ranks:
  59. group = dist.new_group(ranks)
  60. assert group.size == 2
  61. assert group.key == "2,0"
  62. assert group.rank == ranks.index(rank)
  63. dt = get_default_device()[:-1]
  64. assert group.comp_node == "{}{}:2".format(dt, rank)
  65. worker()
  66. @pytest.mark.require_ngpu(2)
  67. @pytest.mark.isolated_distributed
  68. def test_group_barrier():
  69. world_size = 2
  70. server = dist.Server()
  71. port = server.py_server_port
  72. def worker(rank, q):
  73. dist.init_process_group("localhost", port, world_size, rank, rank)
  74. dist.group_barrier()
  75. if rank == 0:
  76. dist.group_barrier()
  77. q.put(0) # to be observed in rank 1
  78. else:
  79. _assert_q_empty(q) # q.put(0) is not executed in rank 0
  80. dist.group_barrier()
  81. _assert_q_val(q, 0) # q.put(0) executed in rank 0
  82. Q = mp.Queue()
  83. procs = []
  84. for rank in range(world_size):
  85. p = mp.Process(target=worker, args=(rank, Q))
  86. p.start()
  87. procs.append(p)
  88. for p in procs:
  89. p.join(20)
  90. assert p.exitcode == 0
  91. @pytest.mark.require_ngpu(2)
  92. @pytest.mark.isolated_distributed
  93. def test_synchronized():
  94. world_size = 2
  95. server = dist.Server()
  96. port = server.py_server_port
  97. @dist.synchronized
  98. def func(rank, q):
  99. q.put(rank)
  100. def worker(rank, q):
  101. dist.init_process_group("localhost", port, world_size, rank, rank)
  102. dist.group_barrier()
  103. if rank == 0:
  104. func(0, q) # q.put(0)
  105. q.put(2)
  106. else:
  107. _assert_q_val(q, 0) # func executed in rank 0
  108. _assert_q_empty(q) # q.put(2) is not executed
  109. func(1, q)
  110. _assert_q_val(
  111. q, 1
  112. ) # func in rank 1 executed earlier than q.put(2) in rank 0
  113. _assert_q_val(q, 2) # q.put(2) executed in rank 0
  114. Q = mp.Queue()
  115. procs = []
  116. for rank in range(world_size):
  117. p = mp.Process(target=worker, args=(rank, Q))
  118. p.start()
  119. procs.append(p)
  120. for p in procs:
  121. p.join(20)
  122. assert p.exitcode == 0
  123. @pytest.mark.require_ngpu(2)
  124. @pytest.mark.isolated_distributed
  125. def test_user_set_get():
  126. @dist.launcher
  127. def worker():
  128. # set in race condition
  129. dist.get_client().user_set("foo", 1)
  130. # get in race condition
  131. ret = dist.get_client().user_get("foo")
  132. assert ret == 1
  133. worker()
  134. def test_oprmm_hashable():
  135. lhs = (CollectiveComm(), ParamPackConcat(), ParamPackSplit())
  136. rhs = (CollectiveComm(), ParamPackConcat(), ParamPackSplit())
  137. assert lhs == rhs
  138. assert hash(lhs) == hash(rhs)
  139. def test_param_pack_split():
  140. a = mge.Tensor(np.ones((10,), np.int32))
  141. b, c = param_pack_split(a, [0, 1, 1, 10], [(1,), (3, 3)])
  142. assert np.allclose(b.numpy(), a.numpy()[1])
  143. assert np.allclose(c.numpy(), a.numpy()[1:].reshape(3, 3))
  144. def test_param_pack_concat():
  145. a = mge.Tensor(np.ones((1,), np.int32))
  146. b = mge.Tensor(np.ones((3, 3), np.int32))
  147. offsets_val = [0, 1, 1, 10]
  148. offsets = mge.Tensor(offsets_val, np.int32)
  149. c = param_pack_concat([a, b], offsets, offsets_val)
  150. assert np.allclose(np.concatenate([a.numpy(), b.numpy().flatten()]), c.numpy())
  151. @pytest.mark.require_ngpu(2)
  152. @pytest.mark.parametrize("early_return", [False, True], ids=["common", "early_return"])
  153. @pytest.mark.parametrize("output_size", [10, 10000], ids=["small_size", "large_size"])
  154. @pytest.mark.isolated_distributed
  155. def test_collect_results(early_return, output_size):
  156. @dist.launcher
  157. def worker():
  158. if early_return:
  159. exit(0)
  160. return [dist.get_rank()] * output_size
  161. results = worker()
  162. world_size = len(results)
  163. assert world_size > 0
  164. expects = (
  165. [None] * world_size
  166. if early_return
  167. else [[dev] * output_size for dev in range(world_size)]
  168. )
  169. assert results == expects
  170. @pytest.mark.require_ngpu(2)
  171. @pytest.mark.isolated_distributed
  172. def test_user_set_pop():
  173. @dist.launcher
  174. def worker():
  175. # set in race condition
  176. dist.get_client().user_set("foo", 1)
  177. if dist.get_rank() == 1:
  178. ret = dist.get_client().user_pop("foo")
  179. assert ret == 1
  180. worker()
  181. @pytest.mark.require_ngpu(2)
  182. @pytest.mark.isolated_distributed
  183. def test_get_cuda_compute_capability():
  184. assert mge.device.get_cuda_compute_capability(0) > 0
  185. assert mge.device.get_cuda_compute_capability(1) > 0
  186. @dist.launcher
  187. def worker():
  188. x = mge.tensor([1.0])
  189. assert mge.device.get_cuda_compute_capability(dist.get_rank()) > 0
  190. worker()
  191. @pytest.mark.require_ngpu(3)
  192. @pytest.mark.isolated_distributed
  193. def test_batch_send_recv():
  194. import megengine.distributed.functional as DF
  195. @dist.launcher(n_gpus=3)
  196. def worker():
  197. rank = dist.get_rank()
  198. dist.group_start()
  199. for i in range(3):
  200. tensor = mge.tensor(np.ones(10000)) * rank
  201. if i == 2:
  202. tensor *= i
  203. DF._remote_send_nobackward(tensor, (rank + 1) % 3)
  204. DF._remote_recv_nobackward(
  205. src_rank=(rank + 1) % 3, dtype="float32", shape=(10000,)
  206. )
  207. DF._remote_send_nobackward(tensor, (rank - 1) % 3)
  208. recv = DF._remote_recv_nobackward(
  209. src_rank=(rank - 1) % 3, dtype="float32", shape=(10000,)
  210. )
  211. if i == 2:
  212. recv2 = recv
  213. dist.group_end()
  214. np.testing.assert_equal(recv2.numpy(), (rank - 1) % 3 * 2 * np.ones(10000))
  215. worker()