Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
a39d3011e8
10 changed files with 203 additions and 60 deletions
  1. +12
    -7
      fastNLP/core/callbacks/progress_callback.py
  2. +43
    -22
      fastNLP/core/controllers/trainer.py
  3. +5
    -2
      fastNLP/core/dataloaders/__init__.py
  4. +33
    -2
      fastNLP/core/dataloaders/utils.py
  5. +0
    -3
      fastNLP/core/drivers/torch_driver/ddp.py
  6. +2
    -1
      fastNLP/core/drivers/torch_driver/torch_driver.py
  7. +24
    -22
      fastNLP/core/drivers/torch_driver/utils.py
  8. +9
    -1
      fastNLP/core/metrics/metric.py
  9. +45
    -0
      tests/core/controllers/test_trainer_w_evaluator_torch.py
  10. +30
    -0
      tests/core/controllers/test_trainer_wo_evaluator_torch.py

+ 12
- 7
fastNLP/core/callbacks/progress_callback.py View File

@@ -20,6 +20,7 @@ class ProgressCallback(HasMonitorCallback):
must_have_monitor=must_have_monitor)
self.best_monitor_epoch = -1
self.best_monitor_step = -1
self.best_results = None

def record_better_monitor(self, trainer):
self.best_monitor_step = trainer.global_forward_batches
@@ -29,6 +30,8 @@ class ProgressCallback(HasMonitorCallback):
if self.best_monitor_epoch != -1:
msg = f"The best performance for monitor {self._real_monitor}:{self.monitor_value} was achieved in" \
f" Epoch:{self.best_monitor_epoch}, Global Batch:{self.best_monitor_step}."
if self.best_results is not None:
msg = msg + ' The evaluation result: \n' + str(self.best_results)
logger.info(msg)

@property
@@ -147,9 +150,11 @@ class RichCallback(ProgressCallback):
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if
not key.startswith('_')}
if self.format_json:
self.progress_bar.console.print_json(json.dumps(results))
results = json.dumps(results)
self.progress_bar.console.print_json(results)
else:
self.progress_bar.print(results)
self.best_results = results

def clear_tasks(self):
for key, taskid in self.task2id.items():
@@ -227,9 +232,9 @@ class RawTextCallback(ProgressCallback):
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if
not key.startswith('_')}
if self.format_json:
logger.info(json.dumps(results))
else:
logger.info(results)
results = json.dumps(results)
logger.info(results)
self.best_results = results

@property
def name(self): # progress bar的名称
@@ -316,9 +321,9 @@ class TqdmCallback(ProgressCallback):
results = {key:trainer.driver.tensor_to_numeric(value) for key, value in results.items() if
not key.startswith('_')}
if self.format_json:
logger.info(json.dumps(results))
else:
logger.info(results)
results = json.dumps(results)
logger.info(results)
self.best_results = results

def clear_tasks(self):
for key, taskid in self.task2id.items():


+ 43
- 22
fastNLP/core/controllers/trainer.py View File

@@ -35,6 +35,7 @@ from fastNLP.envs import rank_zero_call
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.utils.exceptions import EarlyStopException
from fastNLP.core.dataloaders import OverfitDataLoader


class Trainer(TrainerEventTrigger):
@@ -244,7 +245,20 @@ class Trainer(TrainerEventTrigger):

注意该参数仅当 ``Trainer`` 内置的 ``Evaluator`` 不为 None 时且有需要该参数但是没有设置该参数的 *callback* 实例才有效;

:param n_batches: 迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。
:param n_batches: 总共迭代多少个 ``batch`` 的训练结束。当该值不为 -1 时,将直接忽略 ``n_epochs`` 的值。
:param overfit_batches: 使用该参数来支持 '过拟合' 的功能;支持的值为 ``-1``、``0`` 或者 大于 0 的整数,表示使用多少个 batch 的数据
来进行过拟合训练;其中 0 为表示不进行任何操作;-1 表示使用所有的数据进行训练;

.. note::

您可以使用该参数来简单地查看您的模型是否是 '正确的',即您的模型是否能够在少量的数据上快速进行收敛,从而说明损失函数以及优化器等
没有问题。当使用该参数时,我们会直接从 ``train_dataloader`` 中提取固定数量的 batch,然后在所有 epoch 中都是用这些数据
来进行训练;

.. warning::

在使用该参数时,您同样可以指定 ``metrics`` 参数来进行简单的验证,当该参数和 ``metrics`` 同时出现时,我们会将 evaluate_dataloaders
直接替换为在过拟合中所使用的训练数据;因此您需要保证您的 ``metrics`` 是能够在 ``train_dataloader`` 上使用的;

:param marker: 用于标记一个 ``Trainer`` 实例,从而在用户调用 ``Trainer.on`` 函数时,标记该函数属于哪一个具体的 ``Trainer`` 实例;默认为 None;

@@ -372,6 +386,7 @@ class Trainer(TrainerEventTrigger):
monitor: Union[str, Callable] = None,
larger_better: bool = True,
n_batches: int = -1,
overfit_batches: int = 0,
marker: Optional[str] = None,
**kwargs
):
@@ -469,9 +484,6 @@ class Trainer(TrainerEventTrigger):
n_batches=n_batches
)

if metrics is None and evaluate_dataloaders is not None:
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.")

if metrics is not None and evaluate_dataloaders is None:
raise ValueError("You have set 'metrics' but forget to set 'evaluate_dataloaders'.")

@@ -495,33 +507,42 @@ class Trainer(TrainerEventTrigger):
else:
_dist_sampler = None

self.dataloader = self.train_dataloader
self.driver.set_deterministic_dataloader(self.dataloader)

self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler,
reproducible=self.callback_manager._need_reproducible_sampler)
# 进行 overfit 相关的设置;
if overfit_batches != 0:
self.dataloader = OverfitDataLoader(self.dataloader, overfit_batches)
self.overfit_batches = overfit_batches

self.evaluator = None
self.monitor = monitor
self.larger_better = larger_better
if metrics is not None and evaluate_dataloaders is not None:
check_evaluate_every(evaluate_every)
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。
progress_bar = progress_bar.name
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics,
driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping,
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0,
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler),
progress_bar=progress_bar,
check_dataloader_legality=kwargs.get('check_dataloader_legality', True))
if metrics is not None:
if overfit_batches != 0:
evaluate_dataloaders = self.dataloader
if evaluate_dataloaders is not None:
check_evaluate_every(evaluate_every)
progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为
if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。
progress_bar = progress_bar.name
self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics,
driver=self.driver, evaluate_batch_step_fn=evaluate_batch_step_fn,
evaluate_fn=evaluate_fn, input_mapping=evaluate_input_mapping,
output_mapping=evaluate_output_mapping, fp16=fp16, verbose=0,
use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", use_dist_sampler),
progress_bar=progress_bar,
check_dataloader_legality=kwargs.get('check_dataloader_legality', True))
else:
raise ValueError("You have set 'evaluate_dataloaders' but forget to set 'metrics'.")

if train_fn is not None and not isinstance(train_fn, str):
raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.")
self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn)
self.train_fn = train_fn

self.dataloader = self.train_dataloader
self.driver.set_deterministic_dataloader(self.dataloader)

self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler,
reproducible=self.callback_manager._need_reproducible_sampler)

self.evaluate_batch_step_fn = evaluate_batch_step_fn
self.kwargs = kwargs



+ 5
- 2
fastNLP/core/dataloaders/__init__.py View File

@@ -7,10 +7,13 @@ __all__ = [
'prepare_paddle_dataloader',
'prepare_torch_dataloader',

"prepare_dataloader"
"prepare_dataloader",

"OverfitDataLoader"
]

from .jittor_dataloader import JittorDataLoader, prepare_jittor_dataloader
from .torch_dataloader import TorchDataLoader, prepare_torch_dataloader, MixDataLoader
from .paddle_dataloader import PaddleDataLoader, prepare_paddle_dataloader
from .prepare_dataloader import prepare_dataloader
from .prepare_dataloader import prepare_dataloader
from .utils import OverfitDataLoader

+ 33
- 2
fastNLP/core/dataloaders/utils.py View File

@@ -1,4 +1,5 @@
from typing import Callable, Any, Union
import os
from typing import Callable, Any, Union, Sequence
from abc import ABC
import inspect
import ast
@@ -6,7 +7,8 @@ import ast
from ..log import logger
from ..utils.cache_results import get_func_calls, truncate_start_blanks
__all__ = [
"indice_collate_wrapper"
"indice_collate_wrapper",
"OverfitDataLoader"
]


@@ -111,6 +113,35 @@ class HasLenGetitemType(ABC):
return NotImplemented


class OverfitDataLoader:
"""
实现一个简单的迭代器来模拟实际的 dataloader,从给定的 dataloader 中取出部分数据,来让 Trainer 实现 overfit 的功能;
"""

def __init__(self, dataloader, overfit_batches: int):
self.dataloader = dataloader # 需要将实际的 dataloader 挂载到该对象上,从而应付一些对于实际的 dataloader 的操作;
self.batches = []
self.overfit_batches = int(overfit_batches)

if self.overfit_batches > len(dataloader):
logger.warning("Parameter 'overfit_batches' is bigger than the length of 'train_dataloader'.")

for idx, batch in enumerate(dataloader):
if idx < self.overfit_batches or self.overfit_batches <= -1:
self.batches.append(batch)

def __len__(self):
return len(self.batches)

def __iter__(self):
for batch in self.batches:
yield batch

def __getattr__(self, item):
return getattr(self.dataloader, item)



if __name__ == '__main__':
def demo(*args, **kwargs):
pass


+ 0
- 3
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -140,9 +140,6 @@ if _NEED_IMPORT_TORCH:
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler
from torch.utils.data import RandomSampler as TorchRandomSampler
from torch.utils.data import SequentialSampler as TorchSequentialSampler
from torch.utils.data import BatchSampler as TorchBatchSampler

__all__ = [
'TorchDDPDriver'


+ 2
- 1
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -31,6 +31,7 @@ from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, ReproduceBatchSampler, RandomSampler
from fastNLP.core.dataloaders import OverfitDataLoader


class TorchDriver(Driver):
@@ -92,7 +93,7 @@ class TorchDriver(Driver):
self.grad_scaler.update()

def check_dataloader_legality(self, dataloader):
if not isinstance(dataloader, DataLoader):
if not isinstance(dataloader, DataLoader) and not isinstance(dataloader, OverfitDataLoader):
raise TypeError(f"{DataLoader} is expected, instead of `{type(dataloader)}`")
if len(dataloader) == 0:
logger.rank_zero_warning("Your dataloader is empty, which is not recommended because it "


+ 24
- 22
fastNLP/core/drivers/torch_driver/utils.py View File

@@ -181,18 +181,16 @@ def replace_sampler(dataloader: "DataLoader", sampler):
instance_attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith('_')}

# 'multiprocessing_context' 是 user-defined function;
instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context
if getattr(dataloader, 'multiprocessing_context', None) is not None:
instance_attrs["multiprocessing_context"] = dataloader.multiprocessing_context

# 拿到 dataloader '__init__' 函数的默认函数签名;
init_params = dict(inspect.signature(dataloader.__init__).parameters)

# 这里为什么要单独弄的原因在于,用户在定制自己的 dataloader 的同时可能为了方便只设定一些参数,而后面直接使用 **kwargs 的方式,这时如果
# 其在初始化自己的 dataloader 实例的时候加入了一些其它的新的参数(首先这一步是必要的,因为我们只能通过这样加 sampler;另一方面,用户
# 可能确实通过 **kwargs 加入了一些新的参数),如果假设用户是这样使用的: "super().__init__(**kwargs)",那么我们就只能去 DataLoader
# 中寻找;
# 防止用户的 DataLoader 是继承了 pytorch 的 DataLoader,然后还是使用了 **kwargs 的方式对父类传参数
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items())
if has_variadic_kwargs:
# 这里之所以这样写是因为用户自己定制的 Dataloader 中名字一样的参数所设置的默认值可能不同;因此不能直接使用 update 覆盖掉了;
if has_variadic_kwargs and isinstance(dataloader, DataLoader):
# 防止用户写入了 super().__init__(**kwargs)
for key, value in dict(inspect.signature(DataLoader.__init__).parameters).items():
if key not in init_params and key != 'self':
init_params[key] = value
@@ -204,7 +202,8 @@ def replace_sampler(dataloader: "DataLoader", sampler):
non_default_params.add("dataset")

reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params}
reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None})
if isinstance(dataloader, DataLoader):
reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None})

batch_sampler = getattr(dataloader, "batch_sampler")
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler):
@@ -218,35 +217,31 @@ def replace_sampler(dataloader: "DataLoader", sampler):
and p.name not in reconstruct_args
}

# 这种错误针对的是 __init__ 中的参数没有用同样名字的 self 挂上;
# 在 attribute 中没有找到这些参数,导致了没有办法重新初始化
if required_args:
required_args = sorted(required_args)
dataloader_self_name = dataloader.__class__.__name__
raise Exception(
f"Trying to inject `DistributedSampler` into the `{dataloader_self_name}` instance. "
"This would fail as some of the `__init__` arguments are not available as instance attributes. "
f"The missing attributes are {required_args}. "
f"HINT: If you wrote the `{dataloader_self_name}` class, define `self.missing_arg_name` or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`."
f"Need to inject arguments {required_args} into the __init__ of `{dataloader_self_name}`. "
f"But they are not found in the attribute of `{dataloader_self_name}`, fastNLP cannot determine its "
f"value when try to reinitialize `{dataloader_self_name}`, please add `{required_args}` to be "
f"`{dataloader_self_name}`'s attribute."
)

# 这种错误针对的是传入的 dataloader 不是直接的 DataLoader,而是定制了 DataLoader,但是 __init__ 中没有 **kwargs;
if not has_variadic_kwargs:

# the dataloader signature does not allow keyword arguments that need to be passed
missing_kwargs = reconstruct_args.keys() - init_params.keys()
if missing_kwargs:
missing_kwargs = sorted(missing_kwargs)
dataloader_self_name = dataloader.__class__.__name__
raise Exception(
f"Trying to inject `DistributedSampler` into the `{dataloader_self_name}` instance. "
"This would fail as it doesn't expose all its attributes in the `__init__` signature. "
f"The missing arguments are {missing_kwargs}. "
f"HINT: If you wrote the `{dataloader_self_name}` class, add the `__init__` arguments or "
"manually add the `DistributedSampler` as: "
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`."
f"The parameter:{missing_kwargs} needed to reinitialize `{dataloader_self_name}` is not found."
)
# 如果没有kwargs,则保证一下只传入需要的参数
if not isinstance(dataloader, DataLoader):
reconstruct_args = {key:value for key,value in reconstruct_args.items() if key in init_params}

return type(dataloader)(**reconstruct_args)


@@ -260,6 +255,13 @@ def replace_batch_sampler(dataloader, new_batch_sampler):
params_keys.remove(k)
params = {k: getattr(dataloader, k) for k in params_keys}
params["batch_sampler"] = new_batch_sampler

if not isinstance(dataloader, DataLoader):
init_params = dict(inspect.signature(dataloader.__init__).parameters)
has_variadic_kwargs = any(v.kind is v.VAR_KEYWORD for k, v in init_params.items())
if not has_variadic_kwargs:
params = {key:value for key,value in params.items() if key in init_params}

return type(dataloader)(**params)




+ 9
- 1
fastNLP/core/metrics/metric.py View File

@@ -98,7 +98,7 @@ class Metric:
return _wrap_get_metric

def __setattr__(self, key, value):
if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True:
if getattr(self, '_cannot_change_element', False):
if key in self.elements and isinstance(value, (float, int, bool)):
self.elements[key].fill_value(value)
return
@@ -109,6 +109,14 @@ class Metric:
raise RuntimeError("Please use register_element() function to add Element.")
object.__setattr__(self, key, value)

# 当调用 __getattribute__ 没有找到时才会触发这个, 保留这个的目的只是为了防止 ide 的 warning
def __getattr__(self, name: str) -> Element:
if 'elements' in self.__dict__:
elements = self.__dict__['elements']
if name in elements:
return elements[name]
raise AttributeError("`{}` object has no attribute `{}`.".format(type(self).__name__, name))

def _wrap_update(self, update):
@functools.wraps(update)
def _wrap_update(*args, **kwargs):


+ 45
- 0
tests/core/controllers/test_trainer_w_evaluator_torch.py View File

@@ -286,6 +286,9 @@ def test_trainer_specific_params_1(
assert trainer.driver.non_blocking is False
assert trainer.driver.wo_auto_param_call is True

if dist.is_initialized():
dist.destroy_process_group()


@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1)
@@ -332,5 +335,47 @@ def test_trainer_specific_params_2(
assert _ddp_kwargs.get("broadcast_buffers") is True
assert _ddp_kwargs.get("find_unused_parameters") is True

if dist.is_initialized():
dist.destroy_process_group()


@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])]) # ("torch", [0, 1]),("torch", 1)
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)])
@magic_argv_env_context
def test_trainer_w_evaluator_overfit_torch(
model_and_optimizers: TrainerParameters,
driver,
device,
overfit_batches,
num_train_batch_per_epoch
):
"""
测试一些特殊的参数是否能够正确地传递;
"""
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
overfit_batches=overfit_batches,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders={"dl": model_and_optimizers.evaluate_dataloaders},
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
n_epochs=2,
output_from_new_proc="all",
evaluate_every=-1,

torch_kwargs={
"non_blocking": False,
"set_grad_to_none": True
}

)

trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch)

if dist.is_initialized():
dist.destroy_process_group()

+ 30
- 0
tests/core/controllers/test_trainer_wo_evaluator_torch.py View File

@@ -361,5 +361,35 @@ def test_torch_wo_auto_param_call(
dist.destroy_process_group()


# 测试 accumulation_steps;
@pytest.mark.torch
@pytest.mark.parametrize("driver,device", [("torch", 1), ("torch", [0, 1])])
@pytest.mark.parametrize("overfit_batches,num_train_batch_per_epoch", [(-1, -1), (0, -1), (3, 10), (6, -1)])
@magic_argv_env_context
def test_trainer_overfit_torch(
model_and_optimizers: TrainerParameters,
driver,
device,
overfit_batches,
num_train_batch_per_epoch
):
trainer = Trainer(
model=model_and_optimizers.model,
driver=driver,
device=device,
overfit_batches=overfit_batches,
optimizers=model_and_optimizers.optimizers,
train_dataloader=model_and_optimizers.train_dataloader,
evaluate_dataloaders=model_and_optimizers.evaluate_dataloaders,
input_mapping=model_and_optimizers.input_mapping,
output_mapping=model_and_optimizers.output_mapping,
metrics=model_and_optimizers.metrics,
output_from_new_proc="all",
n_epochs=2,
)

trainer.run(num_train_batch_per_epoch=num_train_batch_per_epoch)

if dist.is_initialized():
dist.destroy_process_group()


Loading…
Cancel
Save