You can not select more than 25 topics Topics must start with a chinese character,a letter or number, can include dashes ('-') and can be up to 35 characters long.

fitlog_callback.py 4.0 kB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. __all__ = [
  2. 'FitlogCallback'
  3. ]
  4. import os
  5. from .has_monitor_callback import HasMonitorCallback
  6. from ...envs import _module_available
  7. from ...envs import get_global_rank
  8. from ..log import logger
  9. if _module_available('fitlog'):
  10. import fitlog
  11. class FitlogCallback(HasMonitorCallback):
  12. """
  13. 自动记录 ``evaluation`` 结果到 ``fitlog`` 中。会自动记录每一次 ``evaluate`` 后的结果;同时会根据
  14. ``monitor`` 记录最好的结果。另外,会自动将非 ``rank 0`` 上的 ``fitlog`` 设置为 ``debug`` 状态。同时还会在 ``fitlog`` 的
  15. ``other`` 列中记录一个 ``launch_time`` ,可以通过这个数值找到当前这个脚本的在 save_folder (如果有使用其它需要保存模型的
  16. ``Callback`` ,例如 :class:`~fastNLP.core.callbacks.CheckpointCallback` )下的文件夹名称。
  17. :param monitor: 监控的 metric 值。
  18. * 为 ``None``
  19. 将尝试使用 :class:`~fastNLP.core.controllers.Trainer` 中设置 `monitor` 值(如果有设置)。
  20. * 为 ``str``
  21. 尝试直接使用该名称从 ``evaluation`` 结果中寻找,如果在 ``evaluation`` 结果中没有找到完全一致的名称,将
  22. 使用 最长公共字符串算法 从 ``evaluation`` 结果中找到最匹配的那个作为 ``monitor`` 。
  23. * 为 :class:`Callable`
  24. 接受参数为 ``evaluation`` 的结果(字典类型),返回一个 ``float`` 值作为 ``monitor`` 的结果,如果当前结果中没有相关
  25. 的 ``monitor`` 值请返回 ``None`` 。
  26. :param larger_better: 是否是越大越好。
  27. :param log_exception: 是否记录 ``exception`` 。
  28. :param log_loss_every: 多少个 ``batch`` 记录一次 loss 到 ``fitlog`` 中。
  29. """
  30. def __init__(self, monitor=None, larger_better: bool = True, log_exception:bool=True, log_loss_every:int=0):
  31. assert _module_available('fitlog'), "fitlog is not installed."
  32. super().__init__(monitor=monitor, larger_better=larger_better)
  33. self.log_exception = log_exception
  34. self.log_loss_every = log_loss_every
  35. self.avg_loss = 0
  36. def on_after_trainer_initialized(self, trainer, driver):
  37. if get_global_rank() != 0: # 如果不是 global rank 为 0 ,需要关闭 fitlog
  38. fitlog.debug()
  39. super().on_after_trainer_initialized(trainer, driver)
  40. fitlog.add_other(name='launch_time', value=os.environ['FASTNLP_LAUNCH_TIME'])
  41. def on_sanity_check_end(self, trainer, sanity_check_res):
  42. super(FitlogCallback, self).on_sanity_check_end(trainer, sanity_check_res)
  43. if self.monitor is None:
  44. logger.rank_zero_warning(f"No monitor set for {self.log_name}. Therefore, no best metric will "
  45. f"be logged.")
  46. def on_evaluate_end(self, trainer, results):
  47. results = self.itemize_results(results)
  48. fitlog.add_metric(results, step=trainer.global_forward_batches, epoch=trainer.cur_epoch_idx)
  49. if self.is_better_results(results, keep_if_better=True):
  50. results['step'] = trainer.global_forward_batches
  51. results['epoch'] = trainer.cur_epoch_idx
  52. fitlog.add_best_metric(results)
  53. def on_before_backward(self, trainer, outputs):
  54. if self.log_loss_every > 0:
  55. loss = trainer.extract_loss_from_outputs(outputs)
  56. self.avg_loss += loss.item()
  57. if trainer.global_forward_batches % self.log_loss_every == 0:
  58. fitlog.add_loss(self.avg_loss / self.log_loss_every * trainer.accumulation_steps, name='loss',
  59. step=trainer.global_forward_batches,
  60. epoch=trainer.cur_epoch_idx)
  61. self.avg_loss = 0
  62. def on_train_end(self, trainer):
  63. fitlog.finish()
  64. def on_exception(self, trainer, exception):
  65. fitlog.finish(status=1)
  66. if self.log_exception:
  67. fitlog.add_other(repr(exception), name='except_info')