You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

test_callback_event.py 6.1 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. import pytest
  2. from functools import reduce
  3. from fastNLP.core.callbacks.callback_event import Event, Filter
  4. class TestFilter:
  5. def test_every_filter(self):
  6. # every = 10
  7. @Filter(every=10)
  8. def _fn(data):
  9. return data
  10. _res = []
  11. for i in range(100):
  12. cu_res = _fn(i)
  13. if cu_res is not None:
  14. _res.append(cu_res)
  15. assert _res == [w-1 for w in range(10, 101, 10)]
  16. # every = 1
  17. @Filter(every=1)
  18. def _fn(data):
  19. return data
  20. _res = []
  21. for i in range(100):
  22. cu_res = _fn(i)
  23. if cu_res is not None:
  24. _res.append(cu_res)
  25. assert _res == list(range(100))
  26. def test_once_filter(self):
  27. # once = 10
  28. @Filter(once=10)
  29. def _fn(data):
  30. return data
  31. _res = []
  32. for i in range(100):
  33. cu_res = _fn(i)
  34. if cu_res is not None:
  35. _res.append(cu_res)
  36. assert _res == [9]
  37. def test_extract_filter_from_fn(self):
  38. @Filter(every=10)
  39. def _fn(data):
  40. return data
  41. _filter_num_called = []
  42. _filter_num_executed = []
  43. for i in range(100):
  44. cu_res = _fn(i)
  45. _filter = _fn.__fastNLP_filter__
  46. _filter_num_called.append(_filter.num_called)
  47. _filter_num_executed.append(_filter.num_executed)
  48. assert _filter_num_called == list(range(1, 101))
  49. assert _filter_num_executed == [0]*9 + reduce(lambda x, y: x+y, [[w]*10 for w in range(1, 10)]) + [10]
  50. def _fn(data):
  51. return data
  52. assert not hasattr(_fn, "__fastNLP_filter__")
  53. def test_filter_state_dict(self):
  54. # every = 10
  55. @Filter(every=10)
  56. def _fn(data):
  57. return data
  58. _res = []
  59. for i in range(50):
  60. cu_res = _fn(i)
  61. if cu_res is not None:
  62. _res.append(cu_res)
  63. assert _res == [w - 1 for w in range(10, 51, 10)]
  64. # 保存状态
  65. state = _fn.__fastNLP_filter__.state_dict()
  66. # 加载状态
  67. _fn.__fastNLP_filter__.load_state_dict(state)
  68. _res = []
  69. for i in range(50, 100):
  70. cu_res = _fn(i)
  71. if cu_res is not None:
  72. _res.append(cu_res)
  73. assert _res == [w - 1 for w in range(60, 101, 10)]
  74. @pytest.mark.torch
  75. def test_filter_fn_torch():
  76. from torch.optim import SGD
  77. from torch.utils.data import DataLoader
  78. from fastNLP.core.controllers.trainer import Trainer
  79. from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
  80. from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
  81. model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
  82. optimizer = SGD(model.parameters(), lr=0.0001)
  83. dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10)
  84. dataloader = DataLoader(dataset=dataset, batch_size=4)
  85. trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer)
  86. def filter_fn(filter, trainer):
  87. if trainer.__heihei_test__ == 10:
  88. return True
  89. return False
  90. @Filter(filter_fn=filter_fn)
  91. def _fn(trainer, data):
  92. return data
  93. _res = []
  94. for i in range(100):
  95. trainer.__heihei_test__ = i
  96. cu_res = _fn(trainer, i)
  97. if cu_res is not None:
  98. _res.append(cu_res)
  99. assert _res == [10]
  100. class TestCallbackEvents:
  101. def test_every(self):
  102. # 这里是什么样的事件是不影响的,因为我们是与 Trainer 拆分开了进行测试;
  103. event_state = Event.on_train_begin() # 什么都不输入是应当默认 every=1;
  104. @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
  105. def _fn(data):
  106. return data
  107. _res = []
  108. for i in range(100):
  109. cu_res = _fn(i)
  110. if cu_res is not None:
  111. _res.append(cu_res)
  112. assert _res == list(range(100))
  113. event_state = Event.on_train_begin(every=10)
  114. @Filter(every=event_state.every, once=event_state.once, filter_fn=event_state.filter_fn)
  115. def _fn(data):
  116. return data
  117. _res = []
  118. for i in range(100):
  119. cu_res = _fn(i)
  120. if cu_res is not None:
  121. _res.append(cu_res)
  122. assert _res == [w - 1 for w in range(10, 101, 10)]
  123. def test_once(self):
  124. event_state = Event.on_train_begin(once=10)
  125. @Filter(once=event_state.once)
  126. def _fn(data):
  127. return data
  128. _res = []
  129. for i in range(100):
  130. cu_res = _fn(i)
  131. if cu_res is not None:
  132. _res.append(cu_res)
  133. assert _res == [9]
  134. @pytest.mark.torch
  135. def test_callback_events_torch():
  136. from torch.optim import SGD
  137. from torch.utils.data import DataLoader
  138. from fastNLP.core.controllers.trainer import Trainer
  139. from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
  140. from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
  141. model = TorchNormalModel_Classification_1(num_labels=3, feature_dimension=10)
  142. optimizer = SGD(model.parameters(), lr=0.0001)
  143. dataset = TorchNormalDataset_Classification(num_labels=3, feature_dimension=10)
  144. dataloader = DataLoader(dataset=dataset, batch_size=4)
  145. trainer = Trainer(model=model, driver="torch", device="cpu", train_dataloader=dataloader, optimizers=optimizer)
  146. def filter_fn(filter, trainer):
  147. if trainer.__heihei_test__ == 10:
  148. return True
  149. return False
  150. event_state = Event.on_train_begin(filter_fn=filter_fn)
  151. @Filter(filter_fn=event_state.filter_fn)
  152. def _fn(trainer, data):
  153. return data
  154. _res = []
  155. for i in range(100):
  156. trainer.__heihei_test__ = i
  157. cu_res = _fn(trainer, i)
  158. if cu_res is not None:
  159. _res.append(cu_res)
  160. assert _res == [10]