@@ -16,6 +16,7 @@ __all__ = [ | |||
"ResultsMonitor", | |||
'HasMonitorCallback', | |||
"FitlogCallback", | |||
"TimerCallback", | |||
# collators | |||
'Collator', | |||
@@ -21,7 +21,9 @@ __all__ = [ | |||
"ResultsMonitor", | |||
'HasMonitorCallback', | |||
"FitlogCallback" | |||
"FitlogCallback", | |||
"TimerCallback" | |||
] | |||
@@ -37,4 +39,4 @@ from .torch_callbacks import * | |||
from .more_evaluate_callback import MoreEvaluateCallback | |||
from .has_monitor_callback import ResultsMonitor, HasMonitorCallback | |||
from .fitlog_callback import FitlogCallback | |||
from .timer_callback import TimerCallback |
@@ -171,7 +171,7 @@ class ResultsMonitor: | |||
@property | |||
def log_name(self) -> str: | |||
""" | |||
内部用于打印信息使用 | |||
内部用于打印当前类别信息使用 | |||
:return: | |||
""" | |||
@@ -106,11 +106,11 @@ class LoadBestModelCallback(HasMonitorCallback): | |||
def on_train_end(self, trainer): | |||
if abs(self.monitor_value) != float('inf'): # 如果是 inf 说明从来没有运行过。 | |||
if self.real_save_folder: | |||
logger.info(f"Loading best model from {self.real_save_folder} with {self.monitor_name}: {self.monitor_value}...") | |||
logger.info(f"Loading best model from {self.real_save_folder} with {self._real_monitor}: {self.monitor_value}...") | |||
trainer.load_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, | |||
model_load_fn=self.model_load_fn) | |||
else: | |||
logger.info(f"Loading best model from buffer with {self.monitor_name}: {self.monitor_value}...") | |||
logger.info(f"Loading best model from buffer with {self._real_monitor}: {self.monitor_value}...") | |||
self.buffer.seek(0) | |||
trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) | |||
if self.delete_after_after: | |||
@@ -1,5 +1,4 @@ | |||
import json | |||
import sys | |||
from typing import Union | |||
__all__ = [ | |||
@@ -16,8 +15,21 @@ from fastNLP.core.log import logger | |||
class ProgressCallback(HasMonitorCallback): | |||
def __init__(self, monitor, larger_better, must_have_monitor=False): | |||
super(ProgressCallback, self).__init__(monitor=monitor, larger_better=larger_better, | |||
must_have_monitor=must_have_monitor) | |||
self.best_monitor_epoch = -1 | |||
self.best_monitor_step = -1 | |||
def record_better_monitor(self, trainer): | |||
self.best_monitor_step = trainer.global_forward_batches | |||
self.best_monitor_epoch = trainer.cur_epoch_idx | |||
def on_train_end(self, trainer): | |||
f_rich_progress.stop() | |||
if self.best_monitor_epoch != -1: | |||
msg = f"The best performance for monitor {self._real_monitor}:{self.monitor_value} was achieved in" \ | |||
f" Epoch:{self.best_monitor_epoch}, Global Batch:{self.best_monitor_step}." | |||
logger.info(msg) | |||
@property | |||
def name(self): # progress bar的名称 | |||
@@ -97,6 +109,7 @@ class RichCallback(ProgressCallback): | |||
advance=None, completed=trainer.cur_epoch_idx, refresh=True) | |||
def on_train_end(self, trainer): | |||
super(RichCallback, self).on_train_end(trainer) | |||
self.clear_tasks() | |||
def on_before_backward(self, trainer, outputs): | |||
@@ -121,8 +134,8 @@ class RichCallback(ProgressCallback): | |||
text_style = '' | |||
characters = '-' | |||
if self.monitor is not None: | |||
monitor_value = self.get_monitor_value(results) | |||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||
if self.is_better_results(results, keep_if_better=True): | |||
self.record_better_monitor(trainer) | |||
if abs(self.monitor_value) != float('inf'): | |||
rule_style = 'spring_green3' | |||
text_style = '[bold]' | |||
@@ -201,8 +214,8 @@ class RawTextCallback(ProgressCallback): | |||
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | |||
text = '' | |||
if self.monitor is not None: | |||
monitor_value = self.get_monitor_value(results) | |||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||
if self.is_better_results(results, keep_if_better=True): | |||
self.record_better_monitor(trainer) | |||
if abs(self.monitor_value) != float('inf'): | |||
text = '+'*self.num_signs + base_text + '+'*self.num_signs | |||
if len(text) == 0: | |||
@@ -266,6 +279,7 @@ class TqdmCallback(ProgressCallback): | |||
self.progress_bar.set_description_str(self.task2id['epoch'], f'Epoch:{trainer.cur_epoch_idx}', refresh=True) | |||
def on_train_end(self, trainer): | |||
super(TqdmCallback, self).on_train_end(trainer) | |||
self.clear_tasks() | |||
def on_before_backward(self, trainer, outputs): | |||
@@ -287,8 +301,8 @@ class TqdmCallback(ProgressCallback): | |||
base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' | |||
text = '' | |||
if self.monitor is not None: | |||
monitor_value = self.get_monitor_value(results) | |||
if self.is_better_monitor_value(monitor_value, keep_if_better=True): | |||
if self.is_better_results(results, keep_if_better=True): | |||
self.record_better_monitor(trainer) | |||
if abs(self.monitor_value) != float('inf'): | |||
text = '+'*self.num_signs + base_text + '+'*self.num_signs | |||
if len(text) == 0: | |||
@@ -0,0 +1,152 @@ | |||
import time | |||
from .callback import Callback | |||
from ..log import logger | |||
__all__ = ['TimerCallback'] | |||
class _Timer: | |||
"""Timer.""" | |||
def __init__(self, name): | |||
self.name_ = name | |||
self.elapsed_ = 0.0 | |||
self.started_ = False | |||
self.start_time = time.time() | |||
def start(self): | |||
"""Start the timer.""" | |||
assert not self.started_, f'{self.name_} timer has already been started' | |||
self.start_time = time.time() | |||
self.started_ = True | |||
def stop(self): | |||
"""Stop the timer.""" | |||
assert self.started_, f'{self.name_} timer is not started' | |||
self.elapsed_ += (time.time() - self.start_time) | |||
self.started_ = False | |||
def reset(self): | |||
"""Reset timer.""" | |||
self.elapsed_ = 0.0 | |||
self.started_ = False | |||
def elapsed(self, reset=True): | |||
"""Calculate the elapsed time.""" | |||
started_ = self.started_ | |||
# If the timing in progress, end it first. | |||
if self.started_: | |||
self.stop() | |||
# Get the elapsed time. | |||
elapsed_ = self.elapsed_ | |||
# Reset the elapsed time | |||
if reset: | |||
self.reset() | |||
# If timing was in progress, set it back. | |||
if started_: | |||
self.start() | |||
return elapsed_ | |||
class Timers: | |||
"""Group of timers.""" | |||
def __init__(self): | |||
self.timers = {} | |||
def __call__(self, name): | |||
if name not in self.timers: | |||
self.timers[name] = _Timer(name) | |||
return self.timers[name] | |||
def __contains__(self, item): | |||
return item in self.timers | |||
def reset(self): | |||
for timer in self.timers.values(): | |||
timer.reset() | |||
class TimerCallback(Callback): | |||
""" | |||
这个 callback 的作用是打印训练过程中的相关时间信息,例如训练时长,评测时长,总的时长等 | |||
""" | |||
def __init__(self, print_every=-1, time_ndigit=3): | |||
""" | |||
:param print_every: 在哪个时候打印时间信息。 | |||
* *负数*: 表示每隔多少 epoch 结束打印一次; | |||
* *0*: 表示整个训练结束才打印; | |||
* *正数*: 每隔多少个 step 打印一次; | |||
:param time_ndigit: 保留多少位的小数 | |||
""" | |||
assert isinstance(print_every, int), "print_every must be an int number." | |||
self.timers = Timers() | |||
self.print_every = print_every | |||
self.time_ndigit = time_ndigit | |||
def on_train_begin(self, trainer): | |||
self.timers('total').start() | |||
self.timers('train').start() | |||
def on_fetch_data_begin(self, trainer): | |||
self.timers('fetch-data').start() | |||
def on_fetch_data_end(self, trainer): | |||
self.timers('fetch-data').stop() | |||
def on_train_batch_begin(self, trainer, batch, indices): | |||
self.timers('forward').start() | |||
def on_before_backward(self, trainer, outputs): | |||
self.timers('forward').stop() | |||
self.timers('backward').start() | |||
def on_after_backward(self, trainer): | |||
self.timers('backward').stop() | |||
def on_before_optimizers_step(self, trainer, optimizers): | |||
self.timers('optimize').start() | |||
def on_after_optimizers_step(self, trainer, optimizers): | |||
self.timers('optimize').stop() | |||
def on_evaluate_begin(self, trainer): | |||
self.timers('train').stop() | |||
self.timers('evaluate').start() | |||
def on_evaluate_end(self, trainer, results): | |||
self.timers('evaluate').stop() | |||
self.timers('train').start() | |||
def format_timer(self, reset=True): | |||
line = '' | |||
timers = ['fetch-data', 'forward', 'backward', 'optimize', 'evaluate', 'train', 'total'] | |||
for timer_name in timers: | |||
if not timer_name in self.timers: | |||
continue | |||
timer = self.timers(timer_name) | |||
elapsed = round(timer.elapsed(reset=reset), self.time_ndigit) | |||
if elapsed != 0: | |||
line = line + f', {timer_name}: {elapsed}s' | |||
return line | |||
def on_train_batch_end(self, trainer): | |||
if self.print_every>0 and trainer.global_forward_batches % self.print_every == 0: | |||
line = self.format_timer() | |||
logger.info(f"Running {self.print_every} batches{line}") | |||
def on_train_epoch_end(self, trainer): | |||
if self.print_every < 0 and trainer.cur_epoch_idx % abs(self.print_every) == 0: | |||
line = self.format_timer() | |||
logger.info(f"Running {abs(self.print_every)} epochs{line}") | |||
def on_train_end(self, trainer): | |||
if self.print_every == 0: | |||
line = self.format_timer() | |||
logger.info(f"Training finished{line}") | |||
@@ -41,10 +41,12 @@ class TrainBatchLoop(Loop): | |||
batch = next(dataloader) | |||
indices = get_batch_indices() | |||
except StopIteration: | |||
trainer.on_fetch_data_end() | |||
break | |||
trainer.on_fetch_data_end() | |||
try: | |||
trainer.on_fetch_data_end() | |||
batch = match_and_substitute_params(trainer.input_mapping, batch) | |||
batch = trainer.move_data_to_device(batch) | |||
@@ -108,6 +108,9 @@ class TorchDataLoader(DataLoader): | |||
if not isinstance(dataset, _FDataSet): | |||
dataset = _FDataSet(dataset) | |||
if num_workers>0 and multiprocessing_context is None: | |||
multiprocessing_context = 'fork' # 这里默认使用fork的方式来启动多进程 | |||
if batch_sampler is not None: | |||
batch_size = 1 | |||
shuffle = False | |||