You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_functional_distributed.py 11 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. # -*- coding: utf-8 -*-
  2. # MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
  3. #
  4. # Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
  5. #
  6. # Unless required by applicable law or agreed to in writing,
  7. # software distributed under the License is distributed on an
  8. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  9. import platform
  10. import numpy as np
  11. import pytest
  12. import megengine as mge
  13. import megengine.distributed as dist
  14. from megengine import Parameter, tensor
  15. from megengine.core._imperative_rt.core2 import sync
  16. from megengine.device import get_default_device, set_default_device
  17. from megengine.functional.distributed import (
  18. all_gather,
  19. all_reduce_max,
  20. all_reduce_min,
  21. all_reduce_sum,
  22. all_to_all,
  23. broadcast,
  24. gather,
  25. reduce_scatter_sum,
  26. reduce_sum,
  27. remote_recv,
  28. remote_send,
  29. scatter,
  30. )
  31. def run_reduce_sum(shape, dtype):
  32. @dist.launcher(n_gpus=2)
  33. def worker(data, expect):
  34. rank = dist.get_rank()
  35. inp = tensor(data[rank])
  36. output = reduce_sum(inp)
  37. if rank == 0:
  38. assert np.allclose(output.numpy(), expect[rank])
  39. else:
  40. assert output is None
  41. x = np.random.random_sample(shape).astype(dtype)
  42. y = np.random.random_sample(shape).astype(dtype)
  43. z = x + y
  44. data = (x, y)
  45. expect = (z, None)
  46. worker(data, expect)
  47. @pytest.mark.require_ngpu(2)
  48. @pytest.mark.parametrize("shape", [(), (1,), (2, 3), (8, 10), (99, 77)], ids=str)
  49. @pytest.mark.isolated_distributed
  50. def test_reduce_sum_multishape(shape):
  51. run_reduce_sum(shape, "float32")
  52. @pytest.mark.require_ngpu(2)
  53. @pytest.mark.parametrize("dtype", ["float32", "int32", "int8", "uint8"], ids=str)
  54. @pytest.mark.isolated_distributed
  55. def test_reduce_sum_multidtype(dtype):
  56. run_reduce_sum((8, 10), dtype)
  57. def run_broadcast(shape, dtype):
  58. @dist.launcher(n_gpus=2)
  59. def worker(data, expect):
  60. rank = dist.get_rank()
  61. inp = tensor(data[rank])
  62. output = broadcast(inp)
  63. assert np.allclose(output.numpy(), expect[rank])
  64. x = np.random.random_sample(shape).astype(dtype)
  65. y = x + 1
  66. data = (x, y)
  67. expect = (x, x)
  68. worker(data, expect)
  69. @pytest.mark.require_ngpu(2)
  70. @pytest.mark.parametrize("shape", [(), (1,), (2, 3), (8, 10), (99, 77)], ids=str)
  71. @pytest.mark.isolated_distributed
  72. def test_broadcast_multishape(shape):
  73. run_broadcast(shape, "float32")
  74. @pytest.mark.require_ngpu(2)
  75. @pytest.mark.parametrize("dtype", ["float32", "int32", "int8", "uint8"], ids=str)
  76. @pytest.mark.isolated_distributed
  77. def test_broadcast_multidtype(dtype):
  78. run_broadcast((8, 10), dtype)
  79. def run_all_gather(shape, dtype):
  80. @dist.launcher(n_gpus=2)
  81. def worker(data, expect):
  82. rank = dist.get_rank()
  83. inp = tensor(data[rank])
  84. output = all_gather(inp)
  85. assert np.allclose(output.numpy(), expect[rank])
  86. x = np.random.random_sample(shape).astype(dtype)
  87. y = np.random.random_sample(shape).astype(dtype)
  88. z = np.concatenate((x, y))
  89. data = (x, y)
  90. expect = (z, z)
  91. worker(data, expect)
  92. @pytest.mark.require_ngpu(2)
  93. @pytest.mark.parametrize("shape", [(1,), (2, 3), (8, 10), (99, 77)], ids=str)
  94. @pytest.mark.isolated_distributed
  95. def test_all_gather_multishape(shape):
  96. run_all_gather(shape, "float32")
  97. @pytest.mark.require_ngpu(2)
  98. @pytest.mark.parametrize("dtype", ["float32", "int32", "int8", "uint8"], ids=str)
  99. @pytest.mark.isolated_distributed
  100. def test_all_gather_multidtype(dtype):
  101. run_all_gather((8, 10), dtype)
  102. def run_reduce_scatter_sum(shape, dtype):
  103. @dist.launcher(n_gpus=2)
  104. def worker(data, expect):
  105. rank = dist.get_rank()
  106. inp = tensor(data[rank])
  107. output = reduce_scatter_sum(inp)
  108. assert np.allclose(output.numpy(), expect[rank])
  109. x = np.random.random_sample(shape).astype(dtype)
  110. y = np.random.random_sample(shape).astype(dtype)
  111. z = x + y
  112. data = (x, y)
  113. expect = (z[: shape[0] // 2], z[shape[0] // 2 :])
  114. worker(data, expect)
  115. @pytest.mark.require_ngpu(2)
  116. @pytest.mark.parametrize("shape", [(2, 3), (8, 10), (88, 44)], ids=str)
  117. @pytest.mark.isolated_distributed
  118. def test_reduce_scatter_sum_multishape(shape):
  119. run_reduce_scatter_sum(shape, "float32")
  120. @pytest.mark.require_ngpu(2)
  121. @pytest.mark.parametrize("dtype", ["float32", "int32", "int8", "uint8"], ids=str)
  122. @pytest.mark.isolated_distributed
  123. def test_reduce_scatter_sum_multidtype(dtype):
  124. run_reduce_scatter_sum((8, 10), dtype)
  125. def run_all_reduce_sum(shape, dtype):
  126. @dist.launcher(n_gpus=2)
  127. def worker(data, expect):
  128. rank = dist.get_rank()
  129. inp = tensor(data[rank])
  130. output = all_reduce_sum(inp)
  131. assert np.allclose(output.numpy(), expect[rank])
  132. x = np.random.random_sample(shape).astype(dtype)
  133. y = np.random.random_sample(shape).astype(dtype)
  134. z = x + y
  135. data = (x, y)
  136. expect = (z, z)
  137. worker(data, expect)
  138. @pytest.mark.require_ngpu(2)
  139. @pytest.mark.parametrize("shape", [(), (1,), (2, 3), (8, 10), (99, 77)], ids=str)
  140. @pytest.mark.isolated_distributed
  141. def test_all_reduce_sum_multishape(shape):
  142. run_all_reduce_sum(shape, "float32")
  143. @pytest.mark.require_ngpu(2)
  144. @pytest.mark.parametrize("dtype", ["float32", "int32", "int8", "uint8"], ids=str)
  145. @pytest.mark.isolated_distributed
  146. def test_all_reduce_sum_multidtype(dtype):
  147. run_all_reduce_sum((8, 10), dtype)
  148. def run_all_reduce_max(shape, dtype):
  149. @dist.launcher(n_gpus=2)
  150. def worker(data, expect):
  151. rank = dist.get_rank()
  152. inp = tensor(data[rank])
  153. output = all_reduce_max(inp)
  154. assert np.allclose(output.numpy(), expect[rank])
  155. x = np.random.random_sample(shape).astype(dtype)
  156. y = np.random.random_sample(shape).astype(dtype)
  157. z = np.maximum(x, y)
  158. data = (x, y)
  159. expect = (z, z)
  160. worker(data, expect)
  161. @pytest.mark.require_ngpu(2)
  162. @pytest.mark.parametrize("shape", [(), (1,), (2, 3), (8, 10), (99, 77)], ids=str)
  163. @pytest.mark.isolated_distributed
  164. def test_all_reduce_max_multishape(shape):
  165. run_all_reduce_max(shape, "float32")
  166. @pytest.mark.require_ngpu(2)
  167. @pytest.mark.parametrize("dtype", ["float32", "int32", "int8", "uint8"], ids=str)
  168. @pytest.mark.isolated_distributed
  169. def test_all_reduce_max_multidtype(dtype):
  170. run_all_reduce_max((8, 10), dtype)
  171. def run_all_reduce_min(shape, dtype):
  172. @dist.launcher(n_gpus=2)
  173. def worker(data, expect):
  174. rank = dist.get_rank()
  175. inp = tensor(data[rank])
  176. output = all_reduce_min(inp)
  177. assert np.allclose(output.numpy(), expect[rank])
  178. x = np.random.random_sample(shape).astype(dtype)
  179. y = np.random.random_sample(shape).astype(dtype)
  180. z = np.minimum(x, y)
  181. data = (x, y)
  182. expect = (z, z)
  183. worker(data, expect)
  184. @pytest.mark.require_ngpu(2)
  185. @pytest.mark.parametrize("shape", [(), (1,), (2, 3), (8, 10), (99, 77)], ids=str)
  186. @pytest.mark.isolated_distributed
  187. def test_all_reduce_min_multishape(shape):
  188. run_all_reduce_min(shape, "float32")
  189. @pytest.mark.require_ngpu(2)
  190. @pytest.mark.parametrize("dtype", ["float32", "int32", "int8", "uint8"], ids=str)
  191. @pytest.mark.isolated_distributed
  192. def test_all_reduce_min_multidtype(dtype):
  193. run_all_reduce_min((8, 10), dtype)
  194. def run_gather(shape, dtype):
  195. @dist.launcher(n_gpus=2)
  196. def worker(data, expect):
  197. rank = dist.get_rank()
  198. inp = tensor(data[rank])
  199. output = gather(inp)
  200. if rank == 0:
  201. assert np.allclose(output.numpy(), expect[rank])
  202. else:
  203. assert output is None
  204. x = np.random.random_sample(shape).astype(dtype)
  205. y = np.random.random_sample(shape).astype(dtype)
  206. z = np.concatenate((x, y))
  207. data = (x, y)
  208. expect = (z, None)
  209. worker(data, expect)
  210. @pytest.mark.require_ngpu(2)
  211. @pytest.mark.parametrize("shape", [(2, 3), (8, 10), (99, 77)], ids=str)
  212. @pytest.mark.isolated_distributed
  213. def test_gather_multishape(shape):
  214. run_gather(shape, "float32")
  215. @pytest.mark.require_ngpu(2)
  216. @pytest.mark.parametrize("dtype", ["float32", "int32", "int8", "uint8"], ids=str)
  217. @pytest.mark.isolated_distributed
  218. def test_gather_multidtype(dtype):
  219. run_gather((8, 10), dtype)
  220. def run_scatter(shape, dtype):
  221. @dist.launcher(n_gpus=2)
  222. def worker(data, expect):
  223. rank = dist.get_rank()
  224. inp = tensor(data[rank])
  225. output = scatter(inp)
  226. assert np.allclose(output.numpy(), expect[rank])
  227. x = np.random.random_sample(shape).astype(dtype)
  228. y = x + 1
  229. data = (x, y)
  230. expect = (x[: shape[0] // 2], x[shape[0] // 2 :])
  231. worker(data, expect)
  232. @pytest.mark.require_ngpu(2)
  233. @pytest.mark.parametrize("shape", [(2, 3), (8, 10), (100, 77)], ids=str)
  234. @pytest.mark.isolated_distributed
  235. def test_scatter_multishape(shape):
  236. run_scatter(shape, "float32")
  237. @pytest.mark.require_ngpu(2)
  238. @pytest.mark.parametrize("dtype", ["float32", "int32", "int8", "uint8"], ids=str)
  239. @pytest.mark.isolated_distributed
  240. def test_scatter_multidtype(dtype):
  241. run_scatter((8, 10), dtype)
  242. def run_all_to_all(shape, dtype):
  243. @dist.launcher(n_gpus=2)
  244. def worker(data, expect):
  245. rank = dist.get_rank()
  246. inp = tensor(data[rank])
  247. output = all_to_all(inp)
  248. assert np.allclose(output.numpy(), expect[rank])
  249. x = np.random.random_sample(shape).astype(dtype)
  250. y = np.random.random_sample(shape).astype(dtype)
  251. a = np.concatenate((x[: shape[0] // 2], y[: shape[0] // 2]))
  252. b = np.concatenate((x[shape[0] // 2 :], y[shape[0] // 2 :]))
  253. data = (x, y)
  254. expect = (a, b)
  255. worker(data, expect)
  256. @pytest.mark.require_ngpu(2)
  257. @pytest.mark.parametrize("shape", [(2, 3), (8, 10), (100, 77)], ids=str)
  258. @pytest.mark.isolated_distributed
  259. def test_all_to_all_multishape(shape):
  260. run_all_to_all(shape, "float32")
  261. @pytest.mark.require_ngpu(2)
  262. @pytest.mark.parametrize("dtype", ["float32", "int32", "int8", "uint8"], ids=str)
  263. @pytest.mark.isolated_distributed
  264. def test_all_to_all_multidtype(dtype):
  265. run_all_to_all((8, 10), dtype)
  266. def run_io_remote(shape, dtype):
  267. @dist.launcher(n_gpus=2)
  268. def worker(val, shape):
  269. rank = dist.get_rank()
  270. if rank == 0: # remote send
  271. x = tensor(val, device="xpu0")
  272. remote_send(x, 1)
  273. sync()
  274. else: # remote recv
  275. y = remote_recv(0)
  276. assert y.device == get_default_device()
  277. np.testing.assert_almost_equal(val, y.numpy())
  278. val = np.random.random_sample(shape).astype(dtype)
  279. worker(val, shape)
  280. @pytest.mark.require_ngpu(2)
  281. @pytest.mark.isolated_distributed
  282. @pytest.mark.parametrize("shape", [(), (1,), (4, 5)], ids=str)
  283. def test_io_remote_multishape(shape):
  284. run_io_remote(shape, "float32")
  285. @pytest.mark.require_ngpu(2)
  286. @pytest.mark.isolated_distributed
  287. @pytest.mark.parametrize("dtype", ["float32", "int32", "int8", "uint8"], ids=str)
  288. def test_io_remote_multidtype(dtype):
  289. run_io_remote((8, 10), dtype)
  290. @pytest.mark.require_ngpu(2)
  291. def test_cuda_init_before_fork():
  292. a = mge.tensor(1, device="gpu0")
  293. @dist.launcher(n_gpus=2)
  294. def worker():
  295. a += 1
  296. b = mge.tensor(2)
  297. with pytest.raises(AssertionError):
  298. worker()

MegEngine 安装包中集成了使用 GPU 运行代码所需的 CUDA 环境,不用区分 CPU 和 GPU 版。 如果想要运行 GPU 程序,请确保机器本身配有 GPU 硬件设备并安装好驱动。 如果你想体验在云端 GPU 算力平台进行深度学习开发的感觉,欢迎访问 MegStudio 平台