You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_mixdataloader.py 22 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495
  1. import pytest
  2. from typing import Mapping
  3. from fastNLP.core.dataloaders import MixDataLoader
  4. from fastNLP import DataSet
  5. from fastNLP.core.collators import Collator
  6. from fastNLP.envs.imports import _NEED_IMPORT_TORCH
  7. if _NEED_IMPORT_TORCH:
  8. import torch
  9. from torch.utils.data import default_collate, SequentialSampler, RandomSampler
  10. d1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
  11. d2 = DataSet({'x': [[101, 201], [201, 301, 401], [100]] * 10, 'y': [20, 10, 10] * 10})
  12. d3 = DataSet({'x': [[1000, 2000], [0], [2000, 3000, 4000, 5000]] * 100, 'y': [100, 100, 200] * 100})
  13. def test_pad_val(tensor, val=0):
  14. if isinstance(tensor, torch.Tensor):
  15. tensor = tensor.tolist()
  16. for item in tensor:
  17. if item[-1] > 0:
  18. continue
  19. elif item[-1] != val:
  20. return False
  21. return True
  22. class TestMixDataLoader:
  23. def test_sequential_init(self):
  24. datasets = {'d1': d1, 'd2': d2, 'd3': d3}
  25. # drop_last = True, collate_fn = 'auto
  26. dl = MixDataLoader(datasets=datasets, mode='sequential', collate_fn='auto', drop_last=True)
  27. for idx, batch in enumerate(dl):
  28. if idx == 0:
  29. # d1
  30. assert batch['x'].shape == torch.Size([16, 4])
  31. if idx == 1:
  32. # d2
  33. assert batch['x'].shape == torch.Size([16, 3])
  34. if idx > 1:
  35. # d3
  36. assert batch['x'].shape == torch.Size([16, 4])
  37. assert test_pad_val(batch['x'], val=0)
  38. # collate_fn = Callable
  39. def collate_batch(batch):
  40. new_batch = {'x': [], 'y': []}
  41. for ins in batch:
  42. new_batch['x'].append(ins['x'])
  43. new_batch['y'].append(ins['y'])
  44. return new_batch
  45. dl1 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_batch, drop_last=True)
  46. for idx, batch in enumerate(dl1):
  47. if idx == 0:
  48. # d1
  49. assert [1, 2] in batch['x']
  50. if idx == 1:
  51. # d2
  52. assert [101, 201] in batch['x']
  53. if idx > 1:
  54. # d3
  55. assert [1000, 2000] in batch['x']
  56. assert 'x' in batch and 'y' in batch
  57. collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1),
  58. 'd2': Collator(backend='auto').set_pad("x", -2),
  59. 'd3': Collator(backend='auto').set_pad("x", -3)}
  60. dl2 = MixDataLoader(datasets=datasets, mode='sequential', collate_fn=collate_fns, drop_last=True)
  61. for idx, batch in enumerate(dl2):
  62. if idx == 0:
  63. assert test_pad_val(batch['x'], val=-1)
  64. assert batch['x'].shape == torch.Size([16, 4])
  65. if idx == 1:
  66. assert test_pad_val(batch['x'], val=-2)
  67. assert batch['x'].shape == torch.Size([16, 3])
  68. if idx > 1:
  69. assert test_pad_val(batch['x'], val=-3)
  70. assert batch['x'].shape == torch.Size([16, 4])
  71. # sampler 为 str
  72. dl3 = MixDataLoader(datasets=datasets, mode='sequential', sampler='seq', drop_last=True)
  73. dl4 = MixDataLoader(datasets=datasets, mode='sequential', sampler='rand', drop_last=True)
  74. for idx, batch in enumerate(dl3):
  75. if idx == 0:
  76. # d1
  77. assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
  78. assert batch['x'].shape == torch.Size([16, 4])
  79. if idx == 1:
  80. # d2
  81. assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
  82. assert batch['x'].shape == torch.Size([16, 3])
  83. if idx == 2:
  84. # d3
  85. assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
  86. if idx > 1:
  87. # d3
  88. assert batch['x'].shape == torch.Size([16, 4])
  89. assert test_pad_val(batch['x'], val=0)
  90. for idx, batch in enumerate(dl4):
  91. if idx == 0:
  92. # d1
  93. assert batch['x'][:3].tolist() != [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
  94. assert batch['x'].shape == torch.Size([16, 4])
  95. if idx == 1:
  96. # d2
  97. assert batch['x'][:3].tolist() != [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
  98. assert batch['x'].shape == torch.Size([16, 3])
  99. if idx == 2:
  100. # d3
  101. assert batch['x'][:3].tolist() != [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
  102. if idx > 1:
  103. # d3
  104. assert batch['x'].shape == torch.Size([16, 4])
  105. assert test_pad_val(batch['x'], val=0)
  106. # sampler 为 Dict
  107. samplers = {'d1': SequentialSampler(d1),
  108. 'd2': SequentialSampler(d2),
  109. 'd3': RandomSampler(d3)}
  110. dl5 = MixDataLoader(datasets=datasets, mode='sequential', sampler=samplers, drop_last=True)
  111. for idx, batch in enumerate(dl5):
  112. if idx == 0:
  113. # d1
  114. assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
  115. assert batch['x'].shape == torch.Size([16, 4])
  116. if idx == 1:
  117. # d2
  118. assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
  119. assert batch['x'].shape == torch.Size([16, 3])
  120. if idx > 1:
  121. # d3
  122. assert batch['x'].shape == torch.Size([16, 4])
  123. assert test_pad_val(batch['x'], val=0)
  124. # ds_ratio 为 'truncate_to_least'
  125. dl6 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='truncate_to_least', drop_last=True)
  126. for idx, batch in enumerate(dl6):
  127. if idx == 0:
  128. # d1
  129. assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
  130. assert batch['x'].shape == torch.Size([16, 4])
  131. if idx == 1:
  132. # d2
  133. assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
  134. assert batch['x'].shape == torch.Size([16, 3])
  135. if idx == 2:
  136. # d3
  137. assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
  138. assert batch['x'].shape == torch.Size([16, 4])
  139. assert test_pad_val(batch['x'], val=0)
  140. if idx > 2:
  141. raise ValueError(f"ds_ratio: 'truncate_to_least' error")
  142. # ds_ratio 为 'pad_to_most'
  143. dl7 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio='pad_to_most', drop_last=True)
  144. for idx, batch in enumerate(dl7):
  145. if idx < 18:
  146. # d1
  147. assert batch['x'].shape == torch.Size([16, 4])
  148. if 18 <= idx < 36:
  149. # d2
  150. assert batch['x'].shape == torch.Size([16, 3])
  151. if 36 <= idx < 54:
  152. # d3
  153. assert batch['x'].shape == torch.Size([16, 4])
  154. assert test_pad_val(batch['x'], val=0)
  155. if idx >= 54:
  156. raise ValueError(f"ds_ratio: 'pad_to_most' error")
  157. # ds_ratio 为 Dict[str, float]
  158. ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0}
  159. dl8 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio=ds_ratio, drop_last=True)
  160. for idx, batch in enumerate(dl8):
  161. if idx < 1:
  162. # d1
  163. assert batch['x'].shape == torch.Size([16, 4])
  164. if 1 <= idx < 4:
  165. # d2
  166. assert batch['x'].shape == torch.Size([16, 3])
  167. if 4 <= idx < 41:
  168. # d3
  169. assert batch['x'].shape == torch.Size([16, 4])
  170. assert test_pad_val(batch['x'], val=0)
  171. if idx >= 41:
  172. raise ValueError(f"ds_ratio: 'pad_to_most' error")
  173. ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0}
  174. dl9 = MixDataLoader(datasets=datasets, mode='sequential', ds_ratio=ds_ratio, drop_last=True)
  175. for idx, batch in enumerate(dl9):
  176. if idx < 1:
  177. # d2
  178. assert batch['x'].shape == torch.Size([16, 3])
  179. if 1 <= idx < 19:
  180. # d3
  181. assert batch['x'].shape == torch.Size([16, 4])
  182. assert test_pad_val(batch['x'], val=0)
  183. if idx >= 19:
  184. raise ValueError(f"ds_ratio: 'pad_to_most' error")
  185. def test_mix(self):
  186. datasets = {'d1': d1, 'd2': d2, 'd3': d3}
  187. dl = MixDataLoader(datasets=datasets, mode='mix', collate_fn='auto', drop_last=True)
  188. for idx, batch in enumerate(dl):
  189. assert test_pad_val(batch['x'], val=0)
  190. if idx >= 22:
  191. raise ValueError(f"out of range")
  192. # collate_fn = Callable
  193. def collate_batch(batch):
  194. new_batch = {'x': [], 'y': []}
  195. for ins in batch:
  196. new_batch['x'].append(ins['x'])
  197. new_batch['y'].append(ins['y'])
  198. return new_batch
  199. dl1 = MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_batch, drop_last=True)
  200. for idx, batch in enumerate(dl1):
  201. assert isinstance(batch['x'], list)
  202. assert test_pad_val(batch['x'], val=0)
  203. if idx >= 22:
  204. raise ValueError(f"out of range")
  205. collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1),
  206. 'd2': Collator(backend='auto').set_pad("x", -2),
  207. 'd3': Collator(backend='auto').set_pad("x", -3)}
  208. with pytest.raises(ValueError):
  209. MixDataLoader(datasets=datasets, mode='mix', collate_fn=collate_fns)
  210. # sampler 为 str
  211. dl3 = MixDataLoader(datasets=datasets, mode='mix', sampler='seq', drop_last=True)
  212. for idx, batch in enumerate(dl3):
  213. assert test_pad_val(batch['x'], val=0)
  214. if idx >= 22:
  215. raise ValueError(f"out of range")
  216. dl4 = MixDataLoader(datasets=datasets, mode='mix', sampler='rand', drop_last=True)
  217. for idx, batch in enumerate(dl4):
  218. assert test_pad_val(batch['x'], val=0)
  219. if idx >= 22:
  220. raise ValueError(f"out of range")
  221. # sampler 为 Dict
  222. samplers = {'d1': SequentialSampler(d1),
  223. 'd2': SequentialSampler(d2),
  224. 'd3': RandomSampler(d3)}
  225. dl5 = MixDataLoader(datasets=datasets, mode='mix', sampler=samplers, drop_last=True)
  226. for idx, batch in enumerate(dl5):
  227. assert test_pad_val(batch['x'], val=0)
  228. if idx >= 22:
  229. raise ValueError(f"out of range")
  230. # ds_ratio 为 'truncate_to_least'
  231. dl6 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio='truncate_to_least')
  232. d1_len, d2_len, d3_len = 0, 0, 0
  233. for idx, batch in enumerate(dl6):
  234. for item in batch['y'].tolist():
  235. if item in [1, 0, 1]:
  236. d1_len += 1
  237. elif item in [20, 10, 10]:
  238. d2_len += 1
  239. elif item in [100, 100, 200]:
  240. d3_len += 1
  241. if idx >= 6:
  242. raise ValueError(f"ds_ratio 为 'truncate_to_least'出错了")
  243. assert d1_len == d2_len == d3_len == 30
  244. # ds_ratio 为 'pad_to_most'
  245. dl7 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio='pad_to_most')
  246. d1_len, d2_len, d3_len = 0, 0, 0
  247. for idx, batch in enumerate(dl7):
  248. for item in batch['y'].tolist():
  249. if item in [1, 0, 1]:
  250. d1_len += 1
  251. elif item in [20, 10, 10]:
  252. d2_len += 1
  253. elif item in [100, 100, 200]:
  254. d3_len += 1
  255. if idx >= 57:
  256. raise ValueError(f"ds_ratio 为 'pad_to_most'出错了")
  257. assert d1_len == d2_len == d3_len == 300
  258. # ds_ratio 为 Dict[str, float]
  259. ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0}
  260. dl8 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio=ds_ratio)
  261. d1_len, d2_len, d3_len = 0, 0, 0
  262. for idx, batch in enumerate(dl8):
  263. for item in batch['y'].tolist():
  264. if item in [1, 0, 1]:
  265. d1_len += 1
  266. elif item in [20, 10, 10]:
  267. d2_len += 1
  268. elif item in [100, 100, 200]:
  269. d3_len += 1
  270. if idx >= 44:
  271. raise ValueError(f"ds_ratio 为 'Dict'出错了")
  272. assert d1_len == 30
  273. assert d2_len == 60
  274. assert d3_len == 600
  275. ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0}
  276. dl9 = MixDataLoader(datasets=datasets, mode='mix', ds_ratio=ds_ratio)
  277. d1_len, d2_len, d3_len = 0, 0, 0
  278. for idx, batch in enumerate(dl9):
  279. for item in batch['y'].tolist():
  280. if item in [1, 0, 1]:
  281. d1_len += 1
  282. elif item in [20, 10, 10]:
  283. d2_len += 1
  284. elif item in [100, 100, 200]:
  285. d3_len += 1
  286. if idx >= 21:
  287. raise ValueError(f"ds_ratio 为 'Dict'出错了")
  288. def test_polling(self):
  289. datasets = {'d1': d1, 'd2': d2, 'd3': d3}
  290. dl = MixDataLoader(datasets=datasets, mode='polling', collate_fn='auto', batch_size=18)
  291. for idx, batch in enumerate(dl):
  292. if idx == 0 or idx == 3:
  293. assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
  294. assert batch['x'].shape[1] == 4
  295. elif idx == 1 or idx == 4:
  296. # d2
  297. assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
  298. assert batch['x'].shape[1] == 3
  299. elif idx == 2 or 4 < idx <= 20:
  300. assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
  301. assert batch['x'].shape[1] == 4
  302. if idx > 20:
  303. raise ValueError(f"out of range")
  304. test_pad_val(batch['x'], val=0)
  305. # collate_fn = Callable
  306. def collate_batch(batch):
  307. new_batch = {'x': [], 'y': []}
  308. for ins in batch:
  309. new_batch['x'].append(ins['x'])
  310. new_batch['y'].append(ins['y'])
  311. return new_batch
  312. dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_batch, batch_size=18)
  313. for idx, batch in enumerate(dl1):
  314. if idx == 0 or idx == 3:
  315. assert batch['x'][:3] == [[1, 2], [2, 3, 4], [4, 5, 6, 7]]
  316. elif idx == 1 or idx == 4:
  317. # d2
  318. assert batch['x'][:3] == [[101, 201], [201, 301, 401], [100]]
  319. elif idx == 2 or 4 < idx <= 20:
  320. assert batch['x'][:3] == [[1000, 2000], [0], [2000, 3000, 4000, 5000]]
  321. if idx > 20:
  322. raise ValueError(f"out of range")
  323. collate_fns = {'d1': Collator(backend='auto').set_pad("x", -1),
  324. 'd2': Collator(backend='auto').set_pad("x", -2),
  325. 'd3': Collator(backend='auto').set_pad("x", -3)}
  326. dl1 = MixDataLoader(datasets=datasets, mode='polling', collate_fn=collate_fns, batch_size=18)
  327. for idx, batch in enumerate(dl1):
  328. if idx == 0 or idx == 3:
  329. assert test_pad_val(batch['x'], val=-1)
  330. assert batch['x'][:3].tolist() == [[1, 2, -1, -1], [2, 3, 4, -1], [4, 5, 6, 7]]
  331. assert batch['x'].shape[1] == 4
  332. elif idx == 1 or idx == 4:
  333. # d2
  334. assert test_pad_val(batch['x'], val=-2)
  335. assert batch['x'][:3].tolist() == [[101, 201, -2], [201, 301, 401], [100, -2, -2]]
  336. assert batch['x'].shape[1] == 3
  337. elif idx == 2 or 4 < idx <= 20:
  338. assert test_pad_val(batch['x'], val=-3)
  339. assert batch['x'][:3].tolist() == [[1000, 2000, -3, -3], [0, -3, -3, -3], [2000, 3000, 4000, 5000]]
  340. assert batch['x'].shape[1] == 4
  341. if idx > 20:
  342. raise ValueError(f"out of range")
  343. # sampler 为 str
  344. dl2 = MixDataLoader(datasets=datasets, mode='polling', sampler='seq', batch_size=18)
  345. dl3 = MixDataLoader(datasets=datasets, mode='polling', sampler='rand', batch_size=18)
  346. for idx, batch in enumerate(dl2):
  347. if idx == 0 or idx == 3:
  348. assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
  349. assert batch['x'].shape[1] == 4
  350. elif idx == 1 or idx == 4:
  351. # d2
  352. assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
  353. assert batch['x'].shape[1] == 3
  354. elif idx == 2 or 4 < idx <= 20:
  355. assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
  356. assert batch['x'].shape[1] == 4
  357. if idx > 20:
  358. raise ValueError(f"out of range")
  359. test_pad_val(batch['x'], val=0)
  360. for idx, batch in enumerate(dl3):
  361. if idx == 0 or idx == 3:
  362. assert batch['x'].shape[1] == 4
  363. elif idx == 1 or idx == 4:
  364. # d2
  365. assert batch['x'].shape[1] == 3
  366. elif idx == 2 or 4 < idx <= 20:
  367. assert batch['x'].shape[1] == 4
  368. if idx > 20:
  369. raise ValueError(f"out of range")
  370. test_pad_val(batch['x'], val=0)
  371. # sampler 为 Dict
  372. samplers = {'d1': SequentialSampler(d1),
  373. 'd2': SequentialSampler(d2),
  374. 'd3': RandomSampler(d3)}
  375. dl4 = MixDataLoader(datasets=datasets, mode='polling', sampler=samplers, batch_size=18)
  376. for idx, batch in enumerate(dl4):
  377. if idx == 0 or idx == 3:
  378. assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
  379. assert batch['x'].shape[1] == 4
  380. elif idx == 1 or idx == 4:
  381. # d2
  382. assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
  383. assert batch['x'].shape[1] == 3
  384. elif idx == 2 or 4 < idx <= 20:
  385. assert batch['x'].shape[1] == 4
  386. if idx > 20:
  387. raise ValueError(f"out of range")
  388. test_pad_val(batch['x'], val=0)
  389. # ds_ratio 为 'truncate_to_least'
  390. dl5 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='truncate_to_least', batch_size=18)
  391. for idx, batch in enumerate(dl5):
  392. if idx == 0 or idx == 3:
  393. assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
  394. assert batch['x'].shape[1] == 4
  395. elif idx == 1 or idx == 4:
  396. # d2
  397. assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
  398. assert batch['x'].shape[1] == 3
  399. elif idx == 2 or idx == 5:
  400. assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
  401. assert batch['x'].shape[1] == 4
  402. if idx > 5:
  403. raise ValueError(f"out of range")
  404. test_pad_val(batch['x'], val=0)
  405. # ds_ratio 为 'pad_to_most'
  406. dl6 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio='pad_to_most', batch_size=18)
  407. for idx, batch in enumerate(dl6):
  408. if idx % 3 == 0:
  409. # d1
  410. assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
  411. assert batch['x'].shape[1] == 4
  412. if idx % 3 == 1:
  413. # d2
  414. assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
  415. assert batch['x'].shape[1] == 3
  416. if idx % 3 == 2:
  417. # d3
  418. assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
  419. assert batch['x'].shape[1] == 4
  420. if idx >= 51:
  421. raise ValueError(f"out of range")
  422. test_pad_val(batch['x'], val=0)
  423. # ds_ratio 为 Dict[str, float]
  424. ds_ratio = {'d1': 1.0, 'd2': 2.0, 'd3': 2.0}
  425. dl7 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18)
  426. for idx, batch in enumerate(dl7):
  427. if idx == 0 or idx == 3:
  428. assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
  429. assert batch['x'].shape[1] == 4
  430. elif idx == 1 or idx == 4 or idx == 6 or idx == 8:
  431. # d2
  432. assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
  433. assert batch['x'].shape[1] == 3
  434. elif idx == 2 or idx == 5 or idx == 7 or idx > 8:
  435. assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
  436. assert batch['x'].shape[1] == 4
  437. if idx > 39:
  438. raise ValueError(f"out of range")
  439. test_pad_val(batch['x'], val=0)
  440. ds_ratio = {'d1': 0.1, 'd2': 0.6, 'd3': 1.0}
  441. dl8 = MixDataLoader(datasets=datasets, mode='polling', ds_ratio=ds_ratio, batch_size=18)
  442. for idx, batch in enumerate(dl8):
  443. if idx == 0:
  444. assert batch['x'][:3].tolist() == [[1, 2, 0, 0], [2, 3, 4, 0], [4, 5, 6, 7]]
  445. assert batch['x'].shape[1] == 4
  446. elif idx == 1:
  447. # d2
  448. assert batch['x'][:3].tolist() == [[101, 201, 0], [201, 301, 401], [100, 0, 0]]
  449. assert batch['x'].shape[1] == 3
  450. elif idx > 1:
  451. assert batch['x'][:3].tolist() == [[1000, 2000, 0, 0], [0, 0, 0, 0], [2000, 3000, 4000, 5000]]
  452. assert batch['x'].shape[1] == 4
  453. if idx > 18:
  454. raise ValueError(f"out of range")
  455. test_pad_val(batch['x'], val=0)