You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_dataset.py 7.9 kB

6 years ago
6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. import os
  2. import unittest
  3. from fastNLP.core.dataset import DataSet
  4. from fastNLP.core.fieldarray import FieldArray
  5. from fastNLP.core.instance import Instance
  6. class TestDataSet(unittest.TestCase):
  7. def test_init_v1(self):
  8. ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40)
  9. self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
  10. self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
  11. self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
  12. def test_init_v2(self):
  13. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  14. self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays)
  15. self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40)
  16. self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40)
  17. def test_init_assert(self):
  18. with self.assertRaises(AssertionError):
  19. _ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100})
  20. with self.assertRaises(AssertionError):
  21. _ = DataSet([[1, 2, 3, 4]] * 10)
  22. with self.assertRaises(ValueError):
  23. _ = DataSet(0.00001)
  24. def test_append(self):
  25. dd = DataSet()
  26. for _ in range(3):
  27. dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6]))
  28. self.assertEqual(len(dd), 3)
  29. self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3)
  30. self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3)
  31. def test_add_append(self):
  32. dd = DataSet()
  33. dd.add_field("x", [[1, 2, 3]] * 10)
  34. dd.add_field("y", [[1, 2, 3, 4]] * 10)
  35. dd.add_field("z", [[5, 6]] * 10)
  36. self.assertEqual(len(dd), 10)
  37. self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3]] * 10)
  38. self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10)
  39. self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10)
  40. with self.assertRaises(RuntimeError):
  41. dd.add_field("??", [[1, 2]] * 40)
  42. def test_delete_field(self):
  43. dd = DataSet()
  44. dd.add_field("x", [[1, 2, 3]] * 10)
  45. dd.add_field("y", [[1, 2, 3, 4]] * 10)
  46. dd.delete_field("x")
  47. self.assertFalse("x" in dd.field_arrays)
  48. self.assertTrue("y" in dd.field_arrays)
  49. def test_getitem(self):
  50. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  51. ins_1, ins_0 = ds[0], ds[1]
  52. self.assertTrue(isinstance(ins_1, Instance) and isinstance(ins_0, Instance))
  53. self.assertEqual(ins_1["x"], [1, 2, 3, 4])
  54. self.assertEqual(ins_1["y"], [5, 6])
  55. self.assertEqual(ins_0["x"], [1, 2, 3, 4])
  56. self.assertEqual(ins_0["y"], [5, 6])
  57. sub_ds = ds[:10]
  58. self.assertTrue(isinstance(sub_ds, DataSet))
  59. self.assertEqual(len(sub_ds), 10)
  60. def test_get_item_error(self):
  61. with self.assertRaises(RuntimeError):
  62. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  63. _ = ds[40:]
  64. with self.assertRaises(KeyError):
  65. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  66. _ = ds["kom"]
  67. def test_len_(self):
  68. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  69. self.assertEqual(len(ds), 40)
  70. ds = DataSet()
  71. self.assertEqual(len(ds), 0)
  72. def test_apply(self):
  73. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  74. ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx")
  75. self.assertTrue("rx" in ds.field_arrays)
  76. self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1])
  77. ds.apply(lambda ins: len(ins["y"]), new_field_name="y")
  78. self.assertEqual(ds.field_arrays["y"].content[0], 2)
  79. res = ds.apply(lambda ins: len(ins["x"]))
  80. self.assertTrue(isinstance(res, list) and len(res) > 0)
  81. self.assertTrue(res[0], 4)
  82. def test_drop(self):
  83. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20})
  84. ds.drop(lambda ins: len(ins["y"]) < 3)
  85. self.assertEqual(len(ds), 20)
  86. def test_contains(self):
  87. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  88. self.assertTrue("x" in ds)
  89. self.assertTrue("y" in ds)
  90. self.assertFalse("z" in ds)
  91. def test_rename_field(self):
  92. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  93. ds.rename_field("x", "xx")
  94. self.assertTrue("xx" in ds)
  95. self.assertFalse("x" in ds)
  96. with self.assertRaises(KeyError):
  97. ds.rename_field("yyy", "oo")
  98. def test_input_target(self):
  99. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  100. ds.set_input("x")
  101. ds.set_target("y")
  102. self.assertTrue(ds.field_arrays["x"].is_input)
  103. self.assertTrue(ds.field_arrays["y"].is_target)
  104. with self.assertRaises(KeyError):
  105. ds.set_input("xxx")
  106. with self.assertRaises(KeyError):
  107. ds.set_input("yyy")
  108. def test_get_input_name(self):
  109. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  110. self.assertEqual(ds.get_input_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_input])
  111. def test_get_target_name(self):
  112. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  113. self.assertEqual(ds.get_target_name(), [_ for _ in ds.field_arrays if ds.field_arrays[_].is_target])
  114. def test_apply2(self):
  115. def split_sent(ins):
  116. return ins['raw_sentence'].split()
  117. dataset = DataSet.read_csv('test/data_for_tests/tutorial_sample_dataset.csv', headers=('raw_sentence', 'label'),
  118. sep='\t')
  119. dataset.drop(lambda x: len(x['raw_sentence'].split()) == 0)
  120. dataset.apply(split_sent, new_field_name='words', is_input=True)
  121. # print(dataset)
  122. def test_add_field(self):
  123. ds = DataSet({"x": [3, 4]})
  124. ds.add_field('y', [['hello', 'world'], ['this', 'is', 'a', 'test']], is_input=True, is_target=True)
  125. # ds.apply(lambda x:[x['x']]*3, is_input=True, is_target=True, new_field_name='y')
  126. print(ds)
  127. def test_save_load(self):
  128. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  129. ds.save("./my_ds.pkl")
  130. self.assertTrue(os.path.exists("./my_ds.pkl"))
  131. ds_1 = DataSet.load("./my_ds.pkl")
  132. os.remove("my_ds.pkl")
  133. def test_get_all_fields(self):
  134. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  135. ans = ds.get_all_fields()
  136. self.assertEqual(ans["x"].content, [[1, 2, 3, 4]] * 10)
  137. self.assertEqual(ans["y"].content, [[5, 6]] * 10)
  138. def test_get_field(self):
  139. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  140. ans = ds.get_field("x")
  141. self.assertTrue(isinstance(ans, FieldArray))
  142. self.assertEqual(ans.content, [[1, 2, 3, 4]] * 10)
  143. ans = ds.get_field("y")
  144. self.assertTrue(isinstance(ans, FieldArray))
  145. self.assertEqual(ans.content, [[5, 6]] * 10)
  146. def test_reader(self):
  147. # 跑通即可
  148. ds = DataSet().read_naive("test/data_for_tests/tutorial_sample_dataset.csv")
  149. self.assertTrue(isinstance(ds, DataSet))
  150. self.assertTrue(len(ds) > 0)
  151. ds = DataSet().read_rawdata("test/data_for_tests/people_daily_raw.txt")
  152. self.assertTrue(isinstance(ds, DataSet))
  153. self.assertTrue(len(ds) > 0)
  154. ds = DataSet().read_pos("test/data_for_tests/people.txt")
  155. self.assertTrue(isinstance(ds, DataSet))
  156. self.assertTrue(len(ds) > 0)
  157. class TestDataSetIter(unittest.TestCase):
  158. def test__repr__(self):
  159. ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10})
  160. for iter in ds:
  161. self.assertEqual(iter.__repr__(), "{'x': [1, 2, 3, 4] type=list,\n'y': [5, 6] type=list}")