You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_functional_distributed_axis.py 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135
  1. # -*- coding: utf-8 -*-
  2. import numpy as np
  3. import pytest
  4. import megengine as mge
  5. import megengine.distributed as dist
  6. from megengine import tensor
  7. from megengine.distributed.functional import (
  8. all_gather,
  9. all_to_all,
  10. gather,
  11. reduce_scatter_sum,
  12. scatter,
  13. )
  14. from megengine.jit import trace
  15. @pytest.mark.require_ngpu(2)
  16. @pytest.mark.parametrize("shape", [(2, 3), (8, 10), (99, 77), (2, 2, 2, 2)], ids=str)
  17. @pytest.mark.parametrize("symbolic", [False, True], ids=str)
  18. @pytest.mark.parametrize("axis", [0, 1], ids=str)
  19. @pytest.mark.isolated_distributed
  20. def test_all_gather(shape, symbolic, axis):
  21. @dist.launcher(n_gpus=2)
  22. def worker(data, expect):
  23. rank = dist.get_rank()
  24. inp = tensor(data[rank])
  25. def func():
  26. output = all_gather(inp, axis=axis)
  27. return output
  28. func = trace(symbolic=symbolic)(func)
  29. output = func()
  30. assert np.allclose(output.numpy(), expect[rank])
  31. x = np.random.random_sample(shape).astype("float32")
  32. y = np.random.random_sample(shape).astype("float32")
  33. z = np.concatenate((x, y), axis=axis)
  34. data = (x, y)
  35. expect = (z, z)
  36. worker(data, expect)
  37. @pytest.mark.require_ngpu(2)
  38. @pytest.mark.parametrize(
  39. "shape,symbolic", [((2, 4, 6, 8), False), ((2, 4, 6, 8), True)], ids=str
  40. )
  41. @pytest.mark.parametrize("axis", [1, 0, 2, 3], ids=str)
  42. @pytest.mark.isolated_distributed
  43. def test_reduce_scatter_sum(shape, symbolic, axis):
  44. @dist.launcher(n_gpus=2)
  45. def worker(data, expect):
  46. rank = dist.get_rank()
  47. inp = tensor(data[rank])
  48. def func():
  49. output = reduce_scatter_sum(inp, axis=axis)
  50. return output
  51. func = trace(symbolic=symbolic)(func)
  52. output = func()
  53. assert np.allclose(output.numpy(), expect[rank])
  54. x = np.random.random_sample(shape).astype("float32")
  55. y = np.random.random_sample(shape).astype("float32")
  56. z = x + y
  57. data = (x, y)
  58. z = np.split(z, 2, axis=axis)
  59. z = np.concatenate(z, axis=0)
  60. expect = (z[: z.shape[0] // 2], z[z.shape[0] // 2 :])
  61. worker(data, expect)
  62. @pytest.mark.require_ngpu(2)
  63. @pytest.mark.parametrize(
  64. "shape,symbolic", [((2, 4, 6, 8), True), ((2, 4, 6, 8), False)], ids=str
  65. )
  66. @pytest.mark.parametrize("axis", [1, 0, 2, 3], ids=str)
  67. @pytest.mark.isolated_distributed
  68. def test_scatter(shape, symbolic, axis):
  69. @dist.launcher(n_gpus=2)
  70. def worker(data, expect):
  71. rank = dist.get_rank()
  72. inp = tensor(data[rank])
  73. def func():
  74. output = scatter(inp, axis=axis)
  75. return output
  76. func = trace(symbolic=symbolic)(func)
  77. output = func()
  78. assert np.allclose(output.numpy(), expect[rank])
  79. x = np.random.random_sample(shape).astype("float32")
  80. y = x + 1
  81. data = (x, y)
  82. _x = np.split(x, 2, axis=axis)
  83. _x = np.concatenate(_x, axis=0)
  84. expect = (_x[: _x.shape[0] // 2], _x[_x.shape[0] // 2 :])
  85. worker(data, expect)
  86. @pytest.mark.require_ngpu(2)
  87. @pytest.mark.parametrize("shape", [(2, 4, 6, 8)], ids=str)
  88. @pytest.mark.parametrize("symbolic", [False, True], ids=str)
  89. @pytest.mark.parametrize(
  90. "split_axis,concat_axis", [(0, 1), (1, 0), (2, 0), (0, 2), (2, 3)], ids=str
  91. )
  92. @pytest.mark.isolated_distributed
  93. def test_all_to_all(shape, symbolic, split_axis, concat_axis):
  94. @dist.launcher(n_gpus=2)
  95. def worker(data):
  96. rank = dist.get_rank()
  97. inp = tensor(data[rank])
  98. def func():
  99. all_to_all_output = all_to_all(
  100. inp, split_axis=split_axis, concat_axis=concat_axis
  101. )
  102. gather_C = gather(inp, axis=concat_axis)
  103. gather_B = gather(all_to_all_output, axis=split_axis)
  104. if rank == 0:
  105. return gather_B, gather_C
  106. return all_to_all_output
  107. func = trace(symbolic=symbolic)(func)
  108. ret = func()
  109. if rank == 0:
  110. assert np.allclose(ret[0], ret[1])
  111. x = np.random.random_sample(shape).astype("float32")
  112. y = np.random.random_sample(shape).astype("float32")
  113. data = (x, y)
  114. worker(data)