Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
745d254aa5
3 changed files with 11 additions and 10 deletions
  1. +2
    -2
      fastNLP/core/callbacks/callback.py
  2. +9
    -6
      fastNLP/core/callbacks/progress_callback.py
  3. +0
    -2
      fastNLP/core/drivers/torch_driver/torch_driver.py

+ 2
- 2
fastNLP/core/callbacks/callback.py View File

@@ -178,8 +178,8 @@ class Callback:

def on_load_checkpoint(self, trainer, states: Optional[Dict]):
r"""
当 Trainer 要恢复 checkpoint 的时候触发(即调用 Trainer.load_checkpoint() 函数时 Trainer 与 Driver 已经加载好自身的状态),
参数 states 为 on_save_checkpoint() 的返回值。
当 Trainer 要恢复 checkpoint 的时候触发(即调用 Trainer.load_checkpoint() 函数时, 此刻 Trainer 与 Driver 已经加载好自身
的状态), 参数 states 为 Callback 在调用 on_save_checkpoint() 的返回值。

:param trainer: ``Trainer`` 实例;
:param states:


+ 9
- 6
fastNLP/core/callbacks/progress_callback.py View File

@@ -78,15 +78,17 @@ class RichCallback(ProgressCallback):
super(RichCallback, self).on_after_trainer_initialized(trainer, driver)

def on_train_begin(self, trainer):
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs,
completed=trainer.global_forward_batches/(trainer.n_batches+1e-6))
self.task2id['epoch'] = self.progress_bar.add_task(description=f'Epoch:{trainer.cur_epoch_idx}',
total=trainer.n_epochs,
completed=trainer.global_forward_batches/(trainer.n_batches+1e-6)*
trainer.n_epochs)

def on_train_epoch_begin(self, trainer):
self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6)
if 'batch' in self.task2id:
self.progress_bar.reset(self.task2id['batch'], completed=trainer.batch_idx_in_epoch)
else:
self.task2id['batch'] = self.progress_bar.add_task(description='Batch:0',
self.task2id['batch'] = self.progress_bar.add_task(description=f'Batch:{trainer.batch_idx_in_epoch}',
total=trainer.num_batches_per_epoch,
completed=trainer.batch_idx_in_epoch)

@@ -249,9 +251,10 @@ class TqdmCallback(ProgressCallback):
self.num_signs = 10

def on_train_begin(self, trainer):
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs,
self.task2id['epoch'] = self.progress_bar.add_task(description=f'Epoch:{trainer.cur_epoch_idx}',
total=trainer.n_epochs,
bar_format='{desc}: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}, {postfix}]',
initial=trainer.global_forward_batches/(trainer.n_batches+1e-6))
initial=trainer.global_forward_batches/(trainer.n_batches+1e-6)*trainer.n_epochs)

def on_train_epoch_begin(self, trainer):
self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6)
@@ -279,7 +282,7 @@ class TqdmCallback(ProgressCallback):
self.progress_bar.update(self.task2id['epoch'], advance=self.epoch_bar_update_advance, refresh=True)

def on_evaluate_end(self, trainer, results):
if len(results)==0:
if len(results) == 0:
return
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}'
text = ''


+ 0
- 2
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -4,8 +4,6 @@ from functools import partial
import numpy as np
import random
from dataclasses import dataclass

from py import process
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from pathlib import Path
if _NEED_IMPORT_TORCH:


Loading…
Cancel
Save