You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_trainer.py 8.8 kB

6 years ago
6 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268
  1. import time
  2. import unittest
  3. import numpy as np
  4. import torch.nn.functional as F
  5. from torch import nn
  6. from fastNLP.core.dataset import DataSet
  7. from fastNLP.core.instance import Instance
  8. from fastNLP.core.losses import BCELoss
  9. from fastNLP.core.losses import CrossEntropyLoss
  10. from fastNLP.core.metrics import AccuracyMetric
  11. from fastNLP.core.optimizer import SGD
  12. from fastNLP.core.trainer import Trainer
  13. from fastNLP.models.base_model import NaiveClassifier
  14. def prepare_fake_dataset():
  15. mean = np.array([-3, -3])
  16. cov = np.array([[1, 0], [0, 1]])
  17. class_A = np.random.multivariate_normal(mean, cov, size=(1000,))
  18. mean = np.array([3, 3])
  19. cov = np.array([[1, 0], [0, 1]])
  20. class_B = np.random.multivariate_normal(mean, cov, size=(1000,))
  21. data_set = DataSet([Instance(x=[float(item[0]), float(item[1])], y=[0.0]) for item in class_A] +
  22. [Instance(x=[float(item[0]), float(item[1])], y=[1.0]) for item in class_B])
  23. return data_set
  24. def prepare_fake_dataset2(*args, size=100):
  25. ys = np.random.randint(4, size=100, dtype=np.int64)
  26. data = {'y': ys}
  27. for arg in args:
  28. data[arg] = np.random.randn(size, 5)
  29. return DataSet(data=data)
  30. class TrainerTestGround(unittest.TestCase):
  31. def test_case(self):
  32. data_set = prepare_fake_dataset()
  33. data_set.set_input("x", flag=True)
  34. data_set.set_target("y", flag=True)
  35. train_set, dev_set = data_set.split(0.3)
  36. model = NaiveClassifier(2, 1)
  37. trainer = Trainer(train_set, model,
  38. loss=BCELoss(pred="predict", target="y"),
  39. metrics=AccuracyMetric(pred="predict", target="y"),
  40. n_epochs=10,
  41. batch_size=32,
  42. print_every=50,
  43. validate_every=-1,
  44. dev_data=dev_set,
  45. optimizer=SGD(lr=0.1),
  46. check_code_level=2,
  47. use_tqdm=True,
  48. save_path=None)
  49. trainer.train()
  50. """
  51. # 应该正确运行
  52. """
  53. def test_trainer_suggestion1(self):
  54. # 检查报错提示能否正确提醒用户。
  55. # 这里没有传入forward需要的数据。需要trainer提醒用户如何设置。
  56. dataset = prepare_fake_dataset2('x')
  57. class Model(nn.Module):
  58. def __init__(self):
  59. super().__init__()
  60. self.fc = nn.Linear(5, 4)
  61. def forward(self, x1, x2, y):
  62. x1 = self.fc(x1)
  63. x2 = self.fc(x2)
  64. x = x1 + x2
  65. loss = F.cross_entropy(x, y)
  66. return {'loss': loss}
  67. model = Model()
  68. with self.assertRaises(RuntimeError):
  69. trainer = Trainer(
  70. train_data=dataset,
  71. model=model
  72. )
  73. """
  74. # 应该获取到的报错提示
  75. NameError:
  76. The following problems occurred when calling Model.forward(self, x1, x2, y)
  77. missing param: ['y', 'x1', 'x2']
  78. Suggestion: (1). You might need to set ['y'] as input.
  79. (2). You need to provide ['x1', 'x2'] in DataSet and set it as input.
  80. """
  81. def test_trainer_suggestion2(self):
  82. # 检查报错提示能否正确提醒用户
  83. # 这里传入forward需要的数据,看是否可以运行
  84. dataset = prepare_fake_dataset2('x1', 'x2')
  85. dataset.set_input('x1', 'x2', 'y', flag=True)
  86. class Model(nn.Module):
  87. def __init__(self):
  88. super().__init__()
  89. self.fc = nn.Linear(5, 4)
  90. def forward(self, x1, x2, y):
  91. x1 = self.fc(x1)
  92. x2 = self.fc(x2)
  93. x = x1 + x2
  94. loss = F.cross_entropy(x, y)
  95. return {'loss': loss}
  96. model = Model()
  97. trainer = Trainer(
  98. train_data=dataset,
  99. model=model,
  100. use_tqdm=False,
  101. print_every=2
  102. )
  103. trainer.train()
  104. """
  105. # 应该正确运行
  106. """
  107. def test_trainer_suggestion3(self):
  108. # 检查报错提示能否正确提醒用户
  109. # 这里传入forward需要的数据,但是forward没有返回loss这个key
  110. dataset = prepare_fake_dataset2('x1', 'x2')
  111. dataset.set_input('x1', 'x2', 'y', flag=True)
  112. class Model(nn.Module):
  113. def __init__(self):
  114. super().__init__()
  115. self.fc = nn.Linear(5, 4)
  116. def forward(self, x1, x2, y):
  117. x1 = self.fc(x1)
  118. x2 = self.fc(x2)
  119. x = x1 + x2
  120. loss = F.cross_entropy(x, y)
  121. return {'wrong_loss_key': loss}
  122. model = Model()
  123. with self.assertRaises(NameError):
  124. trainer = Trainer(
  125. train_data=dataset,
  126. model=model,
  127. use_tqdm=False,
  128. print_every=2
  129. )
  130. trainer.train()
  131. def test_trainer_suggestion4(self):
  132. # 检查报错提示能否正确提醒用户
  133. # 这里传入forward需要的数据,是否可以正确提示unused
  134. dataset = prepare_fake_dataset2('x1', 'x2')
  135. dataset.set_input('x1', 'x2', 'y', flag=True)
  136. class Model(nn.Module):
  137. def __init__(self):
  138. super().__init__()
  139. self.fc = nn.Linear(5, 4)
  140. def forward(self, x1, x2, y):
  141. x1 = self.fc(x1)
  142. x2 = self.fc(x2)
  143. x = x1 + x2
  144. loss = F.cross_entropy(x, y)
  145. return {'losses': loss}
  146. model = Model()
  147. with self.assertRaises(NameError):
  148. trainer = Trainer(
  149. train_data=dataset,
  150. model=model,
  151. use_tqdm=False,
  152. print_every=2
  153. )
  154. def test_trainer_suggestion5(self):
  155. # 检查报错提示能否正确提醒用户
  156. # 这里传入多余参数,让其duplicate, 但这里因为y不会被调用,所以其实不会报错
  157. dataset = prepare_fake_dataset2('x1', 'x_unused')
  158. dataset.rename_field('x_unused', 'x2')
  159. dataset.set_input('x1', 'x2', 'y')
  160. dataset.set_target('y')
  161. class Model(nn.Module):
  162. def __init__(self):
  163. super().__init__()
  164. self.fc = nn.Linear(5, 4)
  165. def forward(self, x1, x2, y):
  166. x1 = self.fc(x1)
  167. x2 = self.fc(x2)
  168. x = x1 + x2
  169. loss = F.cross_entropy(x, y)
  170. return {'loss': loss}
  171. model = Model()
  172. trainer = Trainer(
  173. train_data=dataset,
  174. model=model,
  175. use_tqdm=False,
  176. print_every=2
  177. )
  178. def test_trainer_suggestion6(self):
  179. # 检查报错提示能否正确提醒用户
  180. # 这里传入多余参数,让其duplicate
  181. dataset = prepare_fake_dataset2('x1', 'x_unused')
  182. dataset.rename_field('x_unused', 'x2')
  183. dataset.set_input('x1', 'x2')
  184. dataset.set_target('y', 'x1')
  185. class Model(nn.Module):
  186. def __init__(self):
  187. super().__init__()
  188. self.fc = nn.Linear(5, 4)
  189. def forward(self, x1, x2):
  190. x1 = self.fc(x1)
  191. x2 = self.fc(x2)
  192. x = x1 + x2
  193. time.sleep(0.1)
  194. # loss = F.cross_entropy(x, y)
  195. return {'preds': x}
  196. model = Model()
  197. with self.assertRaises(NameError):
  198. trainer = Trainer(
  199. train_data=dataset,
  200. model=model,
  201. dev_data=dataset,
  202. loss=CrossEntropyLoss(),
  203. metrics=AccuracyMetric(),
  204. use_tqdm=False,
  205. print_every=2)
  206. """
  207. def test_trainer_multiprocess(self):
  208. dataset = prepare_fake_dataset2('x1', 'x2')
  209. dataset.set_input('x1', 'x2', 'y', flag=True)
  210. class Model(nn.Module):
  211. def __init__(self):
  212. super().__init__()
  213. self.fc = nn.Linear(5, 4)
  214. def forward(self, x1, x2, y):
  215. x1 = self.fc(x1)
  216. x2 = self.fc(x2)
  217. x = x1 + x2
  218. loss = F.cross_entropy(x, y)
  219. return {'loss': loss}
  220. model = Model()
  221. trainer = Trainer(
  222. train_data=dataset,
  223. model=model,
  224. use_tqdm=True,
  225. print_every=2,
  226. num_workers=2,
  227. pin_memory=False,
  228. timeout=0,
  229. )
  230. trainer.train()
  231. """