diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index a1d4adf8..22eac708 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -107,7 +107,7 @@ class Evaluator: ``dropout`` 与 ``batch normalization`` 将会关闭。默认为 ``True``。如果为 ``False``,``fastNLP`` 不会对 ``model`` 的 ``evaluate`` 状态做任何设置。无论 该值是什么,``fastNLP`` 都会在评测后将 ``model`` 的状态设置为 ``train``; * *use_dist_sampler* -- - 是否使用分布式评测的方式。仅当 ``driver`` 为分布式类型时,该参数才有效。默认为根据 ``driver`` 是否支持 + True / False, 是否使用分布式评测的方式。仅当 ``driver`` 为分布式类型时,该参数才有效。默认为根据 ``driver`` 是否支持 分布式进行设置。如果为 ``True``,将使得每个进程上的 ``dataloader`` 自动使用不同数据,所有进程的数据并集是整个数据集; * *output_from_new_proc* -- 等价于 ``Trainer`` 中的 ``output_from_new_proc`` 参数; * *progress_bar* -- 等价于 ``Trainer`` 中的 ``progress_bar`` 参数; diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 41fca6ba..c9674f3a 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -290,9 +290,9 @@ class Trainer(TrainerEventTrigger): driver 实例的 ``model_device`` 才会为 None; 3. 对于 paddle,该参数无效; - * *use_dist_sampler* -- 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch + * *use_dist_sampler* -- True / False, 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 - * *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``; + * *evaluate_use_dist_sampler* -- True / False, 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``; 不传入该值时,该值与 ``use_dist_sampler`` 参数保持一致; * *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 43c6bc36..008df0d0 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -565,6 +565,13 @@ class TorchDDPDriver(TorchDriver): ) return replace_sampler(dataloader, sampler) else: + if type(args.batch_sampler) is not BatchSampler or (type(args.sampler) not in {torch.utils.data.RandomSampler, + torch.utils.data.SequentialSampler}): + raise TypeError("Using customized ``batch_sampler`` or ``sampler`` with 'DDP' may cause unseen problems, cause" + "we will substitute your dataloader's sampler into our ``fastNLP.RandomSampler``. You should make" + "your customized sampler being able to be used in distributed setting before you initialize ``Trainer`` by yourself," + "and then set the parameter ``use_dist_sampler`` of ``Trainer`` to ``False``.") + sampler = RandomSampler( dataset=args.dataset, shuffle=args.shuffle, @@ -582,6 +589,7 @@ class TorchDDPDriver(TorchDriver): if isinstance(args.sampler, ReproducibleSampler): sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) elif not isinstance(args.sampler, UnrepeatedSampler): + # todo same as dist sampler = UnrepeatedSequentialSampler( dataset=args.dataset ) diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py index 13281cfe..f0704dd5 100644 --- a/fastNLP/core/drivers/torch_driver/utils.py +++ b/fastNLP/core/drivers/torch_driver/utils.py @@ -14,7 +14,7 @@ from fastNLP.envs import ( FASTNLP_BACKEND_LAUNCH, FASTNLP_GLOBAL_SEED, ) -from fastNLP.core.samplers import re_instantiate_sampler +from fastNLP.core.samplers import re_instantiate_sampler, ReproducibleBatchSampler from fastNLP.core.utils import auto_param_call from fastNLP.core.log import logger @@ -23,7 +23,6 @@ if _NEED_IMPORT_TORCH: # import torch.nn as nn from torch.nn import Module from torch.utils.data import DataLoader, BatchSampler - from torch.utils.data.sampler import Sampler else: from fastNLP.core.utils.dummy_class import DummyClass as Module @@ -201,7 +200,10 @@ def replace_sampler(dataloader: "DataLoader", sampler): non_default_params.add("dataset") reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} - reconstruct_args.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler)) + + batch_sampler = getattr(dataloader, "batch_sampler") + if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): + raise RuntimeError("It should not be running here, please report a bug to us.") required_args = { p.name @@ -243,28 +245,6 @@ def replace_sampler(dataloader: "DataLoader", sampler): return type(dataloader)(**reconstruct_args) -def _dataloader_init_kwargs_resolve_sampler( - dataloader: "DataLoader", sampler: Optional["Sampler"] -) -> Dict[str, Any]: - r""" - 此函数用于处理与 DataLoader 关联的采样器、batch_sampler 参数重新实例化; - """ - batch_sampler = getattr(dataloader, "batch_sampler") - # checking the batch sampler type is different than PyTorch default. - if batch_sampler is not None and not isinstance(batch_sampler, BatchSampler): - batch_sampler = re_instantiate_sampler(batch_sampler) - - return { - "sampler": None, - "shuffle": False, - "batch_sampler": batch_sampler, - "batch_size": 1, - "drop_last": False, - } - - return {"sampler": sampler, "shuffle": False, "batch_sampler": None} - - def replace_batch_sampler(dataloader, new_batch_sampler): r""" 替换一个 dataloader 的 batch_sampler;