You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_classification_loader.py 2.8 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. import unittest
  2. import os
  3. from fastNLP.io import DataBundle
  4. from fastNLP.io.loader.classification import YelpFullLoader, YelpPolarityLoader, IMDBLoader, \
  5. SSTLoader, SST2Loader, ChnSentiCorpLoader, THUCNewsLoader, WeiboSenti100kLoader, \
  6. MRLoader, R8Loader, R52Loader, OhsumedLoader, NG20Loader
  7. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  8. class TestDownload(unittest.TestCase):
  9. def test_download(self):
  10. for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader, ChnSentiCorpLoader]:
  11. loader().download()
  12. def test_load(self):
  13. for loader in [YelpFullLoader, YelpPolarityLoader, IMDBLoader, SST2Loader, SSTLoader, ChnSentiCorpLoader]:
  14. data_bundle = loader().load()
  15. print(data_bundle)
  16. class TestLoad(unittest.TestCase):
  17. def test_process_from_file(self):
  18. data_set_dict = {
  19. 'yelp.p': ('tests/data_for_tests/io/yelp_review_polarity', YelpPolarityLoader, (6, 6, 6), False),
  20. 'yelp.f': ('tests/data_for_tests/io/yelp_review_full', YelpFullLoader, (6, 6, 6), False),
  21. 'sst-2': ('tests/data_for_tests/io/SST-2', SST2Loader, (5, 5, 5), True),
  22. 'sst': ('tests/data_for_tests/io/SST', SSTLoader, (6, 6, 6), False),
  23. 'imdb': ('tests/data_for_tests/io/imdb', IMDBLoader, (6, 6, 6), False),
  24. 'ChnSentiCorp': ('tests/data_for_tests/io/ChnSentiCorp', ChnSentiCorpLoader, (6, 6, 6), False),
  25. 'THUCNews': ('tests/data_for_tests/io/THUCNews', THUCNewsLoader, (9, 9, 9), False),
  26. 'WeiboSenti100k': ('tests/data_for_tests/io/WeiboSenti100k', WeiboSenti100kLoader, (6, 7, 6), False),
  27. 'mr': ('tests/data_for_tests/io/mr', MRLoader, (6, 6, 6), False),
  28. 'R8': ('tests/data_for_tests/io/R8', R8Loader, (6, 6, 6), False),
  29. 'R52': ('tests/data_for_tests/io/R52', R52Loader, (6, 6, 6), False),
  30. 'ohsumed': ('tests/data_for_tests/io/R52', OhsumedLoader, (6, 6, 6), False),
  31. '20ng': ('tests/data_for_tests/io/R52', NG20Loader, (6, 6, 6), False),
  32. }
  33. for k, v in data_set_dict.items():
  34. path, loader, data_set, warns = v
  35. with self.subTest(path=path):
  36. if warns:
  37. with self.assertWarns(Warning):
  38. data_bundle = loader().load(path)
  39. else:
  40. data_bundle = loader().load(path)
  41. self.assertTrue(isinstance(data_bundle, DataBundle))
  42. self.assertEqual(len(data_set), data_bundle.num_dataset)
  43. for x, y in zip(data_set, data_bundle.iter_datasets()):
  44. name, dataset = y
  45. with self.subTest(split=name):
  46. self.assertEqual(x, len(dataset))