Browse Source

为 ddp 在用户使用自己的sampler 和 batch sampler 是添加禁止

tags/v1.0.0alpha
YWMditto 3 years ago
parent
commit
70dea71cdb
4 changed files with 16 additions and 28 deletions
  1. +1
    -1
      fastNLP/core/controllers/evaluator.py
  2. +2
    -2
      fastNLP/core/controllers/trainer.py
  3. +8
    -0
      fastNLP/core/drivers/torch_driver/ddp.py
  4. +5
    -25
      fastNLP/core/drivers/torch_driver/utils.py

+ 1
- 1
fastNLP/core/controllers/evaluator.py View File

@@ -107,7 +107,7 @@ class Evaluator:
``dropout`` 与 ``batch normalization`` 将会关闭。默认为 ``True``。如果为 ``False``,``fastNLP`` 不会对 ``model`` 的 ``evaluate`` 状态做任何设置。无论
该值是什么,``fastNLP`` 都会在评测后将 ``model`` 的状态设置为 ``train``;
* *use_dist_sampler* --
是否使用分布式评测的方式。仅当 ``driver`` 为分布式类型时,该参数才有效。默认为根据 ``driver`` 是否支持
True / False, 是否使用分布式评测的方式。仅当 ``driver`` 为分布式类型时,该参数才有效。默认为根据 ``driver`` 是否支持
分布式进行设置。如果为 ``True``,将使得每个进程上的 ``dataloader`` 自动使用不同数据,所有进程的数据并集是整个数据集;
* *output_from_new_proc* -- 等价于 ``Trainer`` 中的 ``output_from_new_proc`` 参数;
* *progress_bar* -- 等价于 ``Trainer`` 中的 ``progress_bar`` 参数;


+ 2
- 2
fastNLP/core/controllers/trainer.py View File

@@ -290,9 +290,9 @@ class Trainer(TrainerEventTrigger):
driver 实例的 ``model_device`` 才会为 None;
3. 对于 paddle,该参数无效;

* *use_dist_sampler* -- 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch
* *use_dist_sampler* -- True / False, 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch
内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。
* *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``;
* *evaluate_use_dist_sampler* -- True / False, 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``;
不传入该值时,该值与 ``use_dist_sampler`` 参数保持一致;
* *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一:
["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到


+ 8
- 0
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -565,6 +565,13 @@ class TorchDDPDriver(TorchDriver):
)
return replace_sampler(dataloader, sampler)
else:
if type(args.batch_sampler) is not BatchSampler or (type(args.sampler) not in {torch.utils.data.RandomSampler,
torch.utils.data.SequentialSampler}):
raise TypeError("Using customized ``batch_sampler`` or ``sampler`` with 'DDP' may cause unseen problems, cause"
"we will substitute your dataloader's sampler into our ``fastNLP.RandomSampler``. You should make"
"your customized sampler being able to be used in distributed setting before you initialize ``Trainer`` by yourself,"
"and then set the parameter ``use_dist_sampler`` of ``Trainer`` to ``False``.")

sampler = RandomSampler(
dataset=args.dataset,
shuffle=args.shuffle,
@@ -582,6 +589,7 @@ class TorchDDPDriver(TorchDriver):
if isinstance(args.sampler, ReproducibleSampler):
sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler)
elif not isinstance(args.sampler, UnrepeatedSampler):
# todo same as dist
sampler = UnrepeatedSequentialSampler(
dataset=args.dataset
)


+ 5
- 25
fastNLP/core/drivers/torch_driver/utils.py View File

@@ -14,7 +14,7 @@ from fastNLP.envs import (
FASTNLP_BACKEND_LAUNCH,
FASTNLP_GLOBAL_SEED,
)
from fastNLP.core.samplers import re_instantiate_sampler
from fastNLP.core.samplers import re_instantiate_sampler, ReproducibleBatchSampler
from fastNLP.core.utils import auto_param_call
from fastNLP.core.log import logger

@@ -23,7 +23,6 @@ if _NEED_IMPORT_TORCH:
# import torch.nn as nn
from torch.nn import Module
from torch.utils.data import DataLoader, BatchSampler
from torch.utils.data.sampler import Sampler
else:
from fastNLP.core.utils.dummy_class import DummyClass as Module

@@ -201,7 +200,10 @@ def replace_sampler(dataloader: "DataLoader", sampler):
non_default_params.add("dataset")

reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params}
reconstruct_args.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler))

batch_sampler = getattr(dataloader, "batch_sampler")
if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler):
raise RuntimeError("It should not be running here, please report a bug to us.")

required_args = {
p.name
@@ -243,28 +245,6 @@ def replace_sampler(dataloader: "DataLoader", sampler):
return type(dataloader)(**reconstruct_args)


def _dataloader_init_kwargs_resolve_sampler(
dataloader: "DataLoader", sampler: Optional["Sampler"]
) -> Dict[str, Any]:
r"""
此函数用于处理与 DataLoader 关联的采样器、batch_sampler 参数重新实例化;
"""
batch_sampler = getattr(dataloader, "batch_sampler")
# checking the batch sampler type is different than PyTorch default.
if batch_sampler is not None and not isinstance(batch_sampler, BatchSampler):
batch_sampler = re_instantiate_sampler(batch_sampler)

return {
"sampler": None,
"shuffle": False,
"batch_sampler": batch_sampler,
"batch_size": 1,
"drop_last": False,
}

return {"sampler": sampler, "shuffle": False, "batch_sampler": None}


def replace_batch_sampler(dataloader, new_batch_sampler):
r"""
替换一个 dataloader 的 batch_sampler;


Loading…
Cancel
Save