From 0bf0dab347987cb677d4087cefc9c0ba1090d3df Mon Sep 17 00:00:00 2001 From: yhcc Date: Tue, 7 Jun 2022 14:52:09 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E4=B8=80=E4=B8=AATimerCallba?= =?UTF-8?q?ck=E7=94=A8=E4=BA=8E=E8=AE=A1=E6=97=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/__init__.py | 1 + fastNLP/core/callbacks/__init__.py | 6 +- fastNLP/core/callbacks/has_monitor_callback.py | 2 +- fastNLP/core/callbacks/load_best_model_callback.py | 4 +- fastNLP/core/callbacks/progress_callback.py | 30 ++-- fastNLP/core/callbacks/timer_callback.py | 152 +++++++++++++++++++++ fastNLP/core/controllers/loops/train_batch_loop.py | 4 +- fastNLP/core/dataloaders/torch_dataloader/fdl.py | 3 + 8 files changed, 188 insertions(+), 14 deletions(-) create mode 100644 fastNLP/core/callbacks/timer_callback.py diff --git a/fastNLP/core/__init__.py b/fastNLP/core/__init__.py index 4f5ee3d8..6cf73d3b 100644 --- a/fastNLP/core/__init__.py +++ b/fastNLP/core/__init__.py @@ -16,6 +16,7 @@ __all__ = [ "ResultsMonitor", 'HasMonitorCallback', "FitlogCallback", + "TimerCallback", # collators 'Collator', diff --git a/fastNLP/core/callbacks/__init__.py b/fastNLP/core/callbacks/__init__.py index 48699b68..d1f19b96 100644 --- a/fastNLP/core/callbacks/__init__.py +++ b/fastNLP/core/callbacks/__init__.py @@ -21,7 +21,9 @@ __all__ = [ "ResultsMonitor", 'HasMonitorCallback', - "FitlogCallback" + "FitlogCallback", + + "TimerCallback" ] @@ -37,4 +39,4 @@ from .torch_callbacks import * from .more_evaluate_callback import MoreEvaluateCallback from .has_monitor_callback import ResultsMonitor, HasMonitorCallback from .fitlog_callback import FitlogCallback - +from .timer_callback import TimerCallback diff --git a/fastNLP/core/callbacks/has_monitor_callback.py b/fastNLP/core/callbacks/has_monitor_callback.py index 0b57bf53..4fadc3d7 100644 --- a/fastNLP/core/callbacks/has_monitor_callback.py +++ b/fastNLP/core/callbacks/has_monitor_callback.py @@ -171,7 +171,7 @@ class ResultsMonitor: @property def log_name(self) -> str: """ - 内部用于打印信息使用 + 内部用于打印当前类别信息使用 :return: """ diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index 4f52720f..ec6579a6 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -106,11 +106,11 @@ class LoadBestModelCallback(HasMonitorCallback): def on_train_end(self, trainer): if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 if self.real_save_folder: - logger.info(f"Loading best model from {self.real_save_folder} with {self.monitor_name}: {self.monitor_value}...") + logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value}...") trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, model_load_fn=self.model_load_fn) else: - logger.info(f"Loading best model from buffer with {self.monitor_name}: {self.monitor_value}...") + logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value}...") self.buffer.seek(0) trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) if self.delete_after_after: diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index 36524a6b..2fa62c87 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -1,5 +1,4 @@ import json -import sys from typing import Union __all__ = [ @@ -16,8 +15,21 @@ from fastNLP.core.log import logger class ProgressCallback(HasMonitorCallback): + def __init__(self, monitor, larger_better, must_have_monitor=False): + super(ProgressCallback, self).__init__(monitor=monitor, larger_better=larger_better, + must_have_monitor=must_have_monitor) + self.best_monitor_epoch = -1 + self.best_monitor_step = -1 + + def record_better_monitor(self, trainer): + self.best_monitor_step = trainer.global_forward_batches + self.best_monitor_epoch = trainer.cur_epoch_idx + def on_train_end(self, trainer): - f_rich_progress.stop() + if self.best_monitor_epoch != -1: + msg = f"The best performance for monitor {self._real_monitor}:{self.monitor_value} was achieved in" \ + f" Epoch:{self.best_monitor_epoch}, Global Batch:{self.best_monitor_step}." + logger.info(msg) @property def name(self): # progress bar的名称 @@ -97,6 +109,7 @@ class RichCallback(ProgressCallback): advance=None, completed=trainer.cur_epoch_idx, refresh=True) def on_train_end(self, trainer): + super(RichCallback, self).on_train_end(trainer) self.clear_tasks() def on_before_backward(self, trainer, outputs): @@ -121,8 +134,8 @@ class RichCallback(ProgressCallback): text_style = '' characters = '-' if self.monitor is not None: - monitor_value = self.get_monitor_value(results) - if self.is_better_monitor_value(monitor_value, keep_if_better=True): + if self.is_better_results(results, keep_if_better=True): + self.record_better_monitor(trainer) if abs(self.monitor_value) != float('inf'): rule_style = 'spring_green3' text_style = '[bold]' @@ -201,8 +214,8 @@ class RawTextCallback(ProgressCallback): base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' text = '' if self.monitor is not None: - monitor_value = self.get_monitor_value(results) - if self.is_better_monitor_value(monitor_value, keep_if_better=True): + if self.is_better_results(results, keep_if_better=True): + self.record_better_monitor(trainer) if abs(self.monitor_value) != float('inf'): text = '+'*self.num_signs + base_text + '+'*self.num_signs if len(text) == 0: @@ -266,6 +279,7 @@ class TqdmCallback(ProgressCallback): self.progress_bar.set_description_str(self.task2id['epoch'], f'Epoch:{trainer.cur_epoch_idx}', refresh=True) def on_train_end(self, trainer): + super(TqdmCallback, self).on_train_end(trainer) self.clear_tasks() def on_before_backward(self, trainer, outputs): @@ -287,8 +301,8 @@ class TqdmCallback(ProgressCallback): base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' text = '' if self.monitor is not None: - monitor_value = self.get_monitor_value(results) - if self.is_better_monitor_value(monitor_value, keep_if_better=True): + if self.is_better_results(results, keep_if_better=True): + self.record_better_monitor(trainer) if abs(self.monitor_value) != float('inf'): text = '+'*self.num_signs + base_text + '+'*self.num_signs if len(text) == 0: diff --git a/fastNLP/core/callbacks/timer_callback.py b/fastNLP/core/callbacks/timer_callback.py new file mode 100644 index 00000000..f0dafcb6 --- /dev/null +++ b/fastNLP/core/callbacks/timer_callback.py @@ -0,0 +1,152 @@ +import time +from .callback import Callback +from ..log import logger +__all__ = ['TimerCallback'] + + +class _Timer: + """Timer.""" + + def __init__(self, name): + self.name_ = name + self.elapsed_ = 0.0 + self.started_ = False + self.start_time = time.time() + + def start(self): + """Start the timer.""" + assert not self.started_, f'{self.name_} timer has already been started' + self.start_time = time.time() + self.started_ = True + + def stop(self): + """Stop the timer.""" + assert self.started_, f'{self.name_} timer is not started' + self.elapsed_ += (time.time() - self.start_time) + self.started_ = False + + def reset(self): + """Reset timer.""" + self.elapsed_ = 0.0 + self.started_ = False + + def elapsed(self, reset=True): + """Calculate the elapsed time.""" + started_ = self.started_ + # If the timing in progress, end it first. + if self.started_: + self.stop() + # Get the elapsed time. + elapsed_ = self.elapsed_ + # Reset the elapsed time + if reset: + self.reset() + # If timing was in progress, set it back. + if started_: + self.start() + return elapsed_ + + +class Timers: + """Group of timers.""" + + def __init__(self): + self.timers = {} + + def __call__(self, name): + if name not in self.timers: + self.timers[name] = _Timer(name) + return self.timers[name] + + def __contains__(self, item): + return item in self.timers + + def reset(self): + for timer in self.timers.values(): + timer.reset() + + +class TimerCallback(Callback): + """ + 这个 callback 的作用是打印训练过程中的相关时间信息,例如训练时长,评测时长,总的时长等 + + """ + def __init__(self, print_every=-1, time_ndigit=3): + """ + + :param print_every: 在哪个时候打印时间信息。 + + * *负数*: 表示每隔多少 epoch 结束打印一次; + * *0*: 表示整个训练结束才打印; + * *正数*: 每隔多少个 step 打印一次; + + :param time_ndigit: 保留多少位的小数 + """ + assert isinstance(print_every, int), "print_every must be an int number." + self.timers = Timers() + self.print_every = print_every + self.time_ndigit = time_ndigit + + def on_train_begin(self, trainer): + self.timers('total').start() + self.timers('train').start() + + def on_fetch_data_begin(self, trainer): + self.timers('fetch-data').start() + + def on_fetch_data_end(self, trainer): + self.timers('fetch-data').stop() + + def on_train_batch_begin(self, trainer, batch, indices): + self.timers('forward').start() + + def on_before_backward(self, trainer, outputs): + self.timers('forward').stop() + self.timers('backward').start() + + def on_after_backward(self, trainer): + self.timers('backward').stop() + + def on_before_optimizers_step(self, trainer, optimizers): + self.timers('optimize').start() + + def on_after_optimizers_step(self, trainer, optimizers): + self.timers('optimize').stop() + + def on_evaluate_begin(self, trainer): + self.timers('train').stop() + self.timers('evaluate').start() + + def on_evaluate_end(self, trainer, results): + self.timers('evaluate').stop() + self.timers('train').start() + + def format_timer(self, reset=True): + line = '' + timers = ['fetch-data', 'forward', 'backward', 'optimize', 'evaluate', 'train', 'total'] + for timer_name in timers: + if not timer_name in self.timers: + continue + timer = self.timers(timer_name) + elapsed = round(timer.elapsed(reset=reset), self.time_ndigit) + if elapsed != 0: + line = line + f', {timer_name}: {elapsed}s' + return line + + def on_train_batch_end(self, trainer): + if self.print_every>0 and trainer.global_forward_batches % self.print_every == 0: + line = self.format_timer() + logger.info(f"Running {self.print_every} batches{line}") + + def on_train_epoch_end(self, trainer): + if self.print_every < 0 and trainer.cur_epoch_idx % abs(self.print_every) == 0: + line = self.format_timer() + logger.info(f"Running {abs(self.print_every)} epochs{line}") + + def on_train_end(self, trainer): + if self.print_every == 0: + line = self.format_timer() + logger.info(f"Training finished{line}") + + + diff --git a/fastNLP/core/controllers/loops/train_batch_loop.py b/fastNLP/core/controllers/loops/train_batch_loop.py index 645f4224..ca8389b1 100644 --- a/fastNLP/core/controllers/loops/train_batch_loop.py +++ b/fastNLP/core/controllers/loops/train_batch_loop.py @@ -41,10 +41,12 @@ class TrainBatchLoop(Loop): batch = next(dataloader) indices = get_batch_indices() except StopIteration: + trainer.on_fetch_data_end() break + trainer.on_fetch_data_end() + try: - trainer.on_fetch_data_end() batch = match_and_substitute_params(trainer.input_mapping, batch) batch = trainer.move_data_to_device(batch) diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 09fa2ff6..2a119260 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -108,6 +108,9 @@ class TorchDataLoader(DataLoader): if not isinstance(dataset, _FDataSet): dataset = _FDataSet(dataset) + if num_workers>0 and multiprocessing_context is None: + multiprocessing_context = 'fork' # 这里默认使用fork的方式来启动多进程 + if batch_sampler is not None: batch_size = 1 shuffle = False