You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_batch.py 6.8 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. import time
  2. import unittest
  3. import numpy as np
  4. import torch
  5. from fastNLP.core.batch import Batch
  6. from fastNLP.core.dataset import DataSet
  7. from fastNLP.core.dataset import construct_dataset
  8. from fastNLP.core.instance import Instance
  9. from fastNLP.core.sampler import SequentialSampler
  10. def generate_fake_dataset(num_samples=1000):
  11. """
  12. 产生的DataSet包含以下的field {'1':[], '2':[], '3': [], '4':[]}
  13. :param num_samples: sample的数量
  14. :return:
  15. """
  16. max_len = 50
  17. min_len = 10
  18. num_features = 4
  19. data_dict = {}
  20. for i in range(num_features):
  21. data = []
  22. lengths = np.random.randint(min_len, max_len, size=(num_samples))
  23. for length in lengths:
  24. data.append(np.random.randint(100, size=length))
  25. data_dict[str(i)] = data
  26. dataset = DataSet(data_dict)
  27. for i in range(num_features):
  28. if np.random.randint(2) == 0:
  29. dataset.set_input(str(i))
  30. else:
  31. dataset.set_target(str(i))
  32. return dataset
  33. class TestCase1(unittest.TestCase):
  34. def test_simple(self):
  35. dataset = construct_dataset(
  36. [["FastNLP", "is", "the", "most", "beautiful", "tool", "in", "the", "world"] for _ in range(40)])
  37. dataset.set_target()
  38. batch = Batch(dataset, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
  39. cnt = 0
  40. for _, _ in batch:
  41. cnt += 1
  42. self.assertEqual(cnt, 10)
  43. def test_dataset_batching(self):
  44. ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40})
  45. ds.set_input("x")
  46. ds.set_target("y")
  47. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
  48. for x, y in iter:
  49. self.assertTrue(isinstance(x["x"], np.ndarray) and isinstance(y["y"], np.ndarray))
  50. self.assertEqual(len(x["x"]), 4)
  51. self.assertEqual(len(y["y"]), 4)
  52. self.assertListEqual(list(x["x"][-1]), [1, 2, 3, 4])
  53. self.assertListEqual(list(y["y"][-1]), [5, 6])
  54. def test_list_padding(self):
  55. ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10,
  56. "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10})
  57. ds.set_input("x")
  58. ds.set_target("y")
  59. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
  60. for x, y in iter:
  61. self.assertEqual(x["x"].shape, (4, 4))
  62. self.assertEqual(y["y"].shape, (4, 4))
  63. def test_numpy_padding(self):
  64. ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10),
  65. "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)})
  66. ds.set_input("x")
  67. ds.set_target("y")
  68. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=True)
  69. for x, y in iter:
  70. self.assertEqual(x["x"].shape, (4, 4))
  71. self.assertEqual(y["y"].shape, (4, 4))
  72. def test_list_to_tensor(self):
  73. ds = DataSet({"x": [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10,
  74. "y": [[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10})
  75. ds.set_input("x")
  76. ds.set_target("y")
  77. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
  78. for x, y in iter:
  79. self.assertTrue(isinstance(x["x"], torch.Tensor))
  80. self.assertEqual(tuple(x["x"].shape), (4, 4))
  81. self.assertTrue(isinstance(y["y"], torch.Tensor))
  82. self.assertEqual(tuple(y["y"].shape), (4, 4))
  83. def test_numpy_to_tensor(self):
  84. ds = DataSet({"x": np.array([[1], [1, 2], [1, 2, 3], [1, 2, 3, 4]] * 10),
  85. "y": np.array([[4, 3, 2, 1], [3, 2, 1], [2, 1], [1]] * 10)})
  86. ds.set_input("x")
  87. ds.set_target("y")
  88. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
  89. for x, y in iter:
  90. self.assertTrue(isinstance(x["x"], torch.Tensor))
  91. self.assertEqual(tuple(x["x"].shape), (4, 4))
  92. self.assertTrue(isinstance(y["y"], torch.Tensor))
  93. self.assertEqual(tuple(y["y"].shape), (4, 4))
  94. def test_list_of_list_to_tensor(self):
  95. ds = DataSet([Instance(x=[1, 2], y=[3, 4]) for _ in range(2)] +
  96. [Instance(x=[1, 2, 3, 4], y=[3, 4, 5, 6]) for _ in range(2)])
  97. ds.set_input("x")
  98. ds.set_target("y")
  99. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
  100. for x, y in iter:
  101. self.assertTrue(isinstance(x["x"], torch.Tensor))
  102. self.assertEqual(tuple(x["x"].shape), (4, 4))
  103. self.assertTrue(isinstance(y["y"], torch.Tensor))
  104. self.assertEqual(tuple(y["y"].shape), (4, 4))
  105. def test_list_of_numpy_to_tensor(self):
  106. ds = DataSet([Instance(x=np.array([1, 2]), y=np.array([3, 4])) for _ in range(2)] +
  107. [Instance(x=np.array([1, 2, 3, 4]), y=np.array([3, 4, 5, 6])) for _ in range(2)])
  108. ds.set_input("x")
  109. ds.set_target("y")
  110. iter = Batch(ds, batch_size=4, sampler=SequentialSampler(), as_numpy=False)
  111. for x, y in iter:
  112. print(x, y)
  113. def test_sequential_batch(self):
  114. batch_size = 32
  115. pause_seconds = 0.01
  116. num_samples = 1000
  117. dataset = generate_fake_dataset(num_samples)
  118. batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler())
  119. for batch_x, batch_y in batch:
  120. time.sleep(pause_seconds)
  121. """
  122. def test_multi_workers_batch(self):
  123. batch_size = 32
  124. pause_seconds = 0.01
  125. num_samples = 1000
  126. dataset = generate_fake_dataset(num_samples)
  127. num_workers = 1
  128. batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers)
  129. for batch_x, batch_y in batch:
  130. time.sleep(pause_seconds)
  131. num_workers = 2
  132. batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers)
  133. end1 = time.time()
  134. for batch_x, batch_y in batch:
  135. time.sleep(pause_seconds)
  136. """
  137. """
  138. def test_pin_memory(self):
  139. batch_size = 32
  140. pause_seconds = 0.01
  141. num_samples = 1000
  142. dataset = generate_fake_dataset(num_samples)
  143. batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), pin_memory=True)
  144. # 这里发生OOM
  145. # for batch_x, batch_y in batch:
  146. # time.sleep(pause_seconds)
  147. num_workers = 2
  148. batch = Batch(dataset, batch_size=batch_size, sampler=SequentialSampler(), num_workers=num_workers,
  149. pin_memory=True)
  150. # 这里发生OOM
  151. # for batch_x, batch_y in batch:
  152. # time.sleep(pause_seconds)
  153. """