@@ -79,7 +79,7 @@ class RichCallback(ProgressCallback): | |||
def on_train_begin(self, trainer): | |||
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, | |||
completed=trainer.global_forward_batches/(trainer.total_batches+1e-6)) | |||
completed=trainer.global_forward_batches/(trainer.n_batches+1e-6)) | |||
def on_train_epoch_begin(self, trainer): | |||
self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6) | |||
@@ -190,7 +190,7 @@ class RawTextCallback(ProgressCallback): | |||
self.loss = 0 | |||
text = f'Epoch:{trainer.cur_epoch_idx}/{trainer.n_epochs}, Batch:{trainer.batch_idx_in_epoch}, ' \ | |||
f'loss:{round(loss, self.loss_round_ndigit)}, ' \ | |||
f'finished {round(trainer.global_forward_batches/trainer.total_batches*100, 2)}%.' | |||
f'finished {round(trainer.global_forward_batches/trainer.n_batches*100, 2)}%.' | |||
logger.info(text) | |||
def on_evaluate_end(self, trainer, results): | |||
@@ -251,7 +251,7 @@ class TqdmCallback(ProgressCallback): | |||
def on_train_begin(self, trainer): | |||
self.task2id['epoch'] = self.progress_bar.add_task(description='Epoch:0', total=trainer.n_epochs, | |||
bar_format='{desc}: {percentage:3.0f}%|{bar}| [{elapsed}<{remaining}, {rate_fmt}, {postfix}]', | |||
initial=trainer.global_forward_batches/(trainer.total_batches+1e-6)) | |||
initial=trainer.global_forward_batches/(trainer.n_batches+1e-6)) | |||
def on_train_epoch_begin(self, trainer): | |||
self.epoch_bar_update_advance = self.print_every/(trainer.num_batches_per_epoch + 1e-6) | |||
@@ -41,7 +41,7 @@ class TorchWarmupCallback(Callback): | |||
return max((progress - 1.) / (self.warmup - 1.), 0.) | |||
def on_train_begin(self, trainer): | |||
self.t_steps = trainer.total_batches | |||
self.t_steps = trainer.n_batches | |||
if self.warmup >1: | |||
self.warmup = self.warmup / self.t_steps | |||
self.t_steps = max(2, self.t_steps) # 不能小于2 | |||
@@ -460,14 +460,15 @@ class _MetricsWrapper: | |||
for metric in self._metrics: | |||
args = [] | |||
if not isinstance(batch, dict): | |||
logger.warning_once( | |||
logger.rank_zero_warning( | |||
f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on " | |||
f"the output of model to update metric.") | |||
f"the output of model to update metric.", once=True) | |||
else: | |||
args.append(batch) | |||
if not isinstance(outputs, dict): | |||
raise RuntimeError(f"The output of your model is of type:`{type(outputs)}`, please either directly" | |||
f" return a dict from your model or use `output_mapping` to convert it into dict type.") | |||
f" return a dict from your model or use `output_mapping` to convert it into dict " | |||
f"type.") | |||
if isinstance(metric, Metric): | |||
# 这样在 auto_param_call 报错的时候才清晰。 | |||
auto_param_call(metric.update, outputs, *args, signature_fn=metric.update.__wrapped__) | |||
@@ -110,7 +110,7 @@ class Trainer(TrainerEventTrigger): | |||
对于使用 ``TorchDDPDriver`` 的更多细节,请见 :class:`~fastNLP.core.drivers.torch_driver.TorchDDPDriver`。 | |||
:param n_epochs: 训练总共的 epoch 的数量,默认为 20; | |||
:param n_epochs: 训练总共的 epoch 的数量,默认为 20;也可以通过 ``n_batches`` 参数设置总共迭代多少个 ``batch`` 。 | |||
:param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 | |||
为 None; | |||
:param batch_step_fn: 定制每次训练时前向运行一个 batch 的数据所执行的函数。该函数应接受两个参数为 ``trainer`` 和 ``batch``, | |||
@@ -237,6 +237,8 @@ class Trainer(TrainerEventTrigger): | |||
注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效; | |||
:param n_batches: 迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。 | |||
:param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None; | |||
.. note:: | |||
@@ -356,6 +358,7 @@ class Trainer(TrainerEventTrigger): | |||
fp16: bool = False, | |||
monitor: Union[str, Callable] = None, | |||
larger_better: bool = True, | |||
n_batches: int = -1, | |||
marker: Optional[str] = None, | |||
**kwargs | |||
): | |||
@@ -426,6 +429,7 @@ class Trainer(TrainerEventTrigger): | |||
model_wo_auto_param_call=model_wo_auto_param_call, | |||
accumulation_steps=accumulation_steps, | |||
fp16=fp16, | |||
n_batches=n_batches, | |||
marker=marker, | |||
**kwargs | |||
) | |||
@@ -444,12 +448,12 @@ class Trainer(TrainerEventTrigger): | |||
# 初始化 state,包括提供给用户的接口和我们自己使用的接口; | |||
self.state = State() | |||
self.trainer_state = TrainerState( | |||
n_epochs=n_epochs, | |||
n_epochs=n_epochs if n_batches!=-1 else None, | |||
cur_epoch_idx=0, | |||
global_forward_batches=0, | |||
batch_idx_in_epoch=0, | |||
num_batches_per_epoch=None, # 会在具体的 train_batch_loop 中进行初始化; | |||
total_batches=None | |||
n_batches=n_batches | |||
) | |||
if metrics is None and evaluate_dataloaders is not None: | |||
@@ -598,14 +602,18 @@ class Trainer(TrainerEventTrigger): | |||
self.dataloader = _TruncatedDataLoader(self.dataloader, num_train_batch_per_epoch) | |||
self.num_batches_per_epoch = len(self.dataloader) | |||
self.total_batches = self.num_batches_per_epoch * self.n_epochs | |||
if self.n_batches == -1: | |||
self.n_batches = self.num_batches_per_epoch * self.n_epochs | |||
else: | |||
self.n_epochs = (self.n_batches+self.num_batches_per_epoch-1)//self.num_batches_per_epoch | |||
self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch | |||
try: | |||
self.on_train_begin() | |||
self.driver.barrier() | |||
self.driver.zero_grad() | |||
while self.cur_epoch_idx < self.n_epochs: | |||
while self.cur_epoch_idx < self.n_epochs and self.global_forward_batches < self.n_batches: | |||
# 这个是防止在 Trainer.load_checkpoint 之后还没结束当前 epoch 又继续 save | |||
self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch | |||
self.driver.set_model_mode("train") | |||
@@ -1367,15 +1375,15 @@ class Trainer(TrainerEventTrigger): | |||
self.trainer_state.num_batches_per_epoch = num_batches_per_epoch | |||
@property | |||
def total_batches(self) -> int: | |||
def n_batches(self) -> int: | |||
r""" | |||
:return: 返回整体的训练中实际会训练多少个 batch 的数据; | |||
""" | |||
return self.trainer_state.total_batches | |||
return self.trainer_state.n_batches | |||
@total_batches.setter | |||
def total_batches(self, total_batches: int): | |||
self.trainer_state.total_batches = total_batches | |||
@n_batches.setter | |||
def n_batches(self, n_batches: int): | |||
self.trainer_state.n_batches = n_batches | |||
""" driver property """ | |||
@@ -50,7 +50,7 @@ class TrainerState: | |||
:param global_forward_batches: 当前模型总共 forward 了多少个 step; | |||
:param batch_idx_in_epoch: 训练中在当前 epoch 的第几个 step; | |||
:param num_batches_per_epoch: 每一个 epoch 会 forward 多少个 step; | |||
:param total_batches: 完整训练过程会 forward 的 step 数量,注意 total_batches = total_batches * n_epochs; | |||
:param n_batches: 完整训练过程会 forward 的 step 数量,注意 n_batches = n_batches * n_epochs; | |||
""" | |||
n_epochs: Optional[int] = None # 无论如何重新算 | |||
@@ -61,7 +61,7 @@ class TrainerState: | |||
num_batches_per_epoch: Optional[int] = None # 无论如何重新算 | |||
total_batches: Optional[int] = None # 无论如何重新算 | |||
n_batches: Optional[int] = None # 无论如何重新算 | |||
def state_dict(self) -> Dict: | |||
r""" | |||
@@ -156,7 +156,6 @@ import _pickle as pickle | |||
from copy import deepcopy | |||
from typing import Optional, List, Callable, Union, Dict, Any, Mapping | |||
from types import LambdaType | |||
from subprocess import DEVNULL | |||
import sys | |||
import time | |||
@@ -170,6 +169,7 @@ from fastNLP.core.utils.rich_progress import f_rich_progress, DummyFRichProgress | |||
from fastNLP.core.utils.tqdm_progress import f_tqdm_progress | |||
from ..log import logger | |||
from fastNLP.core.utils.dummy_class import DummyClass | |||
from ..utils.utils import _get_fun_msg | |||
progress_bars = { | |||
@@ -780,8 +780,8 @@ class DataSet: | |||
apply_out = self._apply_process(num_proc, func, progress_desc=progress_desc, | |||
progress_bar=progress_bar) | |||
# 只检测第一个数据是否为dict类型,若是则默认所有返回值为dict;否则报错。 | |||
if not isinstance(apply_out[0], dict): | |||
raise Exception("The result of func is not a dict") | |||
if not isinstance(apply_out[0], Mapping): | |||
raise Exception(f"The result of func:{_get_fun_msg(func)} is not a dict, but of type {type(apply_out[0])}") | |||
for key, value in apply_out[0].items(): | |||
results[key] = [value] | |||
@@ -789,7 +789,8 @@ class DataSet: | |||
try: | |||
for idx, per_out in enumerate(apply_out[1:]): | |||
if len(set(results.keys()) - set(per_out.keys())): | |||
raise ApplyResultException("apply results have different fields", idx + 1) | |||
raise ApplyResultException(f"Apply results have different fields:{set(results.keys())} and " | |||
f"{set(per_out.keys())}", idx + 1) | |||
for key, value in per_out.items(): | |||
results[key].append(value) | |||
@@ -120,7 +120,7 @@ class FRichProgress(Progress, metaclass=Singleton): | |||
def add_task( | |||
self, | |||
description: str, | |||
description: str = 'Progress', | |||
start: bool = True, | |||
total: float = 100.0, | |||
completed: int = 0, | |||
@@ -7,7 +7,7 @@ __all__ = [] | |||
import json | |||
import csv | |||
# from ..core import log | |||
from ..core import logger | |||
def _read_csv(path, encoding='utf-8', headers=None, sep=',', dropna=True): | |||
@@ -81,7 +81,7 @@ def _read_json(path, encoding='utf-8', fields=None, dropna=True): | |||
yield line_idx, _res | |||
def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): | |||
def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True, drophash=True): | |||
r""" | |||
Construct a generator to read conll items. | |||
@@ -91,6 +91,7 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): | |||
:param indexes: conll object's column indexes that needed, if None, all columns are needed. default: None | |||
:param dropna: weather to ignore and drop invalid data, | |||
:if False, raise ValueError when reading invalid data. default: True | |||
:param drophash: 是否丢掉以 # 开头的 line 。 | |||
:return: generator, every time yield (line number, conll item) | |||
""" | |||
@@ -121,7 +122,7 @@ def _read_conll(path, encoding='utf-8',sep=None, indexes=None, dropna=True): | |||
sample = [] | |||
continue | |||
raise ValueError('Invalid instance which ends at line: {}'.format(line_idx)) | |||
elif line.startswith('#'): | |||
elif line.startswith('#') and drophash: | |||
continue | |||
else: | |||
sample.append(line.split(sep)) if sep else sample.append(line.split()) | |||
@@ -52,13 +52,14 @@ class ConllLoader(Loader): | |||
""" | |||
def __init__(self, headers, sep=None, indexes=None, dropna=True): | |||
def __init__(self, headers, sep=None, indexes=None, dropna=True, drophash=True): | |||
r""" | |||
:param list headers: 每一列数据的名称,需为List or Tuple of str。``header`` 与 ``indexes`` 一一对应 | |||
:param list sep: 指定分隔符,默认为制表符 | |||
:param list indexes: 需要保留的数据列下标,从0开始。若为 ``None`` ,则所有列都保留。Default: ``None`` | |||
:param bool dropna: 是否忽略非法数据,若 ``False`` ,遇到非法数据时抛出 ``ValueError`` 。Default: ``True`` | |||
:param bool drophashtag: 是否忽略以 ``#`` 开头的句子。 | |||
""" | |||
super(ConllLoader, self).__init__() | |||
if not isinstance(headers, (list, tuple)): | |||
@@ -66,6 +67,7 @@ class ConllLoader(Loader): | |||
'invalid headers: {}, should be list of strings'.format(headers)) | |||
self.headers = headers | |||
self.dropna = dropna | |||
self.drophash = drophash | |||
self.sep=sep | |||
if indexes is None: | |||
self.indexes = list(range(len(self.headers))) | |||
@@ -82,7 +84,8 @@ class ConllLoader(Loader): | |||
:return: DataSet | |||
""" | |||
ds = DataSet() | |||
for idx, data in _read_conll(path,sep=self.sep, indexes=self.indexes, dropna=self.dropna): | |||
for idx, data in _read_conll(path,sep=self.sep, indexes=self.indexes, dropna=self.dropna, | |||
drophash=self.drophash): | |||
ins = {h: data[i] for i, h in enumerate(self.headers)} | |||
ds.append(Instance(**ins)) | |||
return ds | |||
@@ -32,4 +32,4 @@ def test_torch_warmup_callback(warmup, schedule, accumulation_steps): | |||
elif schedule == 'constant': | |||
assert np.allclose(0.1, kwargs['optimizers'].param_groups[0]['lr']) | |||
assert len(r_callback.lrs)<=trainer.total_batches//accumulation_steps+1 | |||
assert len(r_callback.lrs)<=trainer.n_batches//accumulation_steps+1 |
@@ -55,4 +55,4 @@ class RecordAccumulationStepsCallback_Torch(Callback): | |||
def on_train_end(self, trainer): | |||
print(f"\n equal num: {self.equal}.\n") | |||
print(f"\ntotal_batch_num: {trainer.total_batches}.\n") | |||
print(f"\ntotal_batch_num: {trainer.n_batches}.\n") |
@@ -50,24 +50,24 @@ | |||
"\n", | |||
"```python\n", | |||
"trainer = Trainer(\n", | |||
" model=model, # 模型基于 torch.nn.Module\n", | |||
" train_dataloader=train_dataloader, # 加载模块基于 torch.utils.data.DataLoader \n", | |||
" optimizers=optimizer, # 优化模块基于 torch.optim.*\n", | |||
"\t...\n", | |||
"\tdriver=\"torch\", # 使用 pytorch 模块进行训练 \n", | |||
"\tdevice='cuda', # 使用 GPU:0 显卡执行训练\n", | |||
"\t...\n", | |||
")\n", | |||
" model=model, # 模型基于 torch.nn.Module\n", | |||
" train_dataloader=train_dataloader, # 加载模块基于 torch.utils.data.DataLoader \n", | |||
" optimizers=optimizer, # 优化模块基于 torch.optim.*\n", | |||
" ...\n", | |||
" driver=\"torch\", # 使用 pytorch 模块进行训练 \n", | |||
" device='cuda', # 使用 GPU:0 显卡执行训练\n", | |||
" ...\n", | |||
" )\n", | |||
"...\n", | |||
"evaluator = Evaluator(\n", | |||
" model=model, # 模型基于 torch.nn.Module\n", | |||
" dataloaders=evaluate_dataloader, # 加载模块基于 torch.utils.data.DataLoader\n", | |||
" metrics={'acc': Accuracy()}, # 测评方法使用 fastNLP.core.metrics.Accuracy \n", | |||
" ...\n", | |||
" driver=trainer.driver, # 保持同 trainer 的 driver 一致\n", | |||
"\tdevice=None,\n", | |||
" ...\n", | |||
")\n", | |||
" model=model, # 模型基于 torch.nn.Module\n", | |||
" dataloaders=evaluate_dataloader, # 加载模块基于 torch.utils.data.DataLoader\n", | |||
" metrics={'acc': Accuracy()}, # 测评方法使用 fastNLP.core.metrics.Accuracy \n", | |||
" ...\n", | |||
" driver=trainer.driver, # 保持同 trainer 的 driver 一致\n", | |||
" device=None,\n", | |||
" ...\n", | |||
" )\n", | |||
"```" | |||
] | |||
}, | |||
@@ -84,7 +84,7 @@ | |||
"\n", | |||
"在`fastNLP 0.8`中,**`Trainer`和`Evaluator`都依赖于具体的`driver`来完成整体的工作流程**\n", | |||
"\n", | |||
"  具体`driver`与`Trainer`以及`Evaluator`之间的关系请参考`fastNLP 0.8`的框架设计\n", | |||
"  具体`driver`与`Trainer`以及`Evaluator`之间的关系之后`tutorial 4`中的详细介绍\n", | |||
"\n", | |||
"注:这里给出一条建议:**在同一脚本中**,**所有的`Trainer`和`Evaluator`使用的`driver`应当保持一致**\n", | |||
"\n", | |||
@@ -106,17 +106,17 @@ | |||
"\n", | |||
"```python\n", | |||
"trainer = Trainer(\n", | |||
" model=model,\n", | |||
" train_dataloader=train_dataloader,\n", | |||
" optimizers=optimizer,\n", | |||
"\t...\n", | |||
"\tdriver=\"torch\",\n", | |||
"\tdevice='cuda',\n", | |||
"\t...\n", | |||
" evaluate_dataloaders=evaluate_dataloader, # 传入参数 evaluator_dataloaders\n", | |||
" metrics={'acc': Accuracy()}, # 传入参数 metrics\n", | |||
"\t...\n", | |||
")\n", | |||
" model=model,\n", | |||
" train_dataloader=train_dataloader,\n", | |||
" optimizers=optimizer,\n", | |||
" ...\n", | |||
" driver=\"torch\",\n", | |||
" device='cuda',\n", | |||
" ...\n", | |||
" evaluate_dataloaders=evaluate_dataloader, # 传入参数 evaluator_dataloaders\n", | |||
" metrics={'acc': Accuracy()}, # 传入参数 metrics\n", | |||
" ...\n", | |||
" )\n", | |||
"```" | |||
] | |||
}, | |||
@@ -570,7 +570,7 @@ | |||
"outputs": [], | |||
"source": [ | |||
"from fastNLP import Evaluator\n", | |||
"from fastNLP.core.metrics import Accuracy\n", | |||
"from fastNLP import Accuracy\n", | |||
"\n", | |||
"evaluator = Evaluator(\n", | |||
" model=model,\n", | |||
@@ -1312,219 +1312,6 @@ | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 13, | |||
"id": "db784d5b", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"data": { | |||
"text/plain": [ | |||
"['__annotations__',\n", | |||
" '__class__',\n", | |||
" '__delattr__',\n", | |||
" '__dict__',\n", | |||
" '__dir__',\n", | |||
" '__doc__',\n", | |||
" '__eq__',\n", | |||
" '__format__',\n", | |||
" '__ge__',\n", | |||
" '__getattribute__',\n", | |||
" '__gt__',\n", | |||
" '__hash__',\n", | |||
" '__init__',\n", | |||
" '__init_subclass__',\n", | |||
" '__le__',\n", | |||
" '__lt__',\n", | |||
" '__module__',\n", | |||
" '__ne__',\n", | |||
" '__new__',\n", | |||
" '__reduce__',\n", | |||
" '__reduce_ex__',\n", | |||
" '__repr__',\n", | |||
" '__setattr__',\n", | |||
" '__sizeof__',\n", | |||
" '__str__',\n", | |||
" '__subclasshook__',\n", | |||
" '__weakref__',\n", | |||
" '_check_callback_called_legality',\n", | |||
" '_check_train_batch_loop_legality',\n", | |||
" '_custom_callbacks',\n", | |||
" '_driver',\n", | |||
" '_evaluate_dataloaders',\n", | |||
" '_fetch_matched_fn_callbacks',\n", | |||
" '_set_num_eval_batch_per_dl',\n", | |||
" '_train_batch_loop',\n", | |||
" '_train_dataloader',\n", | |||
" '_train_step',\n", | |||
" '_train_step_signature_fn',\n", | |||
" 'accumulation_steps',\n", | |||
" 'add_callback_fn',\n", | |||
" 'backward',\n", | |||
" 'batch_idx_in_epoch',\n", | |||
" 'batch_step_fn',\n", | |||
" 'callback_manager',\n", | |||
" 'check_batch_step_fn',\n", | |||
" 'cur_epoch_idx',\n", | |||
" 'data_device',\n", | |||
" 'dataloader',\n", | |||
" 'device',\n", | |||
" 'driver',\n", | |||
" 'driver_name',\n", | |||
" 'epoch_evaluate',\n", | |||
" 'evaluate_batch_step_fn',\n", | |||
" 'evaluate_dataloaders',\n", | |||
" 'evaluate_every',\n", | |||
" 'evaluate_fn',\n", | |||
" 'evaluator',\n", | |||
" 'extract_loss_from_outputs',\n", | |||
" 'fp16',\n", | |||
" 'get_no_sync_context',\n", | |||
" 'global_forward_batches',\n", | |||
" 'has_checked_train_batch_loop',\n", | |||
" 'input_mapping',\n", | |||
" 'kwargs',\n", | |||
" 'larger_better',\n", | |||
" 'load_checkpoint',\n", | |||
" 'load_model',\n", | |||
" 'marker',\n", | |||
" 'metrics',\n", | |||
" 'model',\n", | |||
" 'model_device',\n", | |||
" 'monitor',\n", | |||
" 'move_data_to_device',\n", | |||
" 'n_epochs',\n", | |||
" 'num_batches_per_epoch',\n", | |||
" 'on',\n", | |||
" 'on_after_backward',\n", | |||
" 'on_after_optimizers_step',\n", | |||
" 'on_after_trainer_initialized',\n", | |||
" 'on_after_zero_grad',\n", | |||
" 'on_before_backward',\n", | |||
" 'on_before_optimizers_step',\n", | |||
" 'on_before_zero_grad',\n", | |||
" 'on_evaluate_begin',\n", | |||
" 'on_evaluate_end',\n", | |||
" 'on_exception',\n", | |||
" 'on_fetch_data_begin',\n", | |||
" 'on_fetch_data_end',\n", | |||
" 'on_load_checkpoint',\n", | |||
" 'on_load_model',\n", | |||
" 'on_sanity_check_begin',\n", | |||
" 'on_sanity_check_end',\n", | |||
" 'on_save_checkpoint',\n", | |||
" 'on_save_model',\n", | |||
" 'on_train_batch_begin',\n", | |||
" 'on_train_batch_end',\n", | |||
" 'on_train_begin',\n", | |||
" 'on_train_end',\n", | |||
" 'on_train_epoch_begin',\n", | |||
" 'on_train_epoch_end',\n", | |||
" 'optimizers',\n", | |||
" 'output_mapping',\n", | |||
" 'progress_bar',\n", | |||
" 'run',\n", | |||
" 'run_evaluate',\n", | |||
" 'save_checkpoint',\n", | |||
" 'save_model',\n", | |||
" 'start_batch_idx_in_epoch',\n", | |||
" 'state',\n", | |||
" 'step',\n", | |||
" 'step_evaluate',\n", | |||
" 'total_batches',\n", | |||
" 'train_batch_loop',\n", | |||
" 'train_dataloader',\n", | |||
" 'train_fn',\n", | |||
" 'train_step',\n", | |||
" 'trainer_state',\n", | |||
" 'zero_grad']" | |||
] | |||
}, | |||
"execution_count": 13, | |||
"metadata": {}, | |||
"output_type": "execute_result" | |||
} | |||
], | |||
"source": [ | |||
"dir(trainer)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": 14, | |||
"id": "953533c4", | |||
"metadata": {}, | |||
"outputs": [ | |||
{ | |||
"name": "stdout", | |||
"output_type": "stream", | |||
"text": [ | |||
"Help on method run in module fastNLP.core.controllers.trainer:\n", | |||
"\n", | |||
"run(num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, catch_KeyboardInterrupt=None) method of fastNLP.core.controllers.trainer.Trainer instance\n", | |||
" 该函数是在 ``Trainer`` 初始化后用于真正开始训练的函数;\n", | |||
" \n", | |||
" 注意如果是断点重训的第一次训练,即还没有保存任何用于断点重训的文件,那么其应当置 resume_from 为 None,并且使用 ``CheckpointCallback``\n", | |||
" 去保存断点重训的文件;\n", | |||
" \n", | |||
" :param num_train_batch_per_epoch: 每个 epoch 训练多少个 batch 后停止,*-1* 表示使用 train_dataloader 本身的长度;\n", | |||
" :param num_eval_batch_per_dl: 每个 evaluate_dataloader 验证多少个 batch 停止,*-1* 表示使用 evaluate_dataloader 本身的长度;\n", | |||
" :param num_eval_sanity_batch: 在训练之前运行多少个 evaluation batch 来检测一下 evaluation 的过程是否有错误。为 0 表示不检测;\n", | |||
" :param resume_from: 从哪个路径下恢复 trainer 的状态,注意该值需要为一个文件夹,例如使用 ``CheckpointCallback`` 时帮助您创建的保存的子文件夹;\n", | |||
" :param resume_training: 是否按照 checkpoint 中训练状态恢复。如果为 False,则只恢复 model 和 optimizers 的状态;该参数如果为 ``True``,\n", | |||
" 在下一次断点重训的时候我们会精确到上次训练截止的具体的 sample 进行训练;否则我们只会恢复 model 和 optimizers 的状态,而 ``Trainer`` 中的\n", | |||
" 其余状态都是保持初始化时的状态不会改变;\n", | |||
" :param catch_KeyboardInterrupt: 是否捕获 KeyboardInterrupt;如果该参数为 ``True``,在训练时如果您使用 ``ctrl+c`` 来终止程序,\n", | |||
" ``Trainer`` 不会抛出异常,但是会提前退出,然后 ``trainer.run()`` 之后的代码会继续运行。注意该参数在您使用分布式训练的 ``Driver``\n", | |||
" 时无效,例如 ``TorchDDPDriver``;非分布式训练的 ``Driver`` 下该参数默认为 True;\n", | |||
" \n", | |||
" .. warning::\n", | |||
" \n", | |||
" 注意初始化的 ``Trainer`` 只能调用一次 ``run`` 函数,即之后的调用 ``run`` 函数实际不会运行,因为此时\n", | |||
" ``trainer.cur_epoch_idx == trainer.n_epochs``;\n", | |||
" \n", | |||
" 这意味着如果您需要再次调用 ``run`` 函数,您需要重新再初始化一个 ``Trainer``;\n", | |||
" \n", | |||
" .. note::\n", | |||
" \n", | |||
" 您可以使用 ``num_train_batch_per_epoch`` 来简单地对您的训练过程进行验证,例如,当您指定 ``num_train_batch_per_epoch=10`` 后,\n", | |||
" 每一个 epoch 下实际训练的 batch 的数量则会被修改为 10。您可以先使用该值来设定一个较小的训练长度,在验证整体的训练流程没有错误后,再将\n", | |||
" 该值设定为 **-1** 开始真正的训练;\n", | |||
" \n", | |||
" ``num_eval_batch_per_dl`` 的意思和 ``num_train_batch_per_epoch`` 类似,即您可以通过设定 ``num_eval_batch_per_dl`` 来验证\n", | |||
" 整体的验证流程是否正确;\n", | |||
" \n", | |||
" ``num_eval_sanity_batch`` 的作用可能会让人产生迷惑,其本质和 ``num_eval_batch_per_dl`` 作用一致,但是其只被 ``Trainer`` 使用;\n", | |||
" 并且其只会在训练的一开始使用,意思为:我们在训练的开始时会先使用 ``Evaluator``(如果其不为 ``None``) 进行验证,此时验证的 batch 的\n", | |||
" 数量只有 ``num_eval_sanity_batch`` 个;但是对于 ``num_eval_batch_per_dl`` 而言,其表示在实际的整体的训练过程中,每次 ``Evaluator``\n", | |||
" 进行验证时会验证的 batch 的数量。\n", | |||
" \n", | |||
" 并且,在实际真正的训练中,``num_train_batch_per_epoch`` 和 ``num_eval_batch_per_dl`` 应当都被设置为 **-1**,但是 ``num_eval_sanity_batch``\n", | |||
" 应当为一个很小的正整数,例如 2;\n", | |||
" \n", | |||
" .. note::\n", | |||
" \n", | |||
" 参数 ``resume_from`` 和 ``resume_training`` 的设立是为了支持断点重训功能;仅当 ``resume_from`` 不为 ``None`` 时,``resume_training`` 才有效;\n", | |||
" \n", | |||
" 断点重训的意思为将上一次训练过程中的 ``Trainer`` 的状态保存下来,包括模型和优化器的状态、当前训练过的 epoch 的数量、对于当前的 epoch\n", | |||
" 已经训练过的 batch 的数量、callbacks 的状态等等;然后在下一次训练时直接加载这些状态,从而直接恢复到上一次训练过程的某一个具体时间点的状态开始训练;\n", | |||
" \n", | |||
" fastNLP 将断点重训分为了 **保存状态** 和 **恢复断点重训** 两部分:\n", | |||
" \n", | |||
" 1. 您需要使用 ``CheckpointCallback`` 来保存训练过程中的 ``Trainer`` 的状态;具体详见 :class:`~fastNLP.core.callbacks.CheckpointCallback`;\n", | |||
" ``CheckpointCallback`` 会帮助您把 ``Trainer`` 的状态保存到一个具体的文件夹下,这个文件夹的名字由 ``CheckpointCallback`` 自己生成;\n", | |||
" 2. 在第二次训练开始时,您需要找到您想要加载的 ``Trainer`` 状态所存放的文件夹,然后传入给参数 ``resume_from``;\n", | |||
" \n", | |||
" 需要注意的是 **保存状态** 和 **恢复断点重训** 是互不影响的。\n", | |||
"\n" | |||
] | |||
} | |||
], | |||
"source": [ | |||
"help(trainer.run)" | |||
] | |||
}, | |||
{ | |||
"cell_type": "code", | |||
"execution_count": null, | |||
"id": "1bc7cb4a", | |||
"metadata": {}, | |||