Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
MorningForest 3 years ago
parent
commit
745d254aa5
3 changed files with 11 additions and 10 deletions
  1. +2
    -2
      fastNLP/core/callbacks/callback.py
  2. +9
    -6
      fastNLP/core/callbacks/progress_callback.py
  3. +0
    -2
      fastNLP/core/drivers/torch_driver/torch_driver.py

+ 2
- 2
fastNLP/core/callbacks/callback.py View File

@@ -178,8 +178,8 @@ class Callback:


def on_load_checkpoint(self, trainer, states: Optional[Dict]): def on_load_checkpoint(self, trainer, states: Optional[Dict]):
r""" r"""
当 Trainer 要恢复 checkpoint 的时候触发(即调用 Trainer.load_checkpoint() 函数时 Trainer 与 Driver 已经加载好自身的状态),
参数 states 为 on_save_checkpoint() 的返回值。
当 Trainer 要恢复 checkpoint 的时候触发(即调用 Trainer.load_checkpoint() 函数时, 此刻 Trainer 与 Driver 已经加载好自身
的状态), 参数 states 为 Callback 在调用 on_save_checkpoint() 的返回值。


:param trainer: ``Trainer`` 实例; :param trainer: ``Trainer`` 实例;
:param states: :param states:


+ 9
- 6
fastNLP/core/callbacks/progress_callback.py View File

@@ -78,15 +78,17 @@ class RichCallback(ProgressCallback):
super(RichCallback, self).on_after_trainer_initialized(trainer, driver) super(RichCallback, self).on_after_trainer_initialized(trainer, driver)


def on_train_begin(self, trainer): def on_train_begin(self, trainer):
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs,
completed=trainer.global_forward_batches/(trainer.n_batches+1e-6))
self.task2id['epoch'] = self.progress_bar.add_task(description=f'Epoch:{trainer.cur_epoch_idx}',
total=trainer.n_epochs,
completed=trainer.global_forward_batches/(trainer.n_batches+1e-6)*
trainer.n_epochs)


def on_train_epoch_begin(self, trainer): def on_train_epoch_begin(self, trainer):
self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6) self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6)
if 'batch' in self.task2id: if 'batch' in self.task2id:
self.progress_bar.reset(self.task2id['batch'], completed=trainer.batch_idx_in_epoch) self.progress_bar.reset(self.task2id['batch'], completed=trainer.batch_idx_in_epoch)
else: else:
self.task2id['batch'] = self.progress_bar.add_task(description='Batch:0',
self.task2id['batch'] = self.progress_bar.add_task(description=f'Batch:{trainer.batch_idx_in_epoch}',
total=trainer.num_batches_per_epoch, total=trainer.num_batches_per_epoch,
completed=trainer.batch_idx_in_epoch) completed=trainer.batch_idx_in_epoch)


@@ -249,9 +251,10 @@ class TqdmCallback(ProgressCallback):
self.num_signs = 10 self.num_signs = 10


def on_train_begin(self, trainer): def on_train_begin(self, trainer):
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs,
self.task2id['epoch'] = self.progress_bar.add_task(description=f'Epoch:{trainer.cur_epoch_idx}',
total=trainer.n_epochs,
bar_format='{desc}: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}, {postfix}]', bar_format='{desc}: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}, {postfix}]',
initial=trainer.global_forward_batches/(trainer.n_batches+1e-6))
initial=trainer.global_forward_batches/(trainer.n_batches+1e-6)*trainer.n_epochs)


def on_train_epoch_begin(self, trainer): def on_train_epoch_begin(self, trainer):
self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6) self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6)
@@ -279,7 +282,7 @@ class TqdmCallback(ProgressCallback):
self.progress_bar.update(self.task2id['epoch'], advance=self.epoch_bar_update_advance, refresh=True) self.progress_bar.update(self.task2id['epoch'], advance=self.epoch_bar_update_advance, refresh=True)


def on_evaluate_end(self, trainer, results): def on_evaluate_end(self, trainer, results):
if len(results)==0:
if len(results) == 0:
return return
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}'
text = '' text = ''


+ 0
- 2
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -4,8 +4,6 @@ from functools import partial
import numpy as np import numpy as np
import random import random
from dataclasses import dataclass from dataclasses import dataclass

from py import process
from fastNLP.envs.imports import _NEED_IMPORT_TORCH from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from pathlib import Path from pathlib import Path
if _NEED_IMPORT_TORCH: if _NEED_IMPORT_TORCH:


Loading…
Cancel
Save