You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_matching.py 5.0 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. import unittest
  2. import os
  3. from fastNLP.io import DataBundle
  4. from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, QuoraPipe, MNLIPipe, \
  5. CNXNLIPipe, BQCorpusPipe, LCQMCPipe
  6. from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, QuoraBertPipe, MNLIBertPipe, \
  7. CNXNLIBertPipe, BQCorpusBertPipe, LCQMCBertPipe
  8. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  9. class TestMatchingPipe(unittest.TestCase):
  10. def test_process_from_file(self):
  11. for pipe in [SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe]:
  12. with self.subTest(pipe=pipe):
  13. print(pipe)
  14. data_bundle = pipe(tokenizer='raw').process_from_file()
  15. print(data_bundle)
  16. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  17. class TestMatchingBertPipe(unittest.TestCase):
  18. def test_process_from_file(self):
  19. for pipe in [SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe]:
  20. with self.subTest(pipe=pipe):
  21. print(pipe)
  22. data_bundle = pipe(tokenizer='raw').process_from_file()
  23. print(data_bundle)
  24. class TestRunMatchingPipe(unittest.TestCase):
  25. def test_load(self):
  26. data_set_dict = {
  27. 'RTE': ('test/data_for_tests/io/RTE', RTEPipe, RTEBertPipe, (5, 5, 5), (449, 2), True),
  28. 'SNLI': ('test/data_for_tests/io/SNLI', SNLIPipe, SNLIBertPipe, (5, 5, 5), (110, 3), False),
  29. 'QNLI': ('test/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True),
  30. 'MNLI': ('test/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True),
  31. 'BQCorpus': ('test/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False),
  32. 'XNLI': ('test/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 8, 6), (39, 3), False),
  33. 'LCQMC': ('test/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (5, 6, 6), (36, 2), False),
  34. }
  35. for k, v in data_set_dict.items():
  36. path, pipe1, pipe2, data_set, vocab, warns = v
  37. if warns:
  38. with self.assertWarns(Warning):
  39. data_bundle1 = pipe1(tokenizer='raw').process_from_file(path)
  40. data_bundle2 = pipe2(tokenizer='raw').process_from_file(path)
  41. else:
  42. data_bundle1 = pipe1(tokenizer='raw').process_from_file(path)
  43. data_bundle2 = pipe2(tokenizer='raw').process_from_file(path)
  44. self.assertTrue(isinstance(data_bundle1, DataBundle))
  45. self.assertEqual(len(data_set), data_bundle1.num_dataset)
  46. print(k)
  47. print(data_bundle1)
  48. print(data_bundle2)
  49. for x, y in zip(data_set, data_bundle1.iter_datasets()):
  50. name, dataset = y
  51. self.assertEqual(x, len(dataset))
  52. self.assertEqual(len(data_set), data_bundle2.num_dataset)
  53. for x, y in zip(data_set, data_bundle2.iter_datasets()):
  54. name, dataset = y
  55. self.assertEqual(x, len(dataset))
  56. self.assertEqual(len(vocab), data_bundle1.num_vocab)
  57. for x, y in zip(vocab, data_bundle1.iter_vocabs()):
  58. name, vocabs = y
  59. self.assertEqual(x, len(vocabs))
  60. self.assertEqual(len(vocab), data_bundle2.num_vocab)
  61. for x, y in zip(vocab, data_bundle1.iter_vocabs()):
  62. name, vocabs = y
  63. self.assertEqual(x + 1 if name == 'words' else x, len(vocabs))
  64. def test_spacy(self):
  65. data_set_dict = {
  66. 'Quora': ('test/data_for_tests/io/Quora', QuoraPipe, QuoraBertPipe, (2, 2, 2), (93, 2)),
  67. }
  68. for k, v in data_set_dict.items():
  69. path, pipe1, pipe2, data_set, vocab = v
  70. data_bundle1 = pipe1(tokenizer='spacy').process_from_file(path)
  71. data_bundle2 = pipe2(tokenizer='spacy').process_from_file(path)
  72. self.assertTrue(isinstance(data_bundle1, DataBundle))
  73. self.assertEqual(len(data_set), data_bundle1.num_dataset)
  74. print(k)
  75. print(data_bundle1)
  76. print(data_bundle2)
  77. for x, y in zip(data_set, data_bundle1.iter_datasets()):
  78. name, dataset = y
  79. self.assertEqual(x, len(dataset))
  80. self.assertEqual(len(data_set), data_bundle2.num_dataset)
  81. for x, y in zip(data_set, data_bundle2.iter_datasets()):
  82. name, dataset = y
  83. self.assertEqual(x, len(dataset))
  84. self.assertEqual(len(vocab), data_bundle1.num_vocab)
  85. for x, y in zip(vocab, data_bundle1.iter_vocabs()):
  86. name, vocabs = y
  87. self.assertEqual(x, len(vocabs))
  88. self.assertEqual(len(vocab), data_bundle2.num_vocab)
  89. for x, y in zip(vocab, data_bundle1.iter_vocabs()):
  90. name, vocabs = y
  91. self.assertEqual(x + 1 if name == 'words' else x, len(vocabs))