From 91b56cb3638a7acbd126c51f2a48de81064458bf Mon Sep 17 00:00:00 2001 From: szhang0381 Date: Thu, 20 Oct 2022 18:38:53 +0800 Subject: [PATCH] =?UTF-8?q?LoadBestModelCallback=E5=A2=9E=E5=8A=A0epoch?= =?UTF-8?q?=EF=BC=8Cbatch=E8=AE=B0=E5=BD=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/load_best_model_callback.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index 2bd41b5a..73b11b9b 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -7,6 +7,7 @@ from typing import Optional, Callable, Union from .has_monitor_callback import HasMonitorCallback from io import BytesIO import shutil +import pickle from fastNLP.envs.env import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK, FASTNLP_BACKEND_LAUNCH from fastNLP.core.log import logger @@ -63,6 +64,7 @@ class LoadBestModelCallback(HasMonitorCallback): self.model_save_fn = model_save_fn self.model_load_fn = model_load_fn self.delete_after_after = delete_after_train + self.meta = {'epoch': -1, 'batch': -1} def prepare_save_folder(self, trainer): if not hasattr(self, 'real_save_folder'): @@ -87,6 +89,7 @@ class LoadBestModelCallback(HasMonitorCallback): else: # 创建出一个 stringio self.real_save_folder = None self.buffer = BytesIO() + def on_after_trainer_initialized(self, trainer, driver): super().on_after_trainer_initialized(trainer, driver) @@ -94,6 +97,8 @@ class LoadBestModelCallback(HasMonitorCallback): def on_evaluate_end(self, trainer, results): if self.is_better_results(results, keep_if_better=True): + self.meta['epoch'] = trainer.cur_epoch_idx + self.meta['batch'] = trainer.global_forward_batches self.prepare_save_folder(trainer) if self.real_save_folder: trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, @@ -102,17 +107,17 @@ class LoadBestModelCallback(HasMonitorCallback): self.buffer.seek(0) with all_rank_call_context(): trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict) - + def on_train_end(self, trainer): if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 # 如果是分布式且报错了,就不要加载了,防止barrier的问题 if not (trainer.driver.is_distributed() and self.encounter_exception): if self.real_save_folder: - logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value}...") + logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value} (achieved in Epoch:{self.meta['epoch']}, Global Batch:{self.meta['batch']})...") trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, model_load_fn=self.model_load_fn) else: - logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value}...") + logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value} (achieved in Epoch:{self.meta['epoch']}, Global Batch:{self.meta['batch']})...") self.buffer.seek(0) trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) if self.delete_after_after: