You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_reproducible_sampler.py 30 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744
  1. import numpy as np
  2. import pytest
  3. from functools import partial
  4. from itertools import chain
  5. from copy import deepcopy
  6. from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler
  7. from tests.helpers.datasets.torch_data import TorchNormalDataset
  8. class TestRandomSamplerYh:
  9. def test_init(self):
  10. # 测试能否正确初始化
  11. dataset = TorchNormalDataset(num_of_data=100)
  12. sampler = RandomSampler(dataset)
  13. for i in sampler:
  14. pass
  15. def test_during_iter(self):
  16. dataset = TorchNormalDataset(num_of_data=100)
  17. sampler = RandomSampler(dataset)
  18. for i in sampler:
  19. with pytest.raises(AssertionError):
  20. sampler.set_distributed(1, 0)
  21. break
  22. # should not raise
  23. for i in sampler:
  24. pass
  25. sampler.set_distributed(1, 0)
  26. def test_set_distributed(self):
  27. dataset = TorchNormalDataset(num_of_data=100)
  28. sampler = RandomSampler(dataset, shuffle=False)
  29. sampler.set_distributed(num_replicas=2, rank=0, pad=False)
  30. assert len(sampler)==50
  31. count = 0
  32. for i in sampler:
  33. assert i%2==0
  34. count += 1
  35. assert count == 50
  36. sampler.set_distributed(num_replicas=2, rank=1, pad=False)
  37. assert len(sampler)==50
  38. count = 0
  39. for i in sampler:
  40. assert i%2==1
  41. count += 1
  42. assert count==50
  43. dataset = TorchNormalDataset(num_of_data=101)
  44. sampler = RandomSampler(dataset, shuffle=False)
  45. sampler.set_distributed(num_replicas=2, rank=0, pad=True)
  46. assert len(sampler)==51
  47. count = 0
  48. for i in sampler:
  49. assert i%2==0
  50. count += 1
  51. assert count == 51
  52. sampler.set_distributed(num_replicas=2, rank=1, pad=True)
  53. assert len(sampler) == 51
  54. count = 0
  55. for i in sampler:
  56. if i!=0:
  57. assert i%2==1
  58. count += 1
  59. assert count == 51
  60. def test_state_dict_check_length(self):
  61. dataset = TorchNormalDataset(num_of_data=100)
  62. sampler = RandomSampler(dataset, shuffle=False)
  63. states = sampler.state_dict()
  64. new_ds = TorchNormalDataset(num_of_data=10)
  65. with pytest.raises(AssertionError):
  66. new_sampler = RandomSampler(new_ds)
  67. new_sampler.load_state_dict(states)
  68. new_ds = TorchNormalDataset(num_of_data=100)
  69. new_sampler = RandomSampler(new_ds)
  70. new_sampler.load_state_dict(states)
  71. @pytest.mark.parametrize('pad', [True, False])
  72. @pytest.mark.parametrize('pre_shuffle', [True, False])
  73. @pytest.mark.parametrize('post_shuffle', [True, False])
  74. @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist())
  75. def test_state_dict(self, pad, pre_shuffle, post_shuffle, num_consumed_samples):
  76. num_samples = 100
  77. dataset = TorchNormalDataset(num_of_data=num_samples)
  78. # 测试使用 前后shuffle不一致的load操作
  79. sampler = RandomSampler(dataset, shuffle=pre_shuffle)
  80. sampler.set_epoch(0)
  81. already_numbers = set()
  82. if num_consumed_samples>0:
  83. for i, j in enumerate(sampler, start=1):
  84. already_numbers.add(j)
  85. if i == num_consumed_samples:
  86. break
  87. assert len(already_numbers) == num_consumed_samples
  88. states = sampler.state_dict()
  89. new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
  90. new_sampler.load_state_dict(states)
  91. new_sampler.set_epoch(0)
  92. for i in new_sampler:
  93. assert i not in already_numbers
  94. # 测试切换成多卡也没有问题
  95. other_rank_number = set()
  96. for rank in range(3):
  97. new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
  98. new_sampler.load_state_dict(states)
  99. new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
  100. new_sampler.set_epoch(0)
  101. count = 0
  102. seen = 0
  103. seen_in_other_rank = 0
  104. for i in new_sampler:
  105. seen_in_other_rank += int(i in other_rank_number)
  106. other_rank_number.add(i)
  107. seen += int(i in already_numbers)
  108. count += 1
  109. assert seen <= 1 if pad else seen == 0
  110. assert seen_in_other_rank<=1 # 因为pad可能重复
  111. @pytest.mark.parametrize('pad', [True, False])
  112. @pytest.mark.parametrize('pre_shuffle', [True, False])
  113. @pytest.mark.parametrize('post_shuffle', [True, False])
  114. @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist())
  115. def test_state_dict_2(self, pad, pre_shuffle, post_shuffle, num_consumed_samples):
  116. # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
  117. num_samples = 100
  118. dataset = TorchNormalDataset(num_of_data=num_samples)
  119. # 测试使用 前后shuffle不一致的load操作
  120. # lst = [30]
  121. already_numbers = set()
  122. sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0)
  123. sampler.set_distributed(num_replicas=2, rank=0)
  124. sampler.set_epoch(0)
  125. if num_consumed_samples>0:
  126. for i, j in enumerate(sampler, start=1):
  127. already_numbers.add(j)
  128. if i == num_consumed_samples:
  129. break
  130. sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0)
  131. sampler.set_epoch(0)
  132. sampler.set_distributed(num_replicas=2, rank=1)
  133. if num_consumed_samples>0:
  134. for i, j in enumerate(sampler, start=1):
  135. already_numbers.add(j)
  136. if i == num_consumed_samples:
  137. break
  138. assert len(already_numbers) == num_consumed_samples*2
  139. states = sampler.state_dict()
  140. new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
  141. new_sampler.load_state_dict(states)
  142. new_sampler.set_epoch(0)
  143. for i in new_sampler:
  144. assert i not in already_numbers
  145. # 测试切换成多卡也没有问题
  146. other_rank_number = set()
  147. for rank in range(3):
  148. new_sampler = RandomSampler(dataset, shuffle=post_shuffle)
  149. new_sampler.load_state_dict(states)
  150. new_sampler.set_epoch(0)
  151. new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
  152. count = 0
  153. seen = 0
  154. seen_in_other_rank = 0
  155. for i in new_sampler:
  156. seen_in_other_rank += int(i in other_rank_number)
  157. other_rank_number.add(i)
  158. seen += int(i in already_numbers)
  159. count += 1
  160. assert seen <= 1 if pad else seen == 0
  161. assert seen_in_other_rank<=1 # 因为pad可能重复
  162. @pytest.mark.parametrize('shuffle', [True, False])
  163. @pytest.mark.parametrize('pad', [True, False])
  164. @pytest.mark.parametrize('num_samples', [13, 100, 623, 1000])
  165. @pytest.mark.parametrize('num_replicas', [1, 2, 3])
  166. def test_num_consumed_samples_array(self, shuffle, pad, num_samples, num_replicas):
  167. # def test_num_consumed_samples_array(self, shuffle=True, pad=True, num_samples=100, num_replicas=2):
  168. # 测试在 sampler 多生成的时候,可以仍然可以恢复
  169. dataset = DatasetWithVaryLength(num_of_data=num_samples)
  170. samplers = []
  171. num_consumed_samples_array = list(range(0, len(dataset)+num_replicas, num_replicas))
  172. for i in range(num_replicas):
  173. sampler = RandomSampler(dataset, shuffle=shuffle)
  174. sampler.set_epoch(0)
  175. sampler.set_distributed(num_replicas=num_replicas, rank=i, pad=pad)
  176. samplers.append(sampler)
  177. count = 0
  178. already_seen_sets = [set()]
  179. already_seen_set = set()
  180. for idxes in zip(*samplers):
  181. already_seen_set.update(idxes)
  182. already_seen_sets.append(deepcopy(already_seen_set))
  183. count += 1
  184. if count > 3:
  185. break
  186. states = samplers[0].state_dict()
  187. for i in range(len(already_seen_sets)):
  188. states['num_consumed_samples'] = num_consumed_samples_array[i]
  189. sampler = RandomSampler(dataset, shuffle=shuffle)
  190. already_seen_set = deepcopy(already_seen_sets[i])
  191. for batch in sampler:
  192. already_seen_set.add(batch)
  193. assert len(already_seen_set) == len(dataset)
  194. # 测试保存之后再次保存
  195. sampler = RandomSampler(dataset, shuffle=shuffle)
  196. sampler.set_epoch(0)
  197. if len(already_seen_sets)<3:
  198. return
  199. already_seen_set = already_seen_sets[2]
  200. count = 0
  201. num_consumed_samples_array = list(range(0, num_samples))
  202. for idx in sampler:
  203. already_seen_set.add(idx)
  204. count += 1
  205. if count > 6:
  206. break
  207. states = sampler.state_dict()
  208. states['num_consumed_samples'] = num_consumed_samples_array[count]
  209. sampler = RandomSampler(dataset, shuffle=shuffle)
  210. sampler.load_state_dict(states)
  211. sampler.set_epoch(0)
  212. for idx in sampler:
  213. already_seen_set.add(idx)
  214. assert len(already_seen_set)==len(dataset)
  215. class TestRandomSampler:
  216. # 测试单卡;
  217. def test_seed_work_when_shuffle_is_true(self):
  218. data_length = 100
  219. torch_normal_data = TorchNormalDataset(num_of_data=data_length)
  220. for shuffle in [True, False]:
  221. iterable = RandomSampler(dataset=torch_normal_data, shuffle=shuffle)
  222. # 迭代一些数据,但是不迭代完;
  223. iterable.set_epoch(1)
  224. iterator = iter(iterable)
  225. pre_data = []
  226. forward_steps = 30
  227. for _ in range(forward_steps):
  228. pre_data.append(next(iterator))
  229. # 看重新生成迭代器是否能够完全重置状态;
  230. iterator = iter(iterable)
  231. res = []
  232. for _ in range(forward_steps):
  233. res.append(next(iterator))
  234. assert pre_data == res
  235. # 测试断点重训;
  236. # 如果 shuffle,那么下一轮的数据应当与前一轮不一样;并且如果是断点重训,两次的下一轮应当是一样的;
  237. def test_2(self):
  238. data_length = 100
  239. torch_normal_data = TorchNormalDataset(num_of_data=data_length)
  240. random_sampler_1 = RandomSampler(dataset=torch_normal_data, shuffle=True)
  241. iterator = iter(random_sampler_1)
  242. # 第一轮
  243. random_sampler_1.set_epoch(0)
  244. first_epoch = []
  245. forward_steps = 30
  246. for _ in range(forward_steps):
  247. first_epoch.append(next(iterator))
  248. # 先提前保存断点重训的结果;
  249. state = random_sampler_1.state_dict()
  250. # 保存第一个 epoch 的之后的结果,用于查看断点重训是否正确;
  251. first_left_data = []
  252. while True:
  253. try:
  254. first_left_data.append(next(iterator))
  255. except StopIteration:
  256. break
  257. # 第二轮
  258. random_sampler_1.set_epoch(1)
  259. iterator = iter(random_sampler_1)
  260. second_epoch = []
  261. for _ in range(forward_steps):
  262. second_epoch.append(next(iterator))
  263. assert first_epoch != second_epoch
  264. # 重新加载第一轮的状态,查看断点重训是否正确;
  265. random_sampler_2 = RandomSampler(dataset=torch_normal_data, shuffle=True)
  266. random_sampler_2.load_state_dict(state)
  267. random_sampler_2.set_epoch(0)
  268. iterator = iter(random_sampler_2)
  269. re_first_epoch = []
  270. while True:
  271. try:
  272. re_first_epoch.append(next(iterator))
  273. except StopIteration:
  274. break
  275. assert re_first_epoch == first_left_data
  276. # 查看第二轮的结果是否也是和第一次的第二轮完全一致;
  277. random_sampler_2.set_epoch(1)
  278. iterator = iter(random_sampler_2)
  279. re_second_epoch = []
  280. for _ in range(forward_steps):
  281. re_second_epoch.append(next(iterator))
  282. assert re_second_epoch == second_epoch
  283. # 多卡;
  284. # 如果一个 sampler 还没有迭代完,我们又直接 iter(sampler) 那么是否正确(应当生成一个全新的 sampler)?
  285. def test_3(self):
  286. data_length = 100
  287. torch_normal_data = TorchNormalDataset(num_of_data=data_length)
  288. random_sampler_1 = partial(RandomSampler, dataset=torch_normal_data, shuffle=False)
  289. random_sampler_2 = partial(RandomSampler, dataset=torch_normal_data, shuffle=True)
  290. iterable_items = [random_sampler_1, random_sampler_2]
  291. world_size = 3
  292. for pad in {True, False}:
  293. for iterable in iterable_items:
  294. for rank in range(world_size):
  295. each_rank_iterable = iterable()
  296. each_rank_iterable.set_epoch(0)
  297. each_rank_iterable.set_distributed(num_replicas=world_size, rank=rank, pad=pad)
  298. # 迭代一些数据,但是不迭代完;
  299. iterator = iter(each_rank_iterable)
  300. pre_data = []
  301. forward_steps = 10
  302. for _ in range(forward_steps):
  303. pre_data.append(next(iterator))
  304. # 看重新生成迭代器是否能够完全重置状态;
  305. iterator = iter(each_rank_iterable)
  306. res = []
  307. for _ in range(forward_steps):
  308. res.append(next(iterator))
  309. assert res == pre_data
  310. # 测试断点重训;
  311. # 如果 shuffle,那么下一轮的数据应当与前一轮不一样;并且如果是断点重训,两次的下一轮应当是一样的;
  312. def test_4(self):
  313. data_length = 100
  314. torch_normal_data = TorchNormalDataset(num_of_data=data_length)
  315. random_sampler_1 = partial(RandomSampler, dataset=torch_normal_data, shuffle=True)
  316. world_size_1 = 2
  317. forward_steps = 10
  318. for pad in {True, False}:
  319. all_rank_state = {}
  320. all_rank_first_left_data = {}
  321. all_rank_second_epoch = {}
  322. for rank in range(world_size_1):
  323. each_rank_iterable = random_sampler_1()
  324. each_rank_iterable.set_distributed(num_replicas=world_size_1, rank=rank, pad=pad)
  325. iterator = iter(each_rank_iterable)
  326. # 第一轮
  327. each_rank_iterable.set_epoch(0)
  328. first_epoch = []
  329. for _ in range(forward_steps):
  330. first_epoch.append(next(iterator))
  331. # 先提前保存断点重训的结果;
  332. all_rank_state[rank] = each_rank_iterable.state_dict()
  333. # 保存第一个 epoch 的之后的结果,用于查看断点重训是否正确;
  334. first_left_data = []
  335. while True:
  336. try:
  337. first_left_data.append(next(iterator))
  338. except StopIteration:
  339. break
  340. all_rank_first_left_data[rank] = first_left_data
  341. # 第二轮
  342. each_rank_iterable.set_epoch(1)
  343. iterator = iter(each_rank_iterable)
  344. second_epoch = []
  345. for _ in range(forward_steps):
  346. second_epoch.append(next(iterator))
  347. all_rank_second_epoch[rank] = second_epoch
  348. assert first_epoch != second_epoch
  349. # 重新加载第一轮的状态,查看断点重训是否正确;
  350. random_sampler_2 = partial(RandomSampler, dataset=torch_normal_data, shuffle=True)
  351. for rank in range(world_size_1):
  352. each_rank_iterable = random_sampler_2()
  353. each_rank_iterable.set_distributed(num_replicas=world_size_1, rank=rank, pad=pad)
  354. each_rank_iterable.load_state_dict(all_rank_state[rank])
  355. each_rank_iterable.set_epoch(0)
  356. iterator = iter(each_rank_iterable)
  357. re_first_epoch = []
  358. while True:
  359. try:
  360. re_first_epoch.append(next(iterator))
  361. except StopIteration:
  362. break
  363. assert re_first_epoch == all_rank_first_left_data[rank]
  364. # 查看第二轮的结果是否也是和第一次的第二轮完全一致;
  365. each_rank_iterable.set_epoch(1)
  366. iterator = iter(each_rank_iterable)
  367. re_second_epoch = []
  368. for _ in range(forward_steps):
  369. re_second_epoch.append(next(iterator))
  370. assert re_second_epoch == all_rank_second_epoch[rank]
  371. # todo 测试 ddp 时 world_size 改变的断点重训;
  372. def test_5(self):
  373. ...
  374. class DatasetWithVaryLength:
  375. def __init__(self, num_of_data=100, reverse=False):
  376. self.data = np.arange(num_of_data)
  377. if reverse:
  378. self.data = self.data[::-1]
  379. def __getitem__(self, item):
  380. return self.data[item]
  381. def __len__(self):
  382. return len(self.data)
  383. class TestSortedSampler:
  384. def test_single(self):
  385. num_of_data = 100
  386. data = DatasetWithVaryLength(num_of_data)
  387. sampler = SortedSampler(data, length=data.data)
  388. indexes = list(sampler)
  389. assert indexes==list(range(num_of_data-1, -1, -1))
  390. @pytest.mark.parametrize('pad', [True, False])
  391. @pytest.mark.parametrize('num_replicas', [2, 3])
  392. @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
  393. def test_multi(self, pad, num_replicas, num_of_data):
  394. data = DatasetWithVaryLength(num_of_data=num_of_data)
  395. samplers = []
  396. for i in range(num_replicas):
  397. sampler = SortedSampler(dataset=data, length=data.data)
  398. sampler.set_distributed(num_replicas, rank=i, pad=pad)
  399. samplers.append(sampler)
  400. # 保证顺序是没乱的
  401. already_seen_index = set()
  402. for sampler in samplers:
  403. larger_count = 0 # 这里为 0 就可以,因为最后补充的index一定是比较大的数。
  404. prev_index = float('inf')
  405. cur_set = set()
  406. seen_in_other_rank = 0
  407. for index in sampler:
  408. seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉
  409. cur_set.add(index)
  410. larger_count += int(index <= prev_index)
  411. prev_index = index
  412. assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序
  413. assert seen_in_other_rank <= 1 if pad else seen_in_other_rank == 0
  414. already_seen_index.update(cur_set)
  415. indexes = list(chain(*samplers))
  416. indexes = set(indexes)
  417. if pad:
  418. assert indexes == set(range(num_of_data))
  419. else:
  420. assert len(indexes) <= num_of_data
  421. @pytest.mark.parametrize('pad', [True, False])
  422. @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist())
  423. def test_state_dict(self, pad, num_consumed_samples):
  424. num_samples = 100
  425. dataset = DatasetWithVaryLength(num_of_data=num_samples)
  426. # 测试使用 前后shuffle不一致的load操作
  427. sampler = SortedSampler(dataset, length=dataset.data)
  428. sampler.set_epoch(0)
  429. already_numbers = set()
  430. if num_consumed_samples>0:
  431. for i, j in enumerate(sampler, start=1):
  432. if already_numbers:
  433. assert j<max(already_numbers)
  434. already_numbers.add(j)
  435. if i == num_consumed_samples:
  436. break
  437. assert len(already_numbers) == num_consumed_samples
  438. states = sampler.state_dict()
  439. new_sampler = SortedSampler(dataset, length=dataset.data)
  440. new_sampler.load_state_dict(states)
  441. new_sampler.set_epoch(0)
  442. for i in new_sampler:
  443. if already_numbers:
  444. assert i < max(already_numbers)
  445. assert i not in already_numbers
  446. # 测试切换成多卡也没有问题
  447. other_rank_number = set()
  448. for rank in range(3):
  449. new_sampler = SortedSampler(dataset, length=dataset.data)
  450. new_sampler.load_state_dict(states)
  451. new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
  452. new_sampler.set_epoch(0)
  453. count = 0
  454. seen = 0
  455. seen_in_other_rank = 0
  456. smaller = 0
  457. for i in new_sampler:
  458. if already_numbers:
  459. smaller += int(i >= max(already_numbers))
  460. seen_in_other_rank += int(i in other_rank_number)
  461. other_rank_number.add(i)
  462. seen += int(i in already_numbers)
  463. count += 1
  464. assert seen <= 1 if pad else seen == 0
  465. assert seen_in_other_rank<=1 # 因为pad可能重复
  466. assert smaller<=1 if pad else smaller==0
  467. @pytest.mark.parametrize('pad', [True, False])
  468. @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist())
  469. def test_state_dict_2(self, pad, num_consumed_samples):
  470. # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
  471. num_samples = 100
  472. dataset = DatasetWithVaryLength(num_of_data=num_samples)
  473. # 测试使用 前后shuffle不一致的load操作
  474. # lst = [30]
  475. already_numbers = set()
  476. sampler = SortedSampler(dataset, length=dataset.data)
  477. sampler.set_distributed(num_replicas=2, rank=0)
  478. sampler.set_epoch(0)
  479. if num_consumed_samples>0:
  480. for i, j in enumerate(sampler, start=1):
  481. if already_numbers:
  482. assert j<=max(already_numbers)
  483. already_numbers.add(j)
  484. if i == num_consumed_samples:
  485. break
  486. sampler = SortedSampler(dataset, length=dataset.data)
  487. sampler.set_epoch(0)
  488. sampler.set_distributed(num_replicas=2, rank=1)
  489. if num_consumed_samples>0:
  490. for i, j in enumerate(sampler, start=1):
  491. already_numbers.add(j)
  492. if i == num_consumed_samples:
  493. break
  494. assert len(already_numbers) == num_consumed_samples*2
  495. states = sampler.state_dict()
  496. new_sampler = SortedSampler(dataset, length=dataset.data)
  497. new_sampler.load_state_dict(states)
  498. new_sampler.set_epoch(0)
  499. for i in new_sampler:
  500. if already_numbers:
  501. assert i < max(already_numbers)
  502. assert i not in already_numbers
  503. # 测试切换成多卡也没有问题
  504. other_rank_number = set()
  505. for rank in range(3):
  506. new_sampler = SortedSampler(dataset, length=dataset.data)
  507. new_sampler.load_state_dict(states)
  508. new_sampler.set_epoch(0)
  509. new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
  510. count = 0
  511. seen = 0
  512. seen_in_other_rank = 0
  513. smaller = 0
  514. for i in new_sampler:
  515. if already_numbers:
  516. smaller += int(i>=max(already_numbers))
  517. seen_in_other_rank += int(i in other_rank_number)
  518. other_rank_number.add(i)
  519. seen += int(i in already_numbers)
  520. count += 1
  521. assert seen <= 1 if pad else seen == 0
  522. assert seen_in_other_rank<=1 # 因为pad可能重复
  523. assert smaller <= 1 if pad else smaller == 0
  524. class TestSequentialSampler:
  525. def test_single(self):
  526. num_of_data = 100
  527. data = DatasetWithVaryLength(num_of_data)
  528. sampler = SequentialSampler(data)
  529. indexes = list(sampler)
  530. assert indexes==list(range(num_of_data))
  531. @pytest.mark.parametrize('pad', [True, False])
  532. @pytest.mark.parametrize('num_replicas', [2, 3])
  533. @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100])
  534. def test_multi(self, pad, num_replicas, num_of_data):
  535. data = DatasetWithVaryLength(num_of_data=num_of_data)
  536. samplers = []
  537. for i in range(num_replicas):
  538. sampler = SequentialSampler(dataset=data)
  539. sampler.set_distributed(num_replicas, rank=i, pad=pad)
  540. samplers.append(sampler)
  541. # 保证顺序是没乱的
  542. already_seen_index = set()
  543. for idx, sampler in enumerate(samplers):
  544. larger_count = 1
  545. prev_index = float('inf')
  546. cur_set = set()
  547. seen_in_other_rank = 0
  548. for index in sampler:
  549. seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉
  550. cur_set.add(index)
  551. larger_count += int(index >= prev_index)
  552. prev_index = index
  553. assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序
  554. assert seen_in_other_rank <= idx if pad else seen_in_other_rank == 0
  555. already_seen_index.update(cur_set)
  556. indexes = list(chain(*samplers))
  557. indexes = set(indexes)
  558. if pad:
  559. assert indexes == set(range(num_of_data))
  560. else:
  561. assert len(indexes) <= num_of_data
  562. @pytest.mark.parametrize('pad', [True, False])
  563. @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist())
  564. def test_state_dict(self, pad, num_consumed_samples):
  565. num_samples = 100
  566. dataset = DatasetWithVaryLength(num_of_data=num_samples)
  567. # 测试使用 前后shuffle不一致的load操作
  568. sampler = SequentialSampler(dataset=dataset)
  569. sampler.set_epoch(0)
  570. already_numbers = set()
  571. if num_consumed_samples>0:
  572. for i, j in enumerate(sampler, start=1):
  573. if already_numbers:
  574. assert j>max(already_numbers)
  575. already_numbers.add(j)
  576. if i == num_consumed_samples:
  577. break
  578. assert len(already_numbers) == num_consumed_samples
  579. states = sampler.state_dict()
  580. new_sampler = SequentialSampler(dataset=dataset)
  581. new_sampler.load_state_dict(states)
  582. new_sampler.set_epoch(0)
  583. for i in new_sampler:
  584. if already_numbers:
  585. assert i > max(already_numbers)
  586. assert i not in already_numbers
  587. # 测试切换成多卡也没有问题
  588. other_rank_number = set()
  589. for rank in range(3):
  590. new_sampler = SequentialSampler(dataset=dataset)
  591. new_sampler.load_state_dict(states)
  592. new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
  593. new_sampler.set_epoch(0)
  594. count = 0
  595. seen = 0
  596. seen_in_other_rank = 0
  597. smaller = 0
  598. for i in new_sampler:
  599. if already_numbers:
  600. smaller += int(i <= max(already_numbers))
  601. seen_in_other_rank += int(i in other_rank_number)
  602. other_rank_number.add(i)
  603. seen += int(i in already_numbers)
  604. count += 1
  605. assert seen <= 1 if pad else seen == 0
  606. assert seen_in_other_rank<=rank # 因为pad可能重复
  607. assert smaller<=1 if pad else smaller==0
  608. @pytest.mark.parametrize('pad', [True, False])
  609. @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist())
  610. def test_state_dict_2(self, pad, num_consumed_samples):
  611. # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡
  612. num_samples = 100
  613. dataset = DatasetWithVaryLength(num_of_data=num_samples)
  614. # 测试使用 前后shuffle不一致的load操作
  615. # lst = [30]
  616. already_numbers = set()
  617. sampler = SequentialSampler(dataset=dataset)
  618. sampler.set_distributed(num_replicas=2, rank=0)
  619. sampler.set_epoch(0)
  620. if num_consumed_samples>0:
  621. for i, j in enumerate(sampler, start=1):
  622. if already_numbers:
  623. assert j>max(already_numbers)
  624. already_numbers.add(j)
  625. if i == num_consumed_samples:
  626. break
  627. sampler = SequentialSampler(dataset=dataset)
  628. sampler.set_epoch(0)
  629. sampler.set_distributed(num_replicas=2, rank=1)
  630. if num_consumed_samples>0:
  631. for i, j in enumerate(sampler, start=1):
  632. already_numbers.add(j)
  633. if i == num_consumed_samples:
  634. break
  635. assert len(already_numbers) == num_consumed_samples*2
  636. states = sampler.state_dict()
  637. new_sampler = SequentialSampler(dataset=dataset)
  638. new_sampler.load_state_dict(states)
  639. new_sampler.set_epoch(0)
  640. for i in new_sampler:
  641. if already_numbers:
  642. assert i > max(already_numbers)
  643. assert i not in already_numbers
  644. # 测试切换成多卡也没有问题
  645. other_rank_number = set()
  646. for rank in range(3):
  647. new_sampler = SequentialSampler(dataset=dataset)
  648. new_sampler.load_state_dict(states)
  649. new_sampler.set_epoch(0)
  650. new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad)
  651. count = 0
  652. seen = 0
  653. seen_in_other_rank = 0
  654. smaller = 0
  655. for i in new_sampler:
  656. if already_numbers:
  657. smaller += int(i<max(already_numbers))
  658. seen_in_other_rank += int(i in other_rank_number)
  659. other_rank_number.add(i)
  660. seen += int(i in already_numbers)
  661. count += 1
  662. assert seen <= 1 if pad else seen == 0
  663. assert seen_in_other_rank<=1 # 因为pad可能重复
  664. assert smaller <= rank if pad else smaller == 0