You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_utils.py 2.5 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. from functools import reduce
  2. from fastNLP.core.controllers.utils.utils import _TruncatedDataLoader # TODO: 该类修改过,记得将 test 也修改;
  3. from tests.helpers.datasets.normal_data import NormalIterator
  4. class Test_WrapDataLoader:
  5. def test_normal_generator(self):
  6. all_sanity_batches = [4, 20, 100]
  7. for sanity_batches in all_sanity_batches:
  8. data = NormalIterator(num_of_data=1000)
  9. wrapper = _TruncatedDataLoader(num_batches=sanity_batches)
  10. dataloader = iter(wrapper(dataloader=data))
  11. mark = 0
  12. while True:
  13. try:
  14. _data = next(dataloader)
  15. except StopIteration:
  16. break
  17. mark += 1
  18. assert mark == sanity_batches
  19. def test_torch_dataloader(self):
  20. from tests.helpers.datasets.torch_data import TorchNormalDataset
  21. from torch.utils.data import DataLoader
  22. bses = [8, 16, 40]
  23. all_sanity_batches = [4, 7, 10]
  24. for bs in bses:
  25. for sanity_batches in all_sanity_batches:
  26. dataset = TorchNormalDataset(num_of_data=1000)
  27. dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
  28. wrapper = _TruncatedDataLoader(num_batches=sanity_batches)
  29. dataloader = wrapper(dataloader)
  30. dataloader = iter(dataloader)
  31. all_supposed_running_data_num = 0
  32. while True:
  33. try:
  34. _data = next(dataloader)
  35. except StopIteration:
  36. break
  37. all_supposed_running_data_num += _data.shape[0]
  38. assert all_supposed_running_data_num == bs * sanity_batches
  39. def test_len(self):
  40. from tests.helpers.datasets.torch_data import TorchNormalDataset
  41. from torch.utils.data import DataLoader
  42. bses = [8, 16, 40]
  43. all_sanity_batches = [4, 7, 10]
  44. length = []
  45. for bs in bses:
  46. for sanity_batches in all_sanity_batches:
  47. dataset = TorchNormalDataset(num_of_data=1000)
  48. dataloader = DataLoader(dataset, batch_size=bs, shuffle=True)
  49. wrapper = _TruncatedDataLoader(num_batches=sanity_batches)
  50. dataloader = wrapper(dataloader)
  51. length.append(len(dataloader))
  52. assert length == reduce(lambda x, y: x+y, [all_sanity_batches for _ in range(len(bses))])