You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_fieldarray.py 5.9 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import unittest
  2. import numpy as np
  3. from fastNLP.core.fieldarray import FieldArray
  4. class TestFieldArray(unittest.TestCase):
  5. def test(self):
  6. fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True)
  7. self.assertEqual(len(fa), 5)
  8. fa.append(6)
  9. self.assertEqual(len(fa), 6)
  10. self.assertEqual(fa[-1], 6)
  11. self.assertEqual(fa[0], 1)
  12. fa[-1] = 60
  13. self.assertEqual(fa[-1], 60)
  14. self.assertEqual(fa.get(0), 1)
  15. self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray))
  16. self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3])
  17. def test_type_conversion(self):
  18. fa = FieldArray("x", [1.2, 2.2, 3, 4, 5], is_input=True)
  19. self.assertEqual(fa.pytype, float)
  20. self.assertEqual(fa.dtype, np.float64)
  21. fa = FieldArray("x", [1, 2, 3, 4, 5], is_input=True)
  22. fa.append(1.3333)
  23. self.assertEqual(fa.pytype, float)
  24. self.assertEqual(fa.dtype, np.float64)
  25. fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True)
  26. fa.append(10)
  27. self.assertEqual(fa.pytype, float)
  28. self.assertEqual(fa.dtype, np.float64)
  29. fa = FieldArray("y", ["a", "b", "c", "d"], is_input=True)
  30. fa.append("e")
  31. self.assertEqual(fa.dtype, np.str)
  32. self.assertEqual(fa.pytype, str)
  33. def test_support_np_array(self):
  34. fa = FieldArray("y", [np.array([1.1, 2.2, 3.3, 4.4, 5.5])], is_input=True)
  35. self.assertEqual(fa.dtype, np.ndarray)
  36. self.assertEqual(fa.pytype, np.ndarray)
  37. fa.append(np.array([1.1, 2.2, 3.3, 4.4, 5.5]))
  38. self.assertEqual(fa.dtype, np.ndarray)
  39. self.assertEqual(fa.pytype, np.ndarray)
  40. fa = FieldArray("my_field", np.random.rand(3, 5), is_input=True)
  41. # in this case, pytype is actually a float. We do not care about it.
  42. self.assertEqual(fa.dtype, np.float64)
  43. def test_nested_list(self):
  44. fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.1, 2.2, 3.3, 4.4, 5.5]], is_input=True)
  45. self.assertEqual(fa.pytype, float)
  46. self.assertEqual(fa.dtype, np.float64)
  47. def test_getitem_v1(self):
  48. fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
  49. self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5])
  50. ans = fa[[0, 1]]
  51. self.assertTrue(isinstance(ans, np.ndarray))
  52. self.assertTrue(isinstance(ans[0], np.ndarray))
  53. self.assertEqual(ans[0].tolist(), [1.1, 2.2, 3.3, 4.4, 5.5])
  54. self.assertEqual(ans[1].tolist(), [1, 2, 3, 4, 5])
  55. self.assertEqual(ans.dtype, np.float64)
  56. def test_getitem_v2(self):
  57. x = np.random.rand(10, 5)
  58. fa = FieldArray("my_field", x, is_input=True)
  59. indices = [0, 1, 3, 4, 6]
  60. for a, b in zip(fa[indices], x[indices]):
  61. self.assertListEqual(a.tolist(), b.tolist())
  62. def test_append(self):
  63. with self.assertRaises(Exception):
  64. fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
  65. fa.append(0)
  66. with self.assertRaises(Exception):
  67. fa = FieldArray("y", [1.1, 2.2, 3.3, 4.4, 5.5], is_input=True)
  68. fa.append([1, 2, 3, 4, 5])
  69. with self.assertRaises(Exception):
  70. fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
  71. fa.append([])
  72. with self.assertRaises(Exception):
  73. fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
  74. fa.append(["str", 0, 0, 0, 1.89])
  75. fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1, 2, 3, 4, 5]], is_input=True)
  76. fa.append([1.2, 2.3, 3.4, 4.5, 5.6])
  77. self.assertEqual(len(fa), 3)
  78. self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6])
  79. class TestPadder(unittest.TestCase):
  80. def test01(self):
  81. """
  82. 测试AutoPadder能否正常工作
  83. :return:
  84. """
  85. from fastNLP.core.fieldarray import AutoPadder
  86. padder = AutoPadder()
  87. content = ['This is a str', 'this is another str']
  88. self.assertListEqual(content, padder(content, None, np.str).tolist())
  89. content = [1, 2]
  90. self.assertListEqual(content, padder(content, None, np.int64).tolist())
  91. content = [[1,2], [3], [4]]
  92. self.assertListEqual([[1,2], [3, 0], [4, 0]],
  93. padder(content, None, np.int64).tolist())
  94. contents = [
  95. [[1, 2, 3], [4, 5], [7,8,9,10]],
  96. [[1]]
  97. ]
  98. print(padder(contents, None, np.int64))
  99. def test02(self):
  100. """
  101. 测试EngChar2DPadder能不能正确使用
  102. :return:
  103. """
  104. from fastNLP.core.fieldarray import EngChar2DPadder
  105. padder = EngChar2DPadder(pad_length=0)
  106. contents = [1, 2]
  107. # 不能是1维
  108. with self.assertRaises(ValueError):
  109. padder(contents, None, np.int64)
  110. contents = [[1, 2]]
  111. # 不能是2维
  112. with self.assertRaises(ValueError):
  113. padder(contents, None, np.int64)
  114. contents = [[[[1, 2]]]]
  115. # 不能是3维以上
  116. with self.assertRaises(ValueError):
  117. padder(contents, None, np.int64)
  118. contents = [
  119. [[1, 2, 3], [4, 5], [7,8,9,10]],
  120. [[1]]
  121. ]
  122. self.assertListEqual([[[1, 2, 3, 0], [4, 5, 0, 0], [7, 8, 9, 10]], [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 0, 0]]],
  123. padder(contents, None, np.int64).tolist())
  124. padder = EngChar2DPadder(pad_length=5, pad_val=-100)
  125. self.assertListEqual(
  126. [[[1, 2, 3, -100, -100], [4, 5, -100, -100, -100], [7, 8, 9, 10, -100]],
  127. [[1, -100, -100, -100, -100], [-100, -100, -100, -100, -100], [-100, -100, -100, -100, -100]]],
  128. padder(contents, None, np.int64).tolist()
  129. )