Browse Source

progres_bar统一

tags/v1.0.0alpha
yhcc 2 years ago
parent
commit
e1fbc2cfeb
4 changed files with 11 additions and 6 deletions
  1. +1
    -0
      fastNLP/core/__init__.py
  2. +2
    -1
      fastNLP/core/callbacks/__init__.py
  3. +2
    -1
      fastNLP/core/callbacks/progress_callback.py
  4. +6
    -4
      fastNLP/core/controllers/trainer.py

+ 1
- 0
fastNLP/core/__init__.py View File

@@ -7,6 +7,7 @@ __all__ = [
'ProgressCallback', 'ProgressCallback',
'RichCallback', 'RichCallback',
'TqdmCallback', 'TqdmCallback',
'RawTextCallback',
"LRSchedCallback", "LRSchedCallback",
'LoadBestModelCallback', 'LoadBestModelCallback',
"EarlyStopCallback", "EarlyStopCallback",


+ 2
- 1
fastNLP/core/callbacks/__init__.py View File

@@ -8,6 +8,7 @@ __all__ = [
'ProgressCallback', 'ProgressCallback',
'RichCallback', 'RichCallback',
'TqdmCallback', 'TqdmCallback',
'RawTextCallback',


"LRSchedCallback", "LRSchedCallback",
'LoadBestModelCallback', 'LoadBestModelCallback',
@@ -31,7 +32,7 @@ from .callback import Callback
from .callback_event import Event, Filter from .callback_event import Event, Filter
from .callback_manager import CallbackManager from .callback_manager import CallbackManager
from .checkpoint_callback import CheckpointCallback from .checkpoint_callback import CheckpointCallback
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback, TqdmCallback
from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback, TqdmCallback, RawTextCallback
from .lr_scheduler_callback import LRSchedCallback from .lr_scheduler_callback import LRSchedCallback
from .load_best_model_callback import LoadBestModelCallback from .load_best_model_callback import LoadBestModelCallback
from .early_stop_callback import EarlyStopCallback from .early_stop_callback import EarlyStopCallback


+ 2
- 1
fastNLP/core/callbacks/progress_callback.py View File

@@ -5,7 +5,8 @@ __all__ = [
'choose_progress_callback', 'choose_progress_callback',
'ProgressCallback', 'ProgressCallback',
'RichCallback', 'RichCallback',
'TqdmCallback'
'TqdmCallback',
'RawTextCallback'
] ]






+ 6
- 4
fastNLP/core/controllers/trainer.py View File

@@ -36,6 +36,7 @@ from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.utils.exceptions import EarlyStopException from fastNLP.core.utils.exceptions import EarlyStopException
from fastNLP.core.dataloaders import OverfitDataLoader from fastNLP.core.dataloaders import OverfitDataLoader
from fastNLP.core.callbacks.progress_callback import ProgressCallback




class Trainer(TrainerEventTrigger): class Trainer(TrainerEventTrigger):
@@ -554,15 +555,16 @@ class Trainer(TrainerEventTrigger):
evaluate_dataloaders = self.dataloader evaluate_dataloaders = self.dataloader
if evaluate_dataloaders is not None: if evaluate_dataloaders is not None:
check_evaluate_every(evaluate_every) check_evaluate_every(evaluate_every)
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。
progress_bar = progress_bar.name
progress_bar_name = None
for callback in self.callback_manager.class_callbacks:
if isinstance(callback, ProgressCallback):
progress_bar_name = callback.name
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics,
driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn, driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping, evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping,
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0, output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0,
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler), use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler),
progress_bar=progress_bar,
progress_bar=progress_bar_name,
check_dataloader_legality=kwargs.get('check_dataloader_legality', True)) check_dataloader_legality=kwargs.get('check_dataloader_legality', True))
else: else:
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.") raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.")


Loading…
Cancel
Save