Browse Source

增加一个TimerCallback用于计时

tags/v1.0.0alpha
yhcc 3 years ago
parent
commit
0bf0dab347
8 changed files with 188 additions and 14 deletions
  1. +1
    -0
      fastNLP/core/__init__.py
  2. +4
    -2
      fastNLP/core/callbacks/__init__.py
  3. +1
    -1
      fastNLP/core/callbacks/has_monitor_callback.py
  4. +2
    -2
      fastNLP/core/callbacks/load_best_model_callback.py
  5. +22
    -8
      fastNLP/core/callbacks/progress_callback.py
  6. +152
    -0
      fastNLP/core/callbacks/timer_callback.py
  7. +3
    -1
      fastNLP/core/controllers/loops/train_batch_loop.py
  8. +3
    -0
      fastNLP/core/dataloaders/torch_dataloader/fdl.py

+ 1
- 0
fastNLP/core/__init__.py View File

@@ -16,6 +16,7 @@ __all__ = [
"ResultsMonitor",
'HasMonitorCallback',
"FitlogCallback",
"TimerCallback",

# collators
'Collator',


+ 4
- 2
fastNLP/core/callbacks/__init__.py View File

@@ -21,7 +21,9 @@ __all__ = [
"ResultsMonitor",
'HasMonitorCallback',

"FitlogCallback"
"FitlogCallback",

"TimerCallback"
]


@@ -37,4 +39,4 @@ from .torch_callbacks import *
from .more_evaluate_callback import MoreEvaluateCallback
from .has_monitor_callback import ResultsMonitor, HasMonitorCallback
from .fitlog_callback import FitlogCallback
from .timer_callback import TimerCallback

+ 1
- 1
fastNLP/core/callbacks/has_monitor_callback.py View File

@@ -171,7 +171,7 @@ class ResultsMonitor:
@property
def log_name(self) -> str:
"""
内部用于打印信息使用
内部用于打印当前类别信息使用

:return:
"""


+ 2
- 2
fastNLP/core/callbacks/load_best_model_callback.py View File

@@ -106,11 +106,11 @@ class LoadBestModelCallback(HasMonitorCallback):
def on_train_end(self, trainer):
if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。
if self.real_save_folder:
logger.info(f"Loading best model from {self.real_save_folder} with {self.monitor_name}: {self.monitor_value}...")
logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value}...")
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict,
model_load_fn=self.model_load_fn)
else:
logger.info(f"Loading best model from buffer with {self.monitor_name}: {self.monitor_value}...")
logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value}...")
self.buffer.seek(0)
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict)
if self.delete_after_after:


+ 22
- 8
fastNLP/core/callbacks/progress_callback.py View File

@@ -1,5 +1,4 @@
import json
import sys
from typing import Union

__all__ = [
@@ -16,8 +15,21 @@ from fastNLP.core.log import logger


class ProgressCallback(HasMonitorCallback):
def __init__(self, monitor, larger_better, must_have_monitor=False):
super(ProgressCallback, self).__init__(monitor=monitor, larger_better=larger_better,
must_have_monitor=must_have_monitor)
self.best_monitor_epoch = -1
self.best_monitor_step = -1

def record_better_monitor(self, trainer):
self.best_monitor_step = trainer.global_forward_batches
self.best_monitor_epoch = trainer.cur_epoch_idx

def on_train_end(self, trainer):
f_rich_progress.stop()
if self.best_monitor_epoch != -1:
msg = f"The best performance for monitor {self._real_monitor}:{self.monitor_value} was achieved in" \
f" Epoch:{self.best_monitor_epoch}, Global Batch:{self.best_monitor_step}."
logger.info(msg)

@property
def name(self): # progress bar的名称
@@ -97,6 +109,7 @@ class RichCallback(ProgressCallback):
advance=None, completed=trainer.cur_epoch_idx, refresh=True)

def on_train_end(self, trainer):
super(RichCallback, self).on_train_end(trainer)
self.clear_tasks()

def on_before_backward(self, trainer, outputs):
@@ -121,8 +134,8 @@ class RichCallback(ProgressCallback):
text_style = ''
characters = '-'
if self.monitor is not None:
monitor_value = self.get_monitor_value(results)
if self.is_better_monitor_value(monitor_value, keep_if_better=True):
if self.is_better_results(results, keep_if_better=True):
self.record_better_monitor(trainer)
if abs(self.monitor_value) != float('inf'):
rule_style = 'spring_green3'
text_style = '[bold]'
@@ -201,8 +214,8 @@ class RawTextCallback(ProgressCallback):
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}'
text = ''
if self.monitor is not None:
monitor_value = self.get_monitor_value(results)
if self.is_better_monitor_value(monitor_value, keep_if_better=True):
if self.is_better_results(results, keep_if_better=True):
self.record_better_monitor(trainer)
if abs(self.monitor_value) != float('inf'):
text = '+'*self.num_signs + base_text + '+'*self.num_signs
if len(text) == 0:
@@ -266,6 +279,7 @@ class TqdmCallback(ProgressCallback):
self.progress_bar.set_description_str(self.task2id['epoch'], f'Epoch:{trainer.cur_epoch_idx}', refresh=True)

def on_train_end(self, trainer):
super(TqdmCallback, self).on_train_end(trainer)
self.clear_tasks()

def on_before_backward(self, trainer, outputs):
@@ -287,8 +301,8 @@ class TqdmCallback(ProgressCallback):
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}'
text = ''
if self.monitor is not None:
monitor_value = self.get_monitor_value(results)
if self.is_better_monitor_value(monitor_value, keep_if_better=True):
if self.is_better_results(results, keep_if_better=True):
self.record_better_monitor(trainer)
if abs(self.monitor_value) != float('inf'):
text = '+'*self.num_signs + base_text + '+'*self.num_signs
if len(text) == 0:


+ 152
- 0
fastNLP/core/callbacks/timer_callback.py View File

@@ -0,0 +1,152 @@
import time
from .callback import Callback
from ..log import logger
__all__ = ['TimerCallback']


class _Timer:
"""Timer."""

def __init__(self, name):
self.name_ = name
self.elapsed_ = 0.0
self.started_ = False
self.start_time = time.time()

def start(self):
"""Start the timer."""
assert not self.started_, f'{self.name_} timer has already been started'
self.start_time = time.time()
self.started_ = True

def stop(self):
"""Stop the timer."""
assert self.started_, f'{self.name_} timer is not started'
self.elapsed_ += (time.time() - self.start_time)
self.started_ = False

def reset(self):
"""Reset timer."""
self.elapsed_ = 0.0
self.started_ = False

def elapsed(self, reset=True):
"""Calculate the elapsed time."""
started_ = self.started_
# If the timing in progress, end it first.
if self.started_:
self.stop()
# Get the elapsed time.
elapsed_ = self.elapsed_
# Reset the elapsed time
if reset:
self.reset()
# If timing was in progress, set it back.
if started_:
self.start()
return elapsed_


class Timers:
"""Group of timers."""

def __init__(self):
self.timers = {}

def __call__(self, name):
if name not in self.timers:
self.timers[name] = _Timer(name)
return self.timers[name]

def __contains__(self, item):
return item in self.timers

def reset(self):
for timer in self.timers.values():
timer.reset()


class TimerCallback(Callback):
"""
这个 callback 的作用是打印训练过程中的相关时间信息,例如训练时长,评测时长,总的时长等

"""
def __init__(self, print_every=-1, time_ndigit=3):
"""

:param print_every: 在哪个时候打印时间信息。

* *负数*: 表示每隔多少 epoch 结束打印一次;
* *0*: 表示整个训练结束才打印;
* *正数*: 每隔多少个 step 打印一次;

:param time_ndigit: 保留多少位的小数
"""
assert isinstance(print_every, int), "print_every must be an int number."
self.timers = Timers()
self.print_every = print_every
self.time_ndigit = time_ndigit

def on_train_begin(self, trainer):
self.timers('total').start()
self.timers('train').start()

def on_fetch_data_begin(self, trainer):
self.timers('fetch-data').start()

def on_fetch_data_end(self, trainer):
self.timers('fetch-data').stop()

def on_train_batch_begin(self, trainer, batch, indices):
self.timers('forward').start()

def on_before_backward(self, trainer, outputs):
self.timers('forward').stop()
self.timers('backward').start()

def on_after_backward(self, trainer):
self.timers('backward').stop()

def on_before_optimizers_step(self, trainer, optimizers):
self.timers('optimize').start()

def on_after_optimizers_step(self, trainer, optimizers):
self.timers('optimize').stop()

def on_evaluate_begin(self, trainer):
self.timers('train').stop()
self.timers('evaluate').start()

def on_evaluate_end(self, trainer, results):
self.timers('evaluate').stop()
self.timers('train').start()

def format_timer(self, reset=True):
line = ''
timers = ['fetch-data', 'forward', 'backward', 'optimize', 'evaluate', 'train', 'total']
for timer_name in timers:
if not timer_name in self.timers:
continue
timer = self.timers(timer_name)
elapsed = round(timer.elapsed(reset=reset), self.time_ndigit)
if elapsed != 0:
line = line + f', {timer_name}: {elapsed}s'
return line

def on_train_batch_end(self, trainer):
if self.print_every>0 and trainer.global_forward_batches % self.print_every == 0:
line = self.format_timer()
logger.info(f"Running {self.print_every} batches{line}")

def on_train_epoch_end(self, trainer):
if self.print_every < 0 and trainer.cur_epoch_idx % abs(self.print_every) == 0:
line = self.format_timer()
logger.info(f"Running {abs(self.print_every)} epochs{line}")

def on_train_end(self, trainer):
if self.print_every == 0:
line = self.format_timer()
logger.info(f"Training finished{line}")




+ 3
- 1
fastNLP/core/controllers/loops/train_batch_loop.py View File

@@ -41,10 +41,12 @@ class TrainBatchLoop(Loop):
batch = next(dataloader)
indices = get_batch_indices()
except StopIteration:
trainer.on_fetch_data_end()
break

trainer.on_fetch_data_end()

try:
trainer.on_fetch_data_end()
batch = match_and_substitute_params(trainer.input_mapping, batch)
batch = trainer.move_data_to_device(batch)



+ 3
- 0
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -108,6 +108,9 @@ class TorchDataLoader(DataLoader):
if not isinstance(dataset, _FDataSet):
dataset = _FDataSet(dataset)

if num_workers>0 and multiprocessing_context is None:
multiprocessing_context = 'fork' # 这里默认使用fork的方式来启动多进程

if batch_sampler is not None:
batch_size = 1
shuffle = False


Loading…
Cancel
Save