You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_unrepeated_sampler.py 4.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. from itertools import chain
  2. import pytest
  3. from fastNLP.core.samplers import UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler
  4. class DatasetWithVaryLength:
  5. def __init__(self, num_of_data=100):
  6. self.data = list(range(num_of_data))
  7. def __getitem__(self, item):
  8. return self.data[item]
  9. def __len__(self):
  10. return len(self.data)
  11. class TestUnrepeatedSampler:
  12. @pytest.mark.parametrize('shuffle', [True, False])
  13. def test_single(self, shuffle):
  14. num_of_data = 100
  15. data = DatasetWithVaryLength(num_of_data)
  16. sampler = UnrepeatedRandomSampler(data, shuffle)
  17. indexes = set(sampler)
  18. assert indexes==set(range(num_of_data))
  19. @pytest.mark.parametrize('num_replicas', [2, 3])
  20. @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
  21. @pytest.mark.parametrize('shuffle', [False, True])
  22. def test_multi(self, num_replicas, num_of_data, shuffle):
  23. if num_replicas > num_of_data:
  24. pytest.skip("num_replicas > num_of_data")
  25. data = DatasetWithVaryLength(num_of_data=num_of_data)
  26. samplers = []
  27. for i in range(num_replicas):
  28. sampler = UnrepeatedRandomSampler(dataset=data, shuffle=shuffle)
  29. sampler.set_distributed(num_replicas, rank=i)
  30. samplers.append(sampler)
  31. indexes = list(chain(*samplers))
  32. assert len(indexes) == num_of_data
  33. indexes = set(indexes)
  34. assert indexes==set(range(num_of_data))
  35. class TestUnrepeatedSortedSampler:
  36. def test_single(self):
  37. num_of_data = 100
  38. data = DatasetWithVaryLength(num_of_data)
  39. sampler = UnrepeatedSortedSampler(data, length=data.data)
  40. indexes = list(sampler)
  41. assert indexes==list(range(num_of_data-1, -1, -1))
  42. @pytest.mark.parametrize('num_replicas', [2, 3])
  43. @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
  44. def test_multi(self, num_replicas, num_of_data):
  45. if num_replicas > num_of_data:
  46. pytest.skip("num_replicas > num_of_data")
  47. data = DatasetWithVaryLength(num_of_data=num_of_data)
  48. samplers = []
  49. for i in range(num_replicas):
  50. sampler = UnrepeatedSortedSampler(dataset=data, length=data.data)
  51. sampler.set_distributed(num_replicas, rank=i)
  52. samplers.append(sampler)
  53. # 保证顺序是没乱的
  54. for sampler in samplers:
  55. prev_index = float('inf')
  56. for index in sampler:
  57. assert index <= prev_index
  58. prev_index = index
  59. indexes = list(chain(*samplers))
  60. assert len(indexes) == num_of_data # 不同卡之间没有交叉
  61. indexes = set(indexes)
  62. assert indexes==set(range(num_of_data))
  63. class TestUnrepeatedSequentialSampler:
  64. def test_single(self):
  65. num_of_data = 100
  66. data = DatasetWithVaryLength(num_of_data)
  67. sampler = UnrepeatedSequentialSampler(data, length=data.data)
  68. indexes = list(sampler)
  69. assert indexes==list(range(num_of_data))
  70. @pytest.mark.parametrize('num_replicas', [2, 3])
  71. @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
  72. @pytest.mark.parametrize('chunk_dist', [True, False])
  73. def test_multi(self, num_replicas, num_of_data, chunk_dist):
  74. if num_replicas > num_of_data:
  75. pytest.skip("num_replicas > num_of_data")
  76. data = DatasetWithVaryLength(num_of_data=num_of_data)
  77. samplers = []
  78. for i in range(num_replicas):
  79. sampler = UnrepeatedSequentialSampler(dataset=data, chunk_dist=chunk_dist)
  80. sampler.set_distributed(num_replicas, rank=i)
  81. samplers.append(sampler)
  82. # 保证顺序是没乱的
  83. for sampler in samplers:
  84. prev_index = float('-inf')
  85. for index in sampler:
  86. assert index>=prev_index
  87. prev_index = index
  88. indexes = list(chain(*samplers))
  89. assert len(indexes) == num_of_data
  90. indexes = set(indexes)
  91. assert indexes == set(range(num_of_data))