You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_fieldarray.py 8.1 kB


  1. import unittest
  2. import numpy as np
  3. from fastNLP.core.fieldarray import FieldArray
  4. class TestFieldArrayInit(unittest.TestCase):
  5. """
  6. 1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray:
  7. 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]})
  8. 1.2) 二维array DataSet({"x": np.array([[1, 2], [3, 4]])})
  9. 1.3) 三维list DataSet({"x": [[[1, 2], [3, 4]], [[1, 2], [3, 4]]]})
  10. 2) 如果DataSet使用list of Instance 初始化,那么在append中会先对第一个样本初始化FieldArray;
  11. 然后后面的样本使用FieldArray.append进行添加。
  12. 2.1) 一维list DataSet([Instance(x=[1, 2, 3, 4])])
  13. 2.2) 一维array DataSet([Instance(x=np.array([1, 2, 3, 4]))])
  14. 2.3) 二维list DataSet([Instance(x=[[1, 2], [3, 4]])])
  15. 2.4) 二维array DataSet([Instance(x=np.array([[1, 2], [3, 4]]))])
  16. """
  17. def test_init_v1(self):
  18. # 二维list
  19. fa = FieldArray("x", [[1, 2], [3, 4]] * 5, is_input=True)
  20. def test_init_v2(self):
  21. # 二维array
  22. fa = FieldArray("x", np.array([[1, 2], [3, 4]] * 5), is_input=True)
  23. def test_init_v3(self):
  24. # 三维list
  25. fa = FieldArray("x", [[[1, 2], [3, 4]], [[1, 2], [3, 4]]], is_input=True)
  26. def test_init_v7(self):
  27. # list of array
  28. fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])], is_input=True)
  29. self.assertEqual(fa.pytype, int)
  30. self.assertEqual(fa.dtype, np.int)
  31. def test_init_v4(self):
  32. # 一维list
  33. val = [1, 2, 3, 4]
  34. fa = FieldArray("x", [val], is_input=True)
  35. fa.append(val)
  36. def test_init_v5(self):
  37. # 一维array
  38. val = np.array([1, 2, 3, 4])
  39. fa = FieldArray("x", [val], is_input=True)
  40. fa.append(val)
  41. def test_init_v6(self):
  42. # 二维array
  43. val = [[1, 2], [3, 4]]
  44. fa = FieldArray("x", [val], is_input=True)
  45. fa.append(val)
  46. def test_init_v7(self):
  47. # 二维list
  48. val = np.array([[1, 2], [3, 4]])
  49. fa = FieldArray("x", [val], is_input=True)
  50. fa.append(val)
  51. class TestFieldArray(unittest.TestCase):
  52. def test_main(self):
  53. fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True)
  54. self.assertEqual(len(fa), 5)
  55. fa.append(6)
  56. self.assertEqual(len(fa), 6)
  57. self.assertEqual(fa[-1], 6)
  58. self.assertEqual(fa[0], 1)
  59. fa[-1] = 60
  60. self.assertEqual(fa[-1], 60)
  61. self.assertEqual(fa.get(0), 1)
  62. self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray))
  63. self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3])
  64. def test_type_conversion(self):
  65. fa = FieldArray("x", [1.2, 2.2, 3, 4, 5], is_input=True)
  66. self.assertEqual(fa.pytype, float)
  67. self.assertEqual(fa.dtype, np.float64)
  68. fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True)
  69. fa.append(1.3333)
  70. self.assertEqual(fa.pytype, float)
  71. self.assertEqual(fa.dtype, np.float64)
  72. fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True)
  73. fa.append(10)
  74. self.assertEqual(fa.pytype, float)
  75. self.assertEqual(fa.dtype, np.float64)
  76. fa = FieldArray("y", ["a", "b", "c", "d"], is_input=True)
  77. fa.append("e")
  78. self.assertEqual(fa.dtype, np.str)
  79. self.assertEqual(fa.pytype, str)
  80. def test_support_np_array(self):
  81. fa = FieldArray("y", np.array([[1.1, 2.2, 3.3, 4.4, 5.5]]), is_input=True)
  82. self.assertEqual(fa.dtype, np.float64)
  83. self.assertEqual(fa.pytype, float)
  84. fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5]))
  85. self.assertEqual(fa.dtype, np.float64)
  86. self.assertEqual(fa.pytype, float)
  87. fa = FieldArray("my_field", np.random.rand(3, 5), is_input=True)
  88. # in this case, pytype is actually a float. We do not care about it.
  89. self.assertEqual(fa.dtype, np.float64)
  90. def test_nested_list(self):
  91. fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=True)
  92. self.assertEqual(fa.pytype, float)
  93. self.assertEqual(fa.dtype, np.float64)
  94. def test_getitem_v1(self):
  95. fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
  96. self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5])
  97. ans = fa[[0, 1]]
  98. self.assertTrue(isinstance(ans, np.ndarray))
  99. self.assertTrue(isinstance(ans[0], np.ndarray))
  100. self.assertEqual(ans[0].tolist(), [1.1, 2.2, 3.3, 4.4, 5.5])
  101. self.assertEqual(ans[1].tolist(), [1, 2, 3, 4, 5])
  102. self.assertEqual(ans.dtype, np.float64)
  103. def test_getitem_v2(self):
  104. x = np.random.rand(10, 5)
  105. fa = FieldArray("my_field", x, is_input=True)
  106. indices = [0, 1, 3, 4, 6]
  107. for a, b in zip(fa[indices], x[indices]):
  108. self.assertListEqual(a.tolist(), b.tolist())
  109. def test_append(self):
  110. with self.assertRaises(Exception):
  111. fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
  112. fa.append(0)
  113. with self.assertRaises(Exception):
  114. fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True)
  115. fa.append([1, 2, 3, 4, 5])
  116. with self.assertRaises(Exception):
  117. fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
  118. fa.append([])
  119. with self.assertRaises(Exception):
  120. fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
  121. fa.append(["str", 0, 0, 0, 1.89])
  122. fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
  123. fa.append([1.2, 2.3, 3.4, 4.5, 5.6])
  124. self.assertEqual(len(fa), 3)
  125. self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6])
  126. class TestPadder(unittest.TestCase):
  127. def test01(self):
  128. """
  129. 测试AutoPadder能否正常工作
  130. :return:
  131. """
  132. from fastNLP.core.fieldarray import AutoPadder
  133. padder = AutoPadder()
  134. content = ['This is a str', 'this is another str']
  135. self.assertListEqual(content, padder(content, None, np.str).tolist())
  136. content = [1, 2]
  137. self.assertListEqual(content, padder(content, None, np.int64).tolist())
  138. content = [[1,2], [3], [4]]
  139. self.assertListEqual([[1,2], [3, 0], [4, 0]],
  140. padder(content, None, np.int64).tolist())
  141. content = [
  142. [[1, 2, 3], [4, 5], [7,8,9,10]],
  143. [[1]]
  144. ]
  145. self.assertListEqual(content,
  146. padder(content, None, np.int64).tolist())
  147. def test02(self):
  148. """
  149. 测试EngChar2DPadder能不能正确使用
  150. :return:
  151. """
  152. from fastNLP.core.fieldarray import EngChar2DPadder
  153. padder = EngChar2DPadder(pad_length=0)
  154. contents = [1, 2]
  155. # 不能是1维
  156. with self.assertRaises(ValueError):
  157. padder(contents, None, np.int64)
  158. contents = [[1, 2]]
  159. # 不能是2维
  160. with self.assertRaises(ValueError):
  161. padder(contents, None, np.int64)
  162. contents = [[[[1, 2]]]]
  163. # 不能是3维以上
  164. with self.assertRaises(ValueError):
  165. padder(contents, None, np.int64)
  166. contents = [
  167. [[1, 2, 3], [4, 5], [7,8,9,10]],
  168. [[1]]
  169. ]
  170. self.assertListEqual([[[1, 2, 3, 0], [4, 5, 0, 0], [7, 8, 9, 10]], [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]],
  171. padder(contents, None, np.int64).tolist())
  172. padder = EngChar2DPadder(pad_length=5, pad_val=-100)
  173. self.assertListEqual(
  174. [[[1, 2, 3, -100, -100], [4, 5, -100, -100, -100], [7, 8, 9, 10, -100]],
  175. [[1, -100, -100, -100, -100], [-100, -100, -100, -100, -100], [-100, -100, -100, -100, -100]]],
  176. padder(contents, None, np.int64).tolist()
  177. )