You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_matching_loader.py 2.1 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import unittest
  2. import os
  3. from fastNLP.io import DataBundle
  4. from fastNLP.io.loader.matching import RTELoader, QNLILoader, SNLILoader, QuoraLoader, MNLILoader, \
  5. BQCorpusLoader, CNXNLILoader, LCQMCLoader
  6. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  7. class TestMatchingDownload(unittest.TestCase):
  8. def test_download(self):
  9. for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]:
  10. loader().download()
  11. with self.assertRaises(Exception):
  12. QuoraLoader().load()
  13. def test_load(self):
  14. for loader in [RTELoader, QNLILoader, SNLILoader, MNLILoader]:
  15. data_bundle = loader().load()
  16. print(data_bundle)
  17. class TestMatchingLoad(unittest.TestCase):
  18. def test_load(self):
  19. data_set_dict = {
  20. 'RTE': ('test/data_for_tests/io/RTE', RTELoader, (5, 5, 5), True),
  21. 'SNLI': ('test/data_for_tests/io/SNLI', SNLILoader, (5, 5, 5), False),
  22. 'QNLI': ('test/data_for_tests/io/QNLI', QNLILoader, (5, 5, 5), True),
  23. 'MNLI': ('test/data_for_tests/io/MNLI', MNLILoader, (5, 5, 5, 5, 6), True),
  24. 'Quora': ('test/data_for_tests/io/Quora', QuoraLoader, (2, 2, 2), False),
  25. 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusLoader, (5, 5, 5), False),
  26. 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLILoader, (6, 8, 6), False),
  27. 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCLoader, (5, 6, 6), False),
  28. }
  29. for k, v in data_set_dict.items():
  30. path, loader, instance, warns = v
  31. if warns:
  32. with self.assertWarns(Warning):
  33. data_bundle = loader().load(path)
  34. else:
  35. data_bundle = loader().load(path)
  36. self.assertTrue(isinstance(data_bundle, DataBundle))
  37. self.assertEqual(len(instance), data_bundle.num_dataset)
  38. for x, y in zip(instance, data_bundle.iter_datasets()):
  39. name, dataset = y
  40. self.assertEqual(x, len(dataset))