You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

helper_callbacks.py 3.6 kB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. """
  2. 该模块用于实现一些帮助我们在测试的 callback 类;
  3. """
  4. from fastNLP.core.callbacks.callback import Callback
  5. class RecordLossCallback(Callback):
  6. """
  7. 通过该 callback 来测试模型的训练是否基本正常;
  8. """
  9. def __init__(self, loss_threshold: float):
  10. self.loss = None
  11. self.loss_threshold = loss_threshold
  12. self.loss_begin_value = None
  13. def on_before_backward(self, trainer, outputs):
  14. loss = trainer.extract_loss_from_outputs(outputs)
  15. loss = trainer.driver.tensor_to_numeric(loss)
  16. self.loss = loss
  17. if self.loss_begin_value is None:
  18. self.loss_begin_value = loss
  19. def on_train_end(self, trainer):
  20. assert self.loss < self.loss_begin_value
  21. if self.loss_threshold is not None:
  22. assert self.loss < self.loss_threshold
  23. class RecordMetricCallback(Callback):
  24. """
  25. 通过该 callback 来测试带有 metrics 的 Trainer 是否训练测试正确;
  26. """
  27. def __init__(self, monitor: str, metric_threshold: float, larger_better: bool):
  28. self.monitor = monitor
  29. self.larger_better = larger_better
  30. self.metric = None
  31. self.metric_threshold = metric_threshold
  32. self.metric_begin_value = None
  33. def on_validate_end(self, trainer, results):
  34. self.metric = results[self.monitor]
  35. if self.metric_begin_value is None:
  36. self.metric_begin_value = self.metric
  37. def on_train_end(self, trainer):
  38. if self.larger_better:
  39. assert self.metric >= self.metric_begin_value
  40. assert self.metric > self.metric_threshold
  41. else:
  42. assert self.metric <= self.metric_begin_value
  43. assert self.metric < self.metric_threshold
  44. class RecordTrainerEventTriggerCallback(Callback):
  45. """
  46. 测试每一个 callback 是否在 trainer 中都得到了调用;
  47. """
  48. def on_after_trainer_initialized(self, trainer, driver):
  49. print("on_after_trainer_initialized")
  50. def on_sanity_check_begin(self, trainer):
  51. print("on_sanity_check_begin")
  52. def on_sanity_check_end(self, trainer, sanity_check_res):
  53. print("on_sanity_check_end")
  54. def on_train_begin(self, trainer):
  55. print("on_train_begin")
  56. def on_train_end(self, trainer):
  57. print("on_train_end")
  58. def on_train_epoch_begin(self, trainer):
  59. if trainer.current_epoch_idx >= 1:
  60. # 触发 on_exception;
  61. raise Exception
  62. print("on_train_epoch_begin")
  63. def on_train_epoch_end(self, trainer):
  64. print("on_train_epoch_end")
  65. def on_fetch_data_begin(self, trainer):
  66. print("on_fetch_data_begin")
  67. def on_fetch_data_end(self, trainer):
  68. print("on_fetch_data_end")
  69. def on_train_batch_begin(self, trainer, batch, indices=None):
  70. print("on_train_batch_begin")
  71. def on_train_batch_end(self, trainer):
  72. print("on_train_batch_end")
  73. def on_exception(self, trainer, exception):
  74. print("on_exception")
  75. def on_before_backward(self, trainer, outputs):
  76. print("on_before_backward")
  77. def on_after_backward(self, trainer):
  78. print("on_after_backward")
  79. def on_before_optimizer_step(self, trainer, optimizers):
  80. print("on_before_optimizer_step")
  81. def on_before_zero_grad(self, trainer, optimizers):
  82. print("on_before_zero_grad")
  83. def on_validate_begin(self, trainer):
  84. print("on_validate_begin")
  85. def on_validate_end(self, trainer, results):
  86. print("on_validate_end")