Browse Source

LoadBestModelCallback增加epoch,batch记录

pull/12/head
szhang0381 2 years ago
parent
commit
91b56cb363
1 changed files with 8 additions and 3 deletions
  1. +8
    -3
      fastNLP/core/callbacks/load_best_model_callback.py

+ 8
- 3
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -7,6 +7,7 @@ from typing import Optional, Callable, Union
from .has_monitor_callback import HasMonitorCallback
from io import BytesIO
import shutil
import pickle

from fastNLP.envs.env import FASTNLP_LAUNCH_TIME, FASTNLP_GLOBAL_RANK, FASTNLP_BACKEND_LAUNCH
from fastNLP.core.log import logger
@@ -63,6 +64,7 @@ class LoadBestModelCallback(HasMonitorCallback):
self.model_save_fn = model_save_fn
self.model_load_fn = model_load_fn
self.delete_after_after = delete_after_train
self.meta = {'epoch': -1, 'batch': -1}

def prepare_save_folder(self, trainer):
if not hasattr(self, 'real_save_folder'):
@@ -87,6 +89,7 @@ class LoadBestModelCallback(HasMonitorCallback):
else: # 创建出一个 stringio
self.real_save_folder = None
self.buffer = BytesIO()

def on_after_trainer_initialized(self, trainer, driver):
super().on_after_trainer_initialized(trainer, driver)
@@ -94,6 +97,8 @@ class LoadBestModelCallback(HasMonitorCallback):

def on_evaluate_end(self, trainer, results):
if self.is_better_results(results, keep_if_better=True):
self.meta['epoch'] = trainer.cur_epoch_idx
self.meta['batch'] = trainer.global_forward_batches
self.prepare_save_folder(trainer)
if self.real_save_folder:
trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
@@ -102,17 +107,17 @@ class LoadBestModelCallback(HasMonitorCallback):
self.buffer.seek(0)
with all_rank_call_context():
trainer.save_model(folder=self.buffer, only_state_dict=self.only_state_dict)
def on_train_end(self, trainer):
if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。
# 如果是分布式且报错了,就不要加载了,防止barrier的问题
if not (trainer.driver.is_distributed() and self.encounter_exception):
if self.real_save_folder:
logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value}...")
logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value} (achieved in Epoch:{self.meta['epoch']}, Global Batch:{self.meta['batch']})...")
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
model_load_fn=self.model_load_fn)
else:
logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value}...")
logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value} (achieved in Epoch:{self.meta['epoch']}, Global Batch:{self.meta['batch']})...")
self.buffer.seek(0)
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict)
if self.delete_after_after:


Loading…
Cancel
Save