You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_prepare_dataloader.py 867 B

12345678910111213141516171819202122232425262728
  1. import pytest
  2. from fastNLP import prepare_dataloader
  3. from fastNLP import DataSet
  4. from fastNLP.io import DataBundle
  5. @pytest.mark.torch
  6. def test_torch():
  7. import torch
  8. ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
  9. dl = prepare_dataloader(ds, batch_size=2, shuffle=True)
  10. for batch in dl:
  11. assert isinstance(batch['x'], torch.Tensor)
  12. @pytest.mark.torch
  13. def test_torch_data_bundle():
  14. import torch
  15. ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10})
  16. dl = DataBundle()
  17. dl.set_dataset(dataset=ds, name='train')
  18. dl.set_dataset(dataset=ds, name='test')
  19. dls = prepare_dataloader(dl, batch_size=2, shuffle=True)
  20. for dl in dls.values():
  21. for batch in dl:
  22. assert isinstance(batch['x'], torch.Tensor)
  23. assert batch['x'].size(0) == 2