You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_processor.py 4.3 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101
  1. import random
  2. import unittest
  3. import numpy as np
  4. from fastNLP import Vocabulary, Instance
  5. from fastNLP.api.processor import FullSpaceToHalfSpaceProcessor, PreAppendProcessor, SliceProcessor, Num2TagProcessor, \
  6. IndexerProcessor, VocabProcessor, SeqLenProcessor, ModelProcessor, Index2WordProcessor, SetTargetProcessor, \
  7. SetInputProcessor, VocabIndexerProcessor
  8. from fastNLP.core.dataset import DataSet
  9. class TestProcessor(unittest.TestCase):
  10. def test_FullSpaceToHalfSpaceProcessor(self):
  11. ds = DataSet({"word": ["00, u1, u), (u2, u2"]})
  12. proc = FullSpaceToHalfSpaceProcessor("word")
  13. ds = proc(ds)
  14. self.assertEqual(ds.field_arrays["word"].content, ["00, u1, u), (u2, u2"])
  15. def test_PreAppendProcessor(self):
  16. ds = DataSet({"word": [["1234", "3456"], ["8789", "3464"]]})
  17. proc = PreAppendProcessor(data="abc", field_name="word")
  18. ds = proc(ds)
  19. self.assertEqual(ds.field_arrays["word"].content, [["abc", "1234", "3456"], ["abc", "8789", "3464"]])
  20. def test_SliceProcessor(self):
  21. ds = DataSet({"xx": [[random.randint(0, 10) for _ in range(30)]] * 40})
  22. proc = SliceProcessor(10, 20, 2, "xx", new_added_field_name="yy")
  23. ds = proc(ds)
  24. self.assertEqual(len(ds.field_arrays["yy"].content[0]), 5)
  25. def test_Num2TagProcessor(self):
  26. ds = DataSet({"num": [["99.9982", "2134.0"], ["0.002", "234"]]})
  27. proc = Num2TagProcessor("<num>", "num")
  28. ds = proc(ds)
  29. for data in ds.field_arrays["num"].content:
  30. for d in data:
  31. self.assertEqual(d, "<num>")
  32. def test_VocabProcessor_and_IndexerProcessor(self):
  33. ds = DataSet({"xx": [[str(random.randint(0, 10)) for _ in range(30)]] * 40})
  34. vocab_proc = VocabProcessor("xx")
  35. vocab_proc(ds)
  36. vocab = vocab_proc.vocab
  37. self.assertTrue(isinstance(vocab, Vocabulary))
  38. self.assertTrue(len(vocab) > 5)
  39. proc = IndexerProcessor(vocab, "xx", "yy")
  40. ds = proc(ds)
  41. for data in ds.field_arrays["yy"].content[0]:
  42. self.assertTrue(isinstance(data, int))
  43. def test_SeqLenProcessor(self):
  44. ds = DataSet({"xx": [[str(random.randint(0, 10)) for _ in range(30)]] * 10})
  45. proc = SeqLenProcessor("xx", "len")
  46. ds = proc(ds)
  47. for data in ds.field_arrays["len"].content:
  48. self.assertEqual(data, 30)
  49. def test_ModelProcessor(self):
  50. from fastNLP.models.cnn_text_classification import CNNText
  51. model = CNNText(100, 100, 5)
  52. ins_list = []
  53. for _ in range(64):
  54. seq_len = np.random.randint(5, 30)
  55. ins_list.append(Instance(word_seq=[np.random.randint(0, 100) for _ in range(seq_len)], seq_lens=seq_len))
  56. data_set = DataSet(ins_list)
  57. data_set.set_input("word_seq", "seq_lens")
  58. proc = ModelProcessor(model)
  59. data_set = proc(data_set)
  60. self.assertTrue("pred" in data_set)
  61. def test_Index2WordProcessor(self):
  62. vocab = Vocabulary()
  63. vocab.add_word_lst(["a", "b", "c", "d", "e"])
  64. proc = Index2WordProcessor(vocab, "tag_id", "tag")
  65. data_set = DataSet([Instance(tag_id=[np.random.randint(0, 7) for _ in range(32)])])
  66. data_set = proc(data_set)
  67. self.assertTrue("tag" in data_set)
  68. def test_SetTargetProcessor(self):
  69. proc = SetTargetProcessor("a", "b", "c")
  70. data_set = DataSet({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
  71. data_set = proc(data_set)
  72. self.assertTrue(data_set["a"].is_target)
  73. self.assertTrue(data_set["b"].is_target)
  74. self.assertTrue(data_set["c"].is_target)
  75. def test_SetInputProcessor(self):
  76. proc = SetInputProcessor("a", "b", "c")
  77. data_set = DataSet({"a": [1, 2, 3], "b": [1, 2, 3], "c": [1, 2, 3]})
  78. data_set = proc(data_set)
  79. self.assertTrue(data_set["a"].is_input)
  80. self.assertTrue(data_set["b"].is_input)
  81. self.assertTrue(data_set["c"].is_input)
  82. def test_VocabIndexerProcessor(self):
  83. proc = VocabIndexerProcessor("word_seq", "word_ids")
  84. data_set = DataSet([Instance(word_seq=["a", "b", "c", "d", "e"])])
  85. data_set = proc(data_set)
  86. self.assertTrue("word_ids" in data_set)