@@ -7,6 +7,7 @@ from typing import Optional, Callable, Union
from .has_monitor_callback import HasMonitorCallback
from .has_monitor_callback import HasMonitorCallback
from io import BytesIO
from io import BytesIO
import shutil
import shutil
import pickle
from fastNLP.envs.env import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK, FASTNLP_BACKEND_LAUNCH
from fastNLP.envs.env import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK, FASTNLP_BACKEND_LAUNCH
from fastNLP.core.log import logger
from fastNLP.core.log import logger
@@ -63,6 +64,7 @@ class LoadBestModelCallback(HasMonitorCallback):
self.model_save_fn = model_save_fn
self.model_save_fn = model_save_fn
self.model_load_fn = model_load_fn
self.model_load_fn = model_load_fn
self.delete_after_after = delete_after_train
self.delete_after_after = delete_after_train
self.meta = {'epoch': -1, 'batch': -1}
def prepare_save_folder(self, trainer):
def prepare_save_folder(self, trainer):
if not hasattr(self, 'real_save_folder'):
if not hasattr(self, 'real_save_folder'):
@@ -87,6 +89,7 @@ class LoadBestModelCallback(HasMonitorCallback):
else: # 创建出一个 stringio
else: # 创建出一个 stringio
self.real_save_folder = None
self.real_save_folder = None
self.buffer = BytesIO()
self.buffer = BytesIO()
def on_after_trainer_initialized(self, trainer, driver):
def on_after_trainer_initialized(self, trainer, driver):
super().on_after_trainer_initialized(trainer, driver)
super().on_after_trainer_initialized(trainer, driver)
@@ -94,6 +97,8 @@ class LoadBestModelCallback(HasMonitorCallback):
def on_evaluate_end(self, trainer, results):
def on_evaluate_end(self, trainer, results):
if self.is_better_results(results, keep_if_better=True):
if self.is_better_results(results, keep_if_better=True):
self.meta['epoch'] = trainer.cur_epoch_idx
self.meta['batch'] = trainer.global_forward_batches
self.prepare_save_folder(trainer)
self.prepare_save_folder(trainer)
if self.real_save_folder:
if self.real_save_folder:
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
@@ -102,17 +107,17 @@ class LoadBestModelCallback(HasMonitorCallback):
self.buffer.seek(0)
self.buffer.seek(0)
with all_rank_call_context():
with all_rank_call_context():
trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict)
trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict)
def on_train_end(self, trainer):
def on_train_end(self, trainer):
if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。
if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。
# 如果是分布式且报错了,就不要加载了,防止barrier的问题
# 如果是分布式且报错了,就不要加载了,防止barrier的问题
if not (trainer.driver.is_distributed() and self.encounter_exception):
if not (trainer.driver.is_distributed() and self.encounter_exception):
if self.real_save_folder:
if self.real_save_folder:
logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value}...")
logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value} (achieved in Epoch:{self.meta['epoch']}, Global Batch:{self.meta['batch']}) ...")
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
model_load_fn=self.model_load_fn)
model_load_fn=self.model_load_fn)
else:
else:
logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value}...")
logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value} (achieved in Epoch:{self.meta['epoch']}, Global Batch:{self.meta['batch']}) ...")
self.buffer.seek(0)
self.buffer.seek(0)
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict)
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict)
if self.delete_after_after:
if self.delete_after_after: