You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_classification.py 8.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. import pytest
  2. import os
  3. from fastNLP.io import DataBundle
  4. from fastNLP.io.pipe.classification import SSTPipe, SST2Pipe, IMDBPipe, YelpFullPipe, YelpPolarityPipe, \
  5. AGsNewsPipe, DBPediaPipe
  6. from fastNLP.io.pipe.classification import ChnSentiCorpPipe, THUCNewsPipe, WeiboSenti100kPipe
  7. @pytest.mark.skipif('download' not in os.environ, reason="Skip download")
  8. class TestClassificationPipe:
  9. def test_process_from_file(self):
  10. for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]:
  11. print(pipe)
  12. data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file()
  13. print(data_bundle)
  14. def test_process_from_file_proc(self, num_proc=2):
  15. for pipe in [YelpPolarityPipe, SST2Pipe, IMDBPipe, YelpFullPipe, SSTPipe]:
  16. print(pipe)
  17. data_bundle = pipe(tokenizer='raw', num_proc=num_proc).process_from_file()
  18. print(data_bundle)
  19. class TestRunPipe:
  20. def test_load(self):
  21. for pipe in [IMDBPipe]:
  22. data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file('tests/data_for_tests/io/imdb')
  23. print(data_bundle)
  24. def test_load_proc(self):
  25. for pipe in [IMDBPipe]:
  26. data_bundle = pipe(tokenizer='raw', num_proc=2).process_from_file('tests/data_for_tests/io/imdb')
  27. print(data_bundle)
  28. @pytest.mark.skipif('download' not in os.environ, reason="Skip download")
  29. class TestCNClassificationPipe:
  30. def test_process_from_file(self):
  31. for pipe in [ChnSentiCorpPipe]:
  32. data_bundle = pipe(bigrams=True, trigrams=True).process_from_file()
  33. print(data_bundle)
  34. # @pytest.mark.skipif('download' not in os.environ, reason="Skip download")
  35. class TestRunClassificationPipe:
  36. def test_process_from_file(self):
  37. data_set_dict = {
  38. 'yelp.p': ('tests/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe,
  39. {'train': 6, 'dev': 6, 'test': 6}, {'words': 1176, 'target': 2},
  40. False),
  41. 'yelp.f': ('tests/data_for_tests/io/yelp_review_full', YelpFullPipe,
  42. {'train': 6, 'dev': 6, 'test': 6}, {'words': 1166, 'target': 5},
  43. False),
  44. 'sst-2': ('tests/data_for_tests/io/SST-2', SST2Pipe,
  45. {'train': 5, 'dev': 5, 'test': 5}, {'words': 139, 'target': 2},
  46. True),
  47. 'sst': ('tests/data_for_tests/io/SST', SSTPipe,
  48. {'train': 354, 'dev': 6, 'test': 6}, {'words': 232, 'target': 5},
  49. False),
  50. 'imdb': ('tests/data_for_tests/io/imdb', IMDBPipe,
  51. {'train': 6, 'dev': 6, 'test': 6}, {'words': 1670, 'target': 2},
  52. False),
  53. 'ag': ('tests/data_for_tests/io/ag', AGsNewsPipe,
  54. {'train': 4, 'test': 5}, {'words': 257, 'target': 4},
  55. False),
  56. 'dbpedia': ('tests/data_for_tests/io/dbpedia', DBPediaPipe,
  57. {'train': 14, 'test': 5}, {'words': 496, 'target': 14},
  58. False),
  59. 'ChnSentiCorp': ('tests/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe,
  60. {'train': 6, 'dev': 6, 'test': 6},
  61. {'chars': 529, 'bigrams': 1296, 'trigrams': 1483, 'target': 2},
  62. False),
  63. 'Chn-THUCNews': ('tests/data_for_tests/io/THUCNews', THUCNewsPipe,
  64. {'train': 9, 'dev': 9, 'test': 9}, {'chars': 1864, 'target': 9},
  65. False),
  66. 'Chn-WeiboSenti100k': ('tests/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe,
  67. {'train': 6, 'dev': 6, 'test': 7}, {'chars': 452, 'target': 2},
  68. False),
  69. }
  70. for k, v in data_set_dict.items():
  71. path, pipe, data_set, vocab, warns = v
  72. if 'Chn' not in k:
  73. if warns:
  74. data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file(path)
  75. else:
  76. data_bundle = pipe(tokenizer='raw', num_proc=0).process_from_file(path)
  77. else:
  78. data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(path)
  79. assert(isinstance(data_bundle, DataBundle))
  80. assert(len(data_set) == data_bundle.num_dataset)
  81. for name, dataset in data_bundle.iter_datasets():
  82. assert(name in data_set.keys())
  83. assert(data_set[name] == len(dataset))
  84. assert(len(vocab) == data_bundle.num_vocab)
  85. for name, vocabs in data_bundle.iter_vocabs():
  86. assert(name in vocab.keys())
  87. assert(vocab[name] == len(vocabs))
  88. def test_process_from_file_proc(self):
  89. data_set_dict = {
  90. 'yelp.p': ('tests/data_for_tests/io/yelp_review_polarity', YelpPolarityPipe,
  91. {'train': 6, 'dev': 6, 'test': 6}, {'words': 1176, 'target': 2},
  92. False),
  93. 'yelp.f': ('tests/data_for_tests/io/yelp_review_full', YelpFullPipe,
  94. {'train': 6, 'dev': 6, 'test': 6}, {'words': 1166, 'target': 5},
  95. False),
  96. 'sst-2': ('tests/data_for_tests/io/SST-2', SST2Pipe,
  97. {'train': 5, 'dev': 5, 'test': 5}, {'words': 139, 'target': 2},
  98. True),
  99. 'sst': ('tests/data_for_tests/io/SST', SSTPipe,
  100. {'train': 354, 'dev': 6, 'test': 6}, {'words': 232, 'target': 5},
  101. False),
  102. 'imdb': ('tests/data_for_tests/io/imdb', IMDBPipe,
  103. {'train': 6, 'dev': 6, 'test': 6}, {'words': 1670, 'target': 2},
  104. False),
  105. 'ag': ('tests/data_for_tests/io/ag', AGsNewsPipe,
  106. {'train': 4, 'test': 5}, {'words': 257, 'target': 4},
  107. False),
  108. 'dbpedia': ('tests/data_for_tests/io/dbpedia', DBPediaPipe,
  109. {'train': 14, 'test': 5}, {'words': 496, 'target': 14},
  110. False),
  111. 'ChnSentiCorp': ('tests/data_for_tests/io/ChnSentiCorp', ChnSentiCorpPipe,
  112. {'train': 6, 'dev': 6, 'test': 6},
  113. {'chars': 529, 'bigrams': 1296, 'trigrams': 1483, 'target': 2},
  114. False),
  115. 'Chn-THUCNews': ('tests/data_for_tests/io/THUCNews', THUCNewsPipe,
  116. {'train': 9, 'dev': 9, 'test': 9}, {'chars': 1864, 'target': 9},
  117. False),
  118. 'Chn-WeiboSenti100k': ('tests/data_for_tests/io/WeiboSenti100k', WeiboSenti100kPipe,
  119. {'train': 6, 'dev': 6, 'test': 7}, {'chars': 452, 'target': 2},
  120. False),
  121. }
  122. for k, v in data_set_dict.items():
  123. path, pipe, data_set, vocab, warns = v
  124. if 'Chn' not in k:
  125. if warns:
  126. data_bundle = pipe(tokenizer='raw', num_proc=2).process_from_file(path)
  127. else:
  128. data_bundle = pipe(tokenizer='raw', num_proc=2).process_from_file(path)
  129. else:
  130. # if k == 'ChnSentiCorp':
  131. # data_bundle = pipe(bigrams=True, trigrams=True).process_from_file(path)
  132. # else:
  133. data_bundle = pipe(bigrams=True, trigrams=True, num_proc=2).process_from_file(path)
  134. assert(isinstance(data_bundle, DataBundle))
  135. assert(len(data_set) == data_bundle.num_dataset)
  136. for name, dataset in data_bundle.iter_datasets():
  137. assert(name in data_set.keys())
  138. assert(data_set[name] == len(dataset))
  139. assert(len(vocab) == data_bundle.num_vocab)
  140. for name, vocabs in data_bundle.iter_vocabs():
  141. assert(name in vocab.keys())
  142. assert(vocab[name] == len(vocabs))