You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_matching.py 5.2 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109
  1. import unittest
  2. import os
  3. from fastNLP.io import DataBundle
  4. from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, QuoraPipe, MNLIPipe, \
  5. CNXNLIPipe, BQCorpusPipe, LCQMCPipe
  6. from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, QuoraBertPipe, MNLIBertPipe, \
  7. CNXNLIBertPipe, BQCorpusBertPipe, LCQMCBertPipe
  8. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  9. class TestMatchingPipe(unittest.TestCase):
  10. def test_process_from_file(self):
  11. for pipe in [SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe]:
  12. with self.subTest(pipe=pipe):
  13. print(pipe)
  14. data_bundle = pipe(tokenizer='raw').process_from_file()
  15. print(data_bundle)
  16. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  17. class TestMatchingBertPipe(unittest.TestCase):
  18. def test_process_from_file(self):
  19. for pipe in [SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe]:
  20. with self.subTest(pipe=pipe):
  21. print(pipe)
  22. data_bundle = pipe(tokenizer='raw').process_from_file()
  23. print(data_bundle)
  24. class TestRunMatchingPipe(unittest.TestCase):
  25. def test_load(self):
  26. data_set_dict = {
  27. 'RTE': ('tests/data_for_tests/io/RTE', RTEPipe, RTEBertPipe, (5, 5, 5), (449, 2), True),
  28. 'SNLI': ('tests/data_for_tests/io/SNLI', SNLIPipe, SNLIBertPipe, (5, 5, 5), (110, 3), False),
  29. 'QNLI': ('tests/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True),
  30. 'MNLI': ('tests/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True),
  31. 'BQCorpus': ('tests/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False),
  32. 'XNLI': ('tests/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 6, 8), (39, 3), False),
  33. 'LCQMC': ('tests/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 5, 6), (36, 2), False),
  34. }
  35. for k, v in data_set_dict.items():
  36. path, pipe1, pipe2, data_set, vocab, warns = v
  37. if warns:
  38. with self.assertWarns(Warning):
  39. data_bundle1 = pipe1(tokenizer='raw').process_from_file(path)
  40. data_bundle2 = pipe2(tokenizer='raw').process_from_file(path)
  41. else:
  42. data_bundle1 = pipe1(tokenizer='raw').process_from_file(path)
  43. data_bundle2 = pipe2(tokenizer='raw').process_from_file(path)
  44. self.assertTrue(isinstance(data_bundle1, DataBundle))
  45. self.assertEqual(len(data_set), data_bundle1.num_dataset)
  46. print(k)
  47. print(data_bundle1)
  48. print(data_bundle2)
  49. for x, y in zip(data_set, data_bundle1.iter_datasets()):
  50. name, dataset = y
  51. with self.subTest(path=path, split=name):
  52. self.assertEqual(x, len(dataset))
  53. self.assertEqual(len(data_set), data_bundle2.num_dataset)
  54. for x, y in zip(data_set, data_bundle2.iter_datasets()):
  55. name, dataset = y
  56. self.assertEqual(x, len(dataset))
  57. self.assertEqual(len(vocab), data_bundle1.num_vocab)
  58. for x, y in zip(vocab, data_bundle1.iter_vocabs()):
  59. name, vocabs = y
  60. self.assertEqual(x, len(vocabs))
  61. self.assertEqual(len(vocab), data_bundle2.num_vocab)
  62. for x, y in zip(vocab, data_bundle1.iter_vocabs()):
  63. name, vocabs = y
  64. self.assertEqual(x + 1 if name == 'words' else x, len(vocabs))
  65. @unittest.skipIf('TRAVIS' in os.environ, "Skip in travis")
  66. def test_spacy(self):
  67. data_set_dict = {
  68. 'Quora': ('tests/data_for_tests/io/Quora', QuoraPipe, QuoraBertPipe, (2, 2, 2), (93, 2)),
  69. }
  70. for k, v in data_set_dict.items():
  71. path, pipe1, pipe2, data_set, vocab = v
  72. data_bundle1 = pipe1(tokenizer='spacy').process_from_file(path)
  73. data_bundle2 = pipe2(tokenizer='spacy').process_from_file(path)
  74. self.assertTrue(isinstance(data_bundle1, DataBundle))
  75. self.assertEqual(len(data_set), data_bundle1.num_dataset)
  76. print(k)
  77. print(data_bundle1)
  78. print(data_bundle2)
  79. for x, y in zip(data_set, data_bundle1.iter_datasets()):
  80. name, dataset = y
  81. self.assertEqual(x, len(dataset))
  82. self.assertEqual(len(data_set), data_bundle2.num_dataset)
  83. for x, y in zip(data_set, data_bundle2.iter_datasets()):
  84. name, dataset = y
  85. self.assertEqual(x, len(dataset))
  86. self.assertEqual(len(vocab), data_bundle1.num_vocab)
  87. for x, y in zip(vocab, data_bundle1.iter_vocabs()):
  88. name, vocabs = y
  89. self.assertEqual(x, len(vocabs))
  90. self.assertEqual(len(vocab), data_bundle2.num_vocab)
  91. for x, y in zip(vocab, data_bundle1.iter_vocabs()):
  92. name, vocabs = y
  93. self.assertEqual(x + 1 if name == 'words' else x, len(vocabs))