You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_functional.py 14 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434
  1. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  2. #
  3. # Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
  4. #
  5. # Unless required by applicable law or agreed to in writing,
  6. # software distributed under the License is distributed on an
  7. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  8. import multiprocessing as mp
  9. import platform
  10. import numpy as np
  11. import pytest
  12. import megengine as mge
  13. import megengine.distributed as dist
  14. from megengine.core import Parameter, tensor
  15. def _init_process_group_wrapper(world_size, rank, dev, backend, q):
  16. if rank == 0:
  17. dist.init_process_group("localhost", 0, world_size, rank, dev, backend)
  18. q.put(dist.get_master_port())
  19. else:
  20. port = q.get()
  21. dist.init_process_group("localhost", port, world_size, rank, dev, backend)
  22. @pytest.mark.skipif(
  23. platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
  24. )
  25. @pytest.mark.isolated_distributed
  26. def test_reduce_sum():
  27. world_size = 2
  28. def worker(rank, data, backend, expect, port_queue):
  29. if mge.get_device_count("gpu") < world_size:
  30. return
  31. _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
  32. inp = tensor(data)
  33. output = dist.functional.reduce_sum(inp)
  34. if rank == 0:
  35. assert np.allclose(output.numpy(), expect)
  36. else:
  37. assert np.allclose(output.numpy(), 0)
  38. def check(shape, backend):
  39. port_queue = mp.Queue()
  40. x = np.random.rand(*shape).astype("float32")
  41. y = np.random.rand(*shape).astype("float32")
  42. z = x + y
  43. p0 = mp.Process(target=worker, args=(0, x, backend, z, port_queue))
  44. p1 = mp.Process(target=worker, args=(1, y, backend, None, port_queue))
  45. p0.start()
  46. p1.start()
  47. p0.join(10)
  48. p1.join(10)
  49. assert p0.exitcode == 0 and p1.exitcode == 0
  50. for shape in [(2, 3), (8, 10), (99, 77)]:
  51. for backend in ["nccl", "ucx"]:
  52. check(shape, backend)
  53. @pytest.mark.skipif(
  54. platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
  55. )
  56. @pytest.mark.isolated_distributed
  57. def test_gather():
  58. world_size = 2
  59. def worker(rank, data, backend, expect, port_queue):
  60. if mge.get_device_count("gpu") < world_size:
  61. return
  62. _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
  63. inp = tensor(data)
  64. output = dist.functional.gather(inp)
  65. if rank == 0:
  66. assert np.allclose(output.numpy(), expect)
  67. else:
  68. assert np.allclose(output.numpy(), 0)
  69. def check(shape, backend):
  70. port_queue = mp.Queue()
  71. x = np.random.rand(*shape).astype("float32")
  72. y = np.random.rand(*shape).astype("float32")
  73. z = np.concatenate((x, y))
  74. p0 = mp.Process(target=worker, args=(0, x, backend, z, port_queue))
  75. p1 = mp.Process(target=worker, args=(1, y, backend, None, port_queue))
  76. p0.start()
  77. p1.start()
  78. p0.join(10)
  79. p1.join(10)
  80. assert p0.exitcode == 0 and p1.exitcode == 0
  81. for shape in [(2, 3), (8, 10), (99, 77)]:
  82. for backend in ["nccl", "ucx"]:
  83. check(shape, backend)
  84. @pytest.mark.skipif(
  85. platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
  86. )
  87. @pytest.mark.isolated_distributed
  88. def test_broadcast():
  89. world_size = 2
  90. def worker(rank, data, backend, expect, port_queue):
  91. if mge.get_device_count("gpu") < world_size:
  92. return
  93. _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
  94. inp = tensor(data)
  95. output = dist.functional.broadcast(inp)
  96. assert np.allclose(output.numpy(), expect)
  97. def check(shape, backend):
  98. port_queue = mp.Queue()
  99. x = np.random.rand(*shape).astype("float32")
  100. y = x + 1
  101. p0 = mp.Process(target=worker, args=(0, x, backend, x, port_queue))
  102. p1 = mp.Process(target=worker, args=(1, y, backend, x, port_queue))
  103. p0.start()
  104. p1.start()
  105. p0.join(10)
  106. p1.join(10)
  107. assert p0.exitcode == 0 and p1.exitcode == 0
  108. for shape in [(2, 3), (8, 10), (99, 77)]:
  109. for backend in ["nccl", "ucx"]:
  110. check(shape, backend)
  111. @pytest.mark.skipif(
  112. platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
  113. )
  114. @pytest.mark.isolated_distributed
  115. def test_scatter():
  116. world_size = 2
  117. def worker(rank, data, backend, expect, port_queue):
  118. if mge.get_device_count("gpu") < world_size:
  119. return
  120. _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
  121. inp = tensor(data)
  122. output = dist.functional.scatter(inp)
  123. assert np.allclose(output.numpy(), expect)
  124. def check(shape, backend):
  125. port_queue = mp.Queue()
  126. x = np.random.rand(*shape).astype("float32")
  127. y = x + 1
  128. p0 = mp.Process(
  129. target=worker, args=(0, x, backend, x[: shape[0] // 2], port_queue)
  130. )
  131. p1 = mp.Process(
  132. target=worker, args=(1, y, backend, x[shape[0] // 2 :], port_queue)
  133. )
  134. p0.start()
  135. p1.start()
  136. p0.join(10)
  137. p1.join(10)
  138. assert p0.exitcode == 0 and p1.exitcode == 0
  139. for shape in [(2, 3), (8, 10), (100, 77)]:
  140. for backend in ["nccl", "ucx"]:
  141. check(shape, backend)
  142. @pytest.mark.skipif(
  143. platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
  144. )
  145. @pytest.mark.isolated_distributed
  146. def test_all_to_all():
  147. world_size = 2
  148. def worker(rank, data, backend, expect, port_queue):
  149. if mge.get_device_count("gpu") < world_size:
  150. return
  151. _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
  152. inp = tensor(data)
  153. output = dist.functional.all_to_all(inp)
  154. assert np.allclose(output.numpy(), expect)
  155. def check(shape, backend):
  156. port_queue = mp.Queue()
  157. x = np.random.rand(*shape).astype("float32")
  158. y = np.random.rand(*shape).astype("float32")
  159. a = np.concatenate((x[: shape[0] // 2], y[: shape[0] // 2]))
  160. b = np.concatenate((x[shape[0] // 2 :], y[shape[0] // 2 :]))
  161. p0 = mp.Process(target=worker, args=(0, x, backend, a, port_queue))
  162. p1 = mp.Process(target=worker, args=(1, y, backend, b, port_queue))
  163. p0.start()
  164. p1.start()
  165. p0.join(10)
  166. p1.join(10)
  167. assert p0.exitcode == 0 and p1.exitcode == 0
  168. for shape in [(2, 3), (8, 10), (100, 77)]:
  169. for backend in ["nccl", "ucx"]:
  170. check(shape, backend)
  171. @pytest.mark.skipif(
  172. platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
  173. )
  174. @pytest.mark.isolated_distributed
  175. def test_all_gather():
  176. world_size = 2
  177. def worker(rank, data, backend, expect, port_queue):
  178. if mge.get_device_count("gpu") < world_size:
  179. return
  180. _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
  181. inp = tensor(data)
  182. output = dist.functional.all_gather(inp)
  183. assert np.allclose(output.numpy(), expect)
  184. def check(shape, backend):
  185. port_queue = mp.Queue()
  186. x = np.random.rand(*shape).astype("float32")
  187. y = np.random.rand(*shape).astype("float32")
  188. z = np.concatenate((x, y))
  189. p0 = mp.Process(target=worker, args=(0, x, backend, z, port_queue))
  190. p1 = mp.Process(target=worker, args=(1, y, backend, z, port_queue))
  191. p0.start()
  192. p1.start()
  193. p0.join(10)
  194. p1.join(10)
  195. assert p0.exitcode == 0 and p1.exitcode == 0
  196. for shape in [(2, 3), (8, 10), (99, 77)]:
  197. for backend in ["nccl", "ucx"]:
  198. check(shape, backend)
  199. @pytest.mark.skipif(
  200. platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
  201. )
  202. @pytest.mark.isolated_distributed
  203. def test_reduce_scatter_sum():
  204. world_size = 2
  205. def worker(rank, data, backend, expect, port_queue):
  206. if mge.get_device_count("gpu") < world_size:
  207. return
  208. _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
  209. inp = tensor(data)
  210. output = dist.functional.reduce_scatter_sum(inp)
  211. assert np.allclose(output.numpy(), expect)
  212. def check(shape, backend):
  213. port_queue = mp.Queue()
  214. x = np.random.rand(*shape).astype("float32")
  215. y = np.random.rand(*shape).astype("float32")
  216. z = x + y
  217. p0 = mp.Process(
  218. target=worker, args=(0, x, backend, z[: shape[0] // 2], port_queue)
  219. )
  220. p1 = mp.Process(
  221. target=worker, args=(1, y, backend, z[shape[0] // 2 :], port_queue)
  222. )
  223. p0.start()
  224. p1.start()
  225. p0.join(10)
  226. p1.join(10)
  227. assert p0.exitcode == 0 and p1.exitcode == 0
  228. for shape in [(2, 4), (8, 10), (88, 44)]:
  229. for backend in ["nccl", "ucx"]:
  230. check(shape, backend)
  231. @pytest.mark.skipif(
  232. platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
  233. )
  234. @pytest.mark.isolated_distributed
  235. def test_all_reduce_sum():
  236. world_size = 2
  237. def worker(rank, data, backend, expect, port_queue):
  238. if mge.get_device_count("gpu") < world_size:
  239. return
  240. _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
  241. inp = tensor(data)
  242. output = dist.functional.all_reduce_sum(inp)
  243. assert np.allclose(output.numpy(), expect)
  244. def check(shape, backend):
  245. port_queue = mp.Queue()
  246. x = np.random.rand(*shape).astype("float32")
  247. y = np.random.rand(*shape).astype("float32")
  248. z = x + y
  249. p0 = mp.Process(target=worker, args=(0, x, backend, z, port_queue))
  250. p1 = mp.Process(target=worker, args=(1, y, backend, z, port_queue))
  251. p0.start()
  252. p1.start()
  253. p0.join(10)
  254. p1.join(10)
  255. assert p0.exitcode == 0 and p1.exitcode == 0
  256. for shape in [(2, 3), (8, 10), (99, 77)]:
  257. for backend in ["nccl", "ucx"]:
  258. check(shape, backend)
  259. @pytest.mark.skipif(
  260. platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
  261. )
  262. @pytest.mark.isolated_distributed
  263. def test_all_reduce_max():
  264. world_size = 2
  265. def worker(rank, data, backend, expect, port_queue):
  266. if mge.get_device_count("gpu") < world_size:
  267. return
  268. _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
  269. inp = tensor(data)
  270. output = dist.functional.all_reduce_max(inp)
  271. assert np.allclose(output.numpy(), expect)
  272. def check(shape, backend):
  273. port_queue = mp.Queue()
  274. x = np.random.rand(*shape).astype("float32")
  275. y = np.random.rand(*shape).astype("float32")
  276. z = np.maximum(x, y)
  277. p0 = mp.Process(target=worker, args=(0, x, backend, z, port_queue))
  278. p1 = mp.Process(target=worker, args=(1, y, backend, z, port_queue))
  279. p0.start()
  280. p1.start()
  281. p0.join(10)
  282. p1.join(10)
  283. assert p0.exitcode == 0 and p1.exitcode == 0
  284. for shape in [(2, 3), (8, 10), (99, 77)]:
  285. for backend in ["nccl", "ucx"]:
  286. check(shape, backend)
  287. @pytest.mark.skipif(
  288. platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
  289. )
  290. @pytest.mark.isolated_distributed
  291. def test_all_reduce_min():
  292. world_size = 2
  293. def worker(rank, data, backend, expect, port_queue):
  294. if mge.get_device_count("gpu") < world_size:
  295. return
  296. _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
  297. inp = tensor(data)
  298. output = dist.functional.all_reduce_min(inp)
  299. assert np.allclose(output.numpy(), expect)
  300. def check(shape, backend):
  301. port_queue = mp.Queue()
  302. x = np.random.rand(*shape).astype("float32")
  303. y = np.random.rand(*shape).astype("float32")
  304. z = np.minimum(x, y)
  305. p0 = mp.Process(target=worker, args=(0, x, backend, z, port_queue))
  306. p1 = mp.Process(target=worker, args=(1, y, backend, z, port_queue))
  307. p0.start()
  308. p1.start()
  309. p0.join(10)
  310. p1.join(10)
  311. assert p0.exitcode == 0 and p1.exitcode == 0
  312. for shape in [(2, 3), (8, 10), (99, 77)]:
  313. for backend in ["nccl", "ucx"]:
  314. check(shape, backend)
  315. @pytest.mark.skipif(
  316. platform.system() == "Darwin", reason="do not imp GPU mode at macos now"
  317. )
  318. @pytest.mark.isolated_distributed
  319. def test_bcast_param():
  320. world_size = 2
  321. def worker(rank, data, backend, expect, port_queue):
  322. if mge.get_device_count("gpu") < world_size:
  323. return
  324. _init_process_group_wrapper(world_size, rank, rank, backend, port_queue)
  325. inp = Parameter(data)
  326. dist.functional.bcast_param(inp)
  327. assert np.allclose(inp.numpy(), expect)
  328. def check(shape, backend):
  329. port_queue = mp.Queue()
  330. x = np.random.rand(*shape).astype("float32")
  331. y = x + 1
  332. p0 = mp.Process(target=worker, args=(0, x, backend, x, port_queue))
  333. p1 = mp.Process(target=worker, args=(1, y, backend, x, port_queue))
  334. p0.start()
  335. p1.start()
  336. p0.join(10)
  337. p1.join(10)
  338. assert p0.exitcode == 0 and p1.exitcode == 0
  339. for shape in [(2, 3), (8, 10), (99, 77)]:
  340. for backend in ["nccl", "ucx"]:
  341. check(shape, backend)

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台