You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_matching.py 7.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145
  1. import pytest
  2. import os
  3. from fastNLP.io import DataBundle
  4. from fastNLP.io.pipe.matching import SNLIPipe, RTEPipe, QNLIPipe, QuoraPipe, MNLIPipe, \
  5. CNXNLIPipe, BQCorpusPipe, LCQMCPipe
  6. from fastNLP.io.pipe.matching import SNLIBertPipe, RTEBertPipe, QNLIBertPipe, QuoraBertPipe, MNLIBertPipe, \
  7. CNXNLIBertPipe, BQCorpusBertPipe, LCQMCBertPipe
  8. @pytest.mark.skipif('download' not in os.environ, reason="Skip download")
  9. class TestMatchingPipe:
  10. def test_process_from_file(self):
  11. for pipe in [SNLIPipe, RTEPipe, QNLIPipe, MNLIPipe]:
  12. print(pipe)
  13. data_bundle = pipe(tokenizer='raw').process_from_file()
  14. print(data_bundle)
  15. @pytest.mark.skipif('download' not in os.environ, reason="Skip download")
  16. class TestMatchingBertPipe:
  17. def test_process_from_file(self):
  18. for pipe in [SNLIBertPipe, RTEBertPipe, QNLIBertPipe, MNLIBertPipe]:
  19. print(pipe)
  20. data_bundle = pipe(tokenizer='raw').process_from_file()
  21. print(data_bundle)
  22. class TestRunMatchingPipe:
  23. def test_load(self):
  24. data_set_dict = {
  25. 'RTE': ('tests/data_for_tests/io/RTE', RTEPipe, RTEBertPipe, (5, 5, 5), (449, 2), True),
  26. 'SNLI': ('tests/data_for_tests/io/SNLI', SNLIPipe, SNLIBertPipe, (5, 5, 5), (110, 3), False),
  27. 'QNLI': ('tests/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True),
  28. 'MNLI': ('tests/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True),
  29. 'BQCorpus': ('tests/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False),
  30. 'XNLI': ('tests/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 6, 8), (39, 3), False),
  31. 'LCQMC': ('tests/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 5, 6), (36, 2), False),
  32. }
  33. for k, v in data_set_dict.items():
  34. path, pipe1, pipe2, data_set, vocab, warns = v
  35. if warns:
  36. data_bundle1 = pipe1(tokenizer='raw').process_from_file(path)
  37. data_bundle2 = pipe2(tokenizer='raw').process_from_file(path)
  38. else:
  39. data_bundle1 = pipe1(tokenizer='raw').process_from_file(path)
  40. data_bundle2 = pipe2(tokenizer='raw').process_from_file(path)
  41. assert(isinstance(data_bundle1, DataBundle))
  42. assert(len(data_set) == data_bundle1.num_dataset)
  43. print(k)
  44. print(data_bundle1)
  45. print(data_bundle2)
  46. for x, y in zip(data_set, data_bundle1.iter_datasets()):
  47. name, dataset = y
  48. assert(x == len(dataset))
  49. assert(len(data_set) == data_bundle2.num_dataset)
  50. for x, y in zip(data_set, data_bundle2.iter_datasets()):
  51. name, dataset = y
  52. assert(x == len(dataset))
  53. assert(len(vocab) == data_bundle1.num_vocab)
  54. for x, y in zip(vocab, data_bundle1.iter_vocabs()):
  55. name, vocabs = y
  56. assert(x == len(vocabs))
  57. assert(len(vocab) == data_bundle2.num_vocab)
  58. for x, y in zip(vocab, data_bundle1.iter_vocabs()):
  59. name, vocabs = y
  60. assert(x + 1 if name == 'words' else x == len(vocabs))
  61. def test_load_proc(self):
  62. data_set_dict = {
  63. 'RTE': ('tests/data_for_tests/io/RTE', RTEPipe, RTEBertPipe, (5, 5, 5), (449, 2), True),
  64. 'SNLI': ('tests/data_for_tests/io/SNLI', SNLIPipe, SNLIBertPipe, (5, 5, 5), (110, 3), False),
  65. 'QNLI': ('tests/data_for_tests/io/QNLI', QNLIPipe, QNLIBertPipe, (5, 5, 5), (372, 2), True),
  66. 'MNLI': ('tests/data_for_tests/io/MNLI', MNLIPipe, MNLIBertPipe, (5, 5, 5, 5, 6), (459, 3), True),
  67. 'BQCorpus': ('tests/data_for_tests/io/BQCorpus', BQCorpusPipe, BQCorpusBertPipe, (5, 5, 5), (32, 2), False),
  68. 'XNLI': ('tests/data_for_tests/io/XNLI', CNXNLIPipe, CNXNLIBertPipe, (6, 6, 8), (39, 3), False),
  69. 'LCQMC': ('tests/data_for_tests/io/LCQMC', LCQMCPipe, LCQMCBertPipe, (6, 5, 6), (36, 2), False),
  70. }
  71. for k, v in data_set_dict.items():
  72. path, pipe1, pipe2, data_set, vocab, warns = v
  73. if warns:
  74. data_bundle1 = pipe1(tokenizer='raw', num_proc=2).process_from_file(path)
  75. data_bundle2 = pipe2(tokenizer='raw', num_proc=2).process_from_file(path)
  76. else:
  77. data_bundle1 = pipe1(tokenizer='raw', num_proc=2).process_from_file(path)
  78. data_bundle2 = pipe2(tokenizer='raw', num_proc=2).process_from_file(path)
  79. assert (isinstance(data_bundle1, DataBundle))
  80. assert (len(data_set) == data_bundle1.num_dataset)
  81. print(k)
  82. print(data_bundle1)
  83. print(data_bundle2)
  84. for x, y in zip(data_set, data_bundle1.iter_datasets()):
  85. name, dataset = y
  86. assert (x == len(dataset))
  87. assert (len(data_set) == data_bundle2.num_dataset)
  88. for x, y in zip(data_set, data_bundle2.iter_datasets()):
  89. name, dataset = y
  90. assert (x == len(dataset))
  91. assert (len(vocab) == data_bundle1.num_vocab)
  92. for x, y in zip(vocab, data_bundle1.iter_vocabs()):
  93. name, vocabs = y
  94. assert (x == len(vocabs))
  95. assert (len(vocab) == data_bundle2.num_vocab)
  96. for x, y in zip(vocab, data_bundle1.iter_vocabs()):
  97. name, vocabs = y
  98. assert (x + 1 if name == 'words' else x == len(vocabs))
  99. @pytest.mark.skipif('download' not in os.environ, reason="Skip download")
  100. def test_spacy(self):
  101. data_set_dict = {
  102. 'Quora': ('tests/data_for_tests/io/Quora', QuoraPipe, QuoraBertPipe, (2, 2, 2), (93, 2)),
  103. }
  104. for k, v in data_set_dict.items():
  105. path, pipe1, pipe2, data_set, vocab = v
  106. data_bundle1 = pipe1(tokenizer='spacy').process_from_file(path)
  107. data_bundle2 = pipe2(tokenizer='spacy').process_from_file(path)
  108. assert(isinstance(data_bundle1, DataBundle))
  109. assert(len(data_set) == data_bundle1.num_dataset)
  110. print(k)
  111. print(data_bundle1)
  112. print(data_bundle2)
  113. for x, y in zip(data_set, data_bundle1.iter_datasets()):
  114. name, dataset = y
  115. assert(x == len(dataset))
  116. assert(len(data_set) == data_bundle2.num_dataset)
  117. for x, y in zip(data_set, data_bundle2.iter_datasets()):
  118. name, dataset = y
  119. assert(x == len(dataset))
  120. assert(len(vocab) == data_bundle1.num_vocab)
  121. for x, y in zip(vocab, data_bundle1.iter_vocabs()):
  122. name, vocabs = y
  123. assert(x == len(vocabs))
  124. assert(len(vocab) == data_bundle2.num_vocab)
  125. for x, y in zip(vocab, data_bundle1.iter_vocabs()):
  126. name, vocabs = y
  127. assert(x + 1 if name == 'words' else x == len(vocabs))