From 281ffc3b919b2df78f69b56a3437c2bc58025399 Mon Sep 17 00:00:00 2001 From: yhcc Date: Fri, 3 Jun 2022 20:09:11 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E4=BF=AE=E5=A4=8Drichcallback=E7=AD=89?= =?UTF-8?q?=E5=9C=A8=E6=96=AD=E7=82=B9=E9=87=8D=E8=AE=AD=E6=97=B6=E6=98=BE?= =?UTF-8?q?=E7=A4=BA=E8=BF=9B=E5=BA=A6=E4=B8=8D=E6=AD=A3=E7=A1=AE=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback.py | 4 ++-- fastNLP/core/callbacks/progress_callback.py | 15 +++++++++------ 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index f4dbcb42..d3679572 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -178,8 +178,8 @@ class Callback: def on_load_checkpoint(self, trainer, states: Optional[Dict]): r""" - 当 Trainer 要恢复 checkpoint 的时候触发(即调用 Trainer.load_checkpoint() 函数时 Trainer 与 Driver 已经加载好自身的状态), - 参数 states 为 on_save_checkpoint() 的返回值。 + 当 Trainer 要恢复 checkpoint 的时候触发(即调用 Trainer.load_checkpoint() 函数时, 此刻 Trainer 与 Driver 已经加载好自身 + 的状态), 参数 states 为 Callback 在调用 on_save_checkpoint() 的返回值。 :param trainer: ``Trainer`` 实例; :param states: diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index 42f703af..36524a6b 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -78,15 +78,17 @@ class RichCallback(ProgressCallback): super(RichCallback, self).on_after_trainer_initialized(trainer, driver) def on_train_begin(self, trainer): - self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, - completed=trainer.global_forward_batches/(trainer.n_batches+1e-6)) + self.task2id['epoch'] = self.progress_bar.add_task(description=f'Epoch:{trainer.cur_epoch_idx}', + total=trainer.n_epochs, + completed=trainer.global_forward_batches/(trainer.n_batches+1e-6)* + trainer.n_epochs) def on_train_epoch_begin(self, trainer): self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6) if 'batch' in self.task2id: self.progress_bar.reset(self.task2id['batch'], completed=trainer.batch_idx_in_epoch) else: - self.task2id['batch'] = self.progress_bar.add_task(description='Batch:0', + self.task2id['batch'] = self.progress_bar.add_task(description=f'Batch:{trainer.batch_idx_in_epoch}', total=trainer.num_batches_per_epoch, completed=trainer.batch_idx_in_epoch) @@ -249,9 +251,10 @@ class TqdmCallback(ProgressCallback): self.num_signs = 10 def on_train_begin(self, trainer): - self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, + self.task2id['epoch'] = self.progress_bar.add_task(description=f'Epoch:{trainer.cur_epoch_idx}', + total=trainer.n_epochs, bar_format='{desc}: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}, {postfix}]', - initial=trainer.global_forward_batches/(trainer.n_batches+1e-6)) + initial=trainer.global_forward_batches/(trainer.n_batches+1e-6)*trainer.n_epochs) def on_train_epoch_begin(self, trainer): self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6) @@ -279,7 +282,7 @@ class TqdmCallback(ProgressCallback): self.progress_bar.update(self.task2id['epoch'], advance=self.epoch_bar_update_advance, refresh=True) def on_evaluate_end(self, trainer, results): - if len(results)==0: + if len(results) == 0: return base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' text = '' From 21b7b58793b0b6b5f2fa0644a069e58471f6f0aa Mon Sep 17 00:00:00 2001 From: yhcc Date: Fri, 3 Jun 2022 20:21:58 +0800 Subject: [PATCH 2/2] typo fix --- fastNLP/core/drivers/torch_driver/torch_driver.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 93f607d6..96529073 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -4,8 +4,6 @@ from functools import partial import numpy as np import random from dataclasses import dataclass - -from py import process from fastNLP.envs.imports import _NEED_IMPORT_TORCH from pathlib import Path if _NEED_IMPORT_TORCH: