You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper_callbacks.py 3.7 kB

3 years ago
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. """
  2. 该模块用于实现一些帮助我们在测试的 callback 类;
  3. """
  4. from fastNLP.core.callbacks.callback import Callback
  5. class RecordLossCallback(Callback):
  6. """
  7. 通过该 callback 来测试模型的训练是否基本正常;
  8. """
  9. def __init__(self, loss_threshold: float):
  10. self.loss = None
  11. self.loss_threshold = loss_threshold
  12. self.loss_begin_value = None
  13. def on_before_backward(self, trainer, outputs):
  14. loss = trainer.extract_loss_from_outputs(outputs)
  15. loss = trainer.driver.tensor_to_numeric(loss)
  16. self.loss = loss
  17. if self.loss_begin_value is None:
  18. self.loss_begin_value = loss
  19. def on_train_end(self, trainer):
  20. # assert self.loss < self.loss_begin_value
  21. if self.loss_threshold is not None:
  22. pass
  23. # assert self.loss < self.loss_threshold
  24. class RecordMetricCallback(Callback):
  25. """
  26. 通过该 callback 来测试带有 metrics 的 Trainer 是否训练测试正确;
  27. """
  28. def __init__(self, monitor: str, metric_threshold: float, larger_better: bool):
  29. self.monitor = monitor
  30. self.larger_better = larger_better
  31. self.metric = None
  32. self.metric_threshold = metric_threshold
  33. self.metric_begin_value = float('-inf') if larger_better else float('inf')
  34. def on_evaluate_end(self, trainer, results):
  35. self.metric = results[self.monitor]
  36. def on_train_end(self, trainer):
  37. if self.larger_better:
  38. assert self.metric >= self.metric_begin_value
  39. assert self.metric > self.metric_threshold
  40. else:
  41. assert self.metric <= self.metric_begin_value
  42. assert self.metric < self.metric_threshold
  43. class RecordTrainerEventTriggerCallback(Callback):
  44. """
  45. 测试每一个 callback 是否在 trainer 中都得到了调用;
  46. """
  47. def on_after_trainer_initialized(self, trainer, driver):
  48. print("on_after_trainer_initialized")
  49. def on_sanity_check_begin(self, trainer):
  50. print("on_sanity_check_begin")
  51. def on_sanity_check_end(self, trainer, sanity_check_res):
  52. print("on_sanity_check_end")
  53. def on_train_begin(self, trainer):
  54. print("on_train_begin")
  55. def on_train_end(self, trainer):
  56. print("on_train_end")
  57. def on_train_epoch_begin(self, trainer):
  58. if trainer.cur_epoch_idx >= 1:
  59. # 触发 on_exception;
  60. raise Exception
  61. print("on_train_epoch_begin")
  62. def on_train_epoch_end(self, trainer):
  63. print("on_train_epoch_end")
  64. def on_fetch_data_begin(self, trainer):
  65. print("on_fetch_data_begin")
  66. def on_fetch_data_end(self, trainer):
  67. print("on_fetch_data_end")
  68. def on_train_batch_begin(self, trainer, batch, indices=None):
  69. print("on_train_batch_begin")
  70. def on_train_batch_end(self, trainer):
  71. print("on_train_batch_end")
  72. def on_exception(self, trainer, exception):
  73. print("on_exception")
  74. def on_before_backward(self, trainer, outputs):
  75. print("on_before_backward")
  76. def on_after_backward(self, trainer):
  77. print("on_after_backward")
  78. def on_before_optimizers_step(self, trainer, optimizers):
  79. print("on_before_optimizers_step")
  80. def on_after_optimizers_step(self, trainer, optimizers):
  81. print("on_after_optimizers_step")
  82. def on_before_zero_grad(self, trainer, optimizers):
  83. print("on_before_zero_grad")
  84. def on_after_zero_grad(self, trainer, optimizers):
  85. print("on_after_zero_grad")
  86. def on_evaluate_begin(self, trainer):
  87. print("on_evaluate_begin")
  88. def on_evaluate_end(self, trainer, results):
  89. print("on_evaluate_end")