diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 22eac708..84ca03bd 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -107,8 +107,11 @@ class Evaluator: ``dropout`` 与 ``batch normalization`` 将会关闭。默认为 ``True``。如果为 ``False``,``fastNLP`` 不会对 ``model`` 的 ``evaluate`` 状态做任何设置。无论 该值是什么,``fastNLP`` 都会在评测后将 ``model`` 的状态设置为 ``train``; * *use_dist_sampler* -- - True / False, 是否使用分布式评测的方式。仅当 ``driver`` 为分布式类型时,该参数才有效。默认为根据 ``driver`` 是否支持 - 分布式进行设置。如果为 ``True``,将使得每个进程上的 ``dataloader`` 自动使用不同数据,所有进程的数据并集是整个数据集; + 表示在 ``Evaluator`` 中在使用分布式的时候是否将保证 dataloader 的 ``sampler`` 替换为 + 分布式的 ``sampler``,其特点是每个卡上的数据之间不重叠,所有卡上数据的加起来是整个数据集。若传入的 dataloader + 的 sampler 为 (a) 深度学习框架自带的默认 sampler ; (b) fastNLP 的 Sampler 等,则将替换为 + :class:`~fastNLP.UnrepeatedSequentialSampler`,如果这个行为不是期待的,请本参数设置为 ``False``,并针对每个卡控制其可以 + 用到的数据。如果不是以上两类 sampler ,fastNLP 将报错。 * *output_from_new_proc* -- 等价于 ``Trainer`` 中的 ``output_from_new_proc`` 参数; * *progress_bar* -- 等价于 ``Trainer`` 中的 ``progress_bar`` 参数; * *check_dataloader_legality* -- 是否检查 ``DataLoader`` 是否合法,默认为 ``True`` 。 diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index c9674f3a..f92611dd 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -65,6 +65,13 @@ class Trainer(TrainerEventTrigger): 您传入的 ``Driver`` 实例中的模型; :param train_dataloader: 训练数据集,注意其必须是单独的一个数据集,不能是 List 或者 Dict; + + .. warning:: + + 当使用分布式训练时, ``fastNLP`` 会默认将 ``dataloader`` 中的 ``Sampler`` 进行处理,以使得在一个 ``epcoh`` 中,不同卡 + 用以训练的数据是不重叠的。如果你对 sampler 有特殊处理,那么请将 ``use_dist_sampler`` 参数设置为 ``False`` ,此刻需要由 + 你自身保证每张卡上所使用的数据是不同的。 + :param optimizers: 训练所需要的优化器;可以是单独的一个优化器实例,也可以是多个优化器组成的 List; :param device: 该参数用来指定具体训练时使用的机器;注意当该参数仅当您通过 `torch.distributed.launch/run` 启动时可以为 None, 此时 fastNLP 不会对模型和数据进行设备之间的移动处理,但是你可以通过参数 `input_mapping` 和 `output_mapping` 来实现设备之间 @@ -93,9 +100,9 @@ class Trainer(TrainerEventTrigger): .. warning:: - 注意参数 ``device`` 仅当您通过 pytorch 或者其它训练框架自身的并行训练启动脚本启动 ddp 训练时才允许为 ``None``! + 注意参数 ``device`` 仅当您通过训练框架自身的并行训练启动脚本启动 ddp 训练时才允许为 ``None``! - 例如,当您使用:: + 例如,在 pytorch 中,当您使用:: python -m torch.distributed.launch --nproc_per_node 2 train.py @@ -290,16 +297,22 @@ class Trainer(TrainerEventTrigger): driver 实例的 ``model_device`` 才会为 None; 3. 对于 paddle,该参数无效; - * *use_dist_sampler* -- True / False, 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch - 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 - * *evaluate_use_dist_sampler* -- True / False, 表示在 ``Evaluator`` 中在使用分布式的时候是否将 dataloader 的 ``sampler`` 替换为分布式的 ``sampler``; - 不传入该值时,该值与 ``use_dist_sampler`` 参数保持一致; + * *use_dist_sampler* -- 表示是否使用分布式的 ``sampler``。在多卡时,分布式 ``sampler`` 将自动决定每张卡上读取的 sample ,使得一个 epoch + 内所有卡的 sample 加起来为一整个数据集的 sample,同时为了保证所有卡上拥有相同数量的 sample ,有的卡上可能会有重复的 sample ,例如 + 8卡训练,只有9个sample,如果batch_size为1,那么第二个batch时,有7张卡将没有 sample 可用,因此只有重复使用 sample 来 pad 到第二个 + batch 中。如果不希望 fastNLP 对 dataloader 的sampler 做特殊设置,请将该值设置为 False ,若确实需要分布式的训练,请在 Trainer 外 + 对 train_dataloader 做的数据做特殊处理使得其在不同的卡之间 sample 是 + * *evaluate_use_dist_sampler* -- 表示在 ``Evaluator`` 中在使用分布式的时候是否将保证 dataloader 的 ``sampler`` 替换为 + evaluate 时使用的分布式的 ``sampler``,其特点是每个卡上的数据之间不重叠,所有卡上数据的加起来是整个数据集。若传入的 dataloader + 的 sampler 为 (a) 深度学习框架自带的默认 sampler ; (b) fastNLP 的 Sampler 等,则将替换为 + :class:`~fastNLP.UnrepeatedSequentialSampler`,如果这个行为不是期待的,请本参数设置为 ``False``,并针对每个卡控制其可以 + 用到的数据。 * *output_from_new_proc* -- 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; 注意该参数仅当使用分布式的 ``driver`` 时才有效,例如 ``TorchDDPDriver``; - * *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto', 'tqdm'] 或者 :class:`~.fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback`等对象, + * *progress_bar* -- 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto', 'tqdm'] 或者 :class:`~fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback`等对象, 默认为 auto , auto 表示如果检测到当前 terminal 为交互型则使用 :class:`~fastNLP.RichCallback`,否则使用 :class:`~fastNLP.RawTextCallback` 对象。如果 需要定制 progress bar 的参数,例如打印频率等,可以传入 :class:`~fastNLP.RichCallback`, :class:`~fastNLP.RawTextCallback` 等对象。 * *train_input_mapping* -- 与 input_mapping 一致,但是只用于 ``Trainer`` 中。与 input_mapping 互斥。 diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 98c07495..d19da9fe 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -422,8 +422,7 @@ class PaddleFleetDriver(PaddleDriver): # trainer, evaluator if dist is None: if reproducible: - raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize fleet out of our " - "control.") + raise RuntimeError("It is not allowed to save checkpoint if the sampler is not allowed to be replaced.") else: args = self.get_dataloader_args(dataloader) if isinstance(args.batch_sampler, ReproducibleBatchSampler): diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 008df0d0..45a1a61a 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -140,6 +140,9 @@ if _NEED_IMPORT_TORCH: import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel from torch.utils.data import BatchSampler + from torch.utils.data import RandomSampler as TorchRandomSampler + from torch.utils.data import SequentialSampler as TorchSequentialSampler + from torch.utils.data import BatchSampler as TorchBatchSampler __all__ = [ 'TorchDDPDriver' @@ -159,6 +162,7 @@ from fastNLP.core.samplers import ReproducibleSampler, RandomSampler, Unrepeated from fastNLP.envs import FASTNLP_DISTRIBUTED_CHECK, FASTNLP_GLOBAL_RANK, FASTNLP_GLOBAL_SEED, FASTNLP_NO_SYNC from fastNLP.core.log import logger from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather, fastnlp_torch_broadcast_object +from .utils import _check_dataloader_args_for_distributed class TorchDDPDriver(TorchDriver): @@ -535,8 +539,7 @@ class TorchDDPDriver(TorchDriver): # trainer, evaluator if dist is None: if reproducible: - raise RuntimeError("It is not allowed to use checkpoint retraining when you initialize ddp out of our " - "control.") + raise RuntimeError("It is not allowed to save checkpoint if the sampler is not allowed to be replaced.") else: args = self.get_dataloader_args(dataloader) if isinstance(args.batch_sampler, ReproducibleBatchSampler): @@ -565,13 +568,7 @@ class TorchDDPDriver(TorchDriver): ) return replace_sampler(dataloader, sampler) else: - if type(args.batch_sampler) is not BatchSampler or (type(args.sampler) not in {torch.utils.data.RandomSampler, - torch.utils.data.SequentialSampler}): - raise TypeError("Using customized ``batch_sampler`` or ``sampler`` with 'DDP' may cause unseen problems, cause" - "we will substitute your dataloader's sampler into our ``fastNLP.RandomSampler``. You should make" - "your customized sampler being able to be used in distributed setting before you initialize ``Trainer`` by yourself," - "and then set the parameter ``use_dist_sampler`` of ``Trainer`` to ``False``.") - + _check_dataloader_args_for_distributed(args, controller='Trainer') sampler = RandomSampler( dataset=args.dataset, shuffle=args.shuffle, @@ -589,7 +586,7 @@ class TorchDDPDriver(TorchDriver): if isinstance(args.sampler, ReproducibleSampler): sampler = conversion_between_reproducible_and_unrepeated_sampler(args.sampler) elif not isinstance(args.sampler, UnrepeatedSampler): - # todo same as dist + _check_dataloader_args_for_distributed(args, controller='Evaluator') sampler = UnrepeatedSequentialSampler( dataset=args.dataset ) diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index c36e0f8d..263cf712 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -8,6 +8,7 @@ if _NEED_IMPORT_TORCH: from torch.nn.parallel import DistributedDataParallel from torch.utils.data import RandomSampler as TorchRandomSampler from torch.utils.data import SequentialSampler as TorchSequentialSampler + from torch.utils.data import BatchSampler as TorchBatchSampler __all__ = [ 'TorchSingleDriver' @@ -123,19 +124,20 @@ class TorchSingleDriver(TorchDriver): return replace_sampler(dataloader, sampler) if reproducible: - if isinstance(args.sampler, TorchRandomSampler): - if getattr(args.sampler, '_num_samples', None) is None \ - and getattr(args.sampler, 'replacements', False) is False \ - and getattr(args.sampler, 'generator', None) is None: - # 如果本来就是随机的,并且没有定制,直接替换掉吧。 - sampler = RandomSampler(args.sampler.data_source, shuffle=True) - logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") + if type(args.batch_sampler) is TorchBatchSampler: + if type(args.sampler) is TorchRandomSampler: + if getattr(args.sampler, '_num_samples', None) is None \ + and getattr(args.sampler, 'replacements', False) is False \ + and getattr(args.sampler, 'generator', None) is None: + # 如果本来就是随机的,并且没有定制,直接替换掉吧。 + sampler = RandomSampler(args.sampler.data_source, shuffle=True) + logger.debug("Replace torch RandomSampler into fastNLP RandomSampler.") + return replace_sampler(dataloader, sampler) + elif type(args.sampler) is TorchSequentialSampler: + # 需要替换为不要 shuffle 的。 + sampler = RandomSampler(args.sampler.data_source, shuffle=False) + logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.") return replace_sampler(dataloader, sampler) - elif isinstance(args.sampler, TorchSequentialSampler): - # 需要替换为不要 shuffle 的。 - sampler = RandomSampler(args.sampler.data_source, shuffle=False) - logger.debug("Replace torch SequentialSampler into fastNLP RandomSampler.") - return replace_sampler(dataloader, sampler) batch_sampler = ReproduceBatchSampler( batch_sampler=args.batch_sampler, batch_size=args.batch_size, diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py index f0704dd5..300bf196 100644 --- a/fastNLP/core/drivers/torch_driver/utils.py +++ b/fastNLP/core/drivers/torch_driver/utils.py @@ -22,7 +22,11 @@ if _NEED_IMPORT_TORCH: import torch # import torch.nn as nn from torch.nn import Module - from torch.utils.data import DataLoader, BatchSampler + from torch.utils.data import DataLoader + from torch.utils.data import RandomSampler as TorchRandomSampler + from torch.utils.data import SequentialSampler as TorchSequentialSampler + from torch.utils.data import BatchSampler as TorchBatchSampler + else: from fastNLP.core.utils.dummy_class import DummyClass as Module @@ -200,6 +204,7 @@ def replace_sampler(dataloader: "DataLoader", sampler): non_default_params.add("dataset") reconstruct_args = {k: v for k, v in instance_attrs.items() if k in non_default_params} + reconstruct_args.update({"sampler": sampler, "shuffle": False, "batch_sampler": None}) batch_sampler = getattr(dataloader, "batch_sampler") if batch_sampler is not None and isinstance(batch_sampler, ReproducibleBatchSampler): @@ -277,5 +282,13 @@ def optimizer_state_to_device(state, device): return new_state - - +def _check_dataloader_args_for_distributed(args, controller='Trainer'): + if type(args.batch_sampler) is not TorchBatchSampler and (type(args.sampler) not in {TorchRandomSampler, + TorchSequentialSampler}): + mode = 'training' if controller == 'Trainer' else 'evaluation' + substitution = 'fastNLP.RandomSampler' if controller == 'Trainer' else 'fastNLP.UnrepeatedSequentialSampler' + raise TypeError(f"Using customized ``batch_sampler`` or ``sampler`` for distributed {mode} may cause " + f"unpredictable problems, because fastNLP will substitute the dataloader's sampler into " + f"``{substitution}``. The customized sampler should set for distributed running " + f"before initializing ``{controller}`` , and then set the " + f"parameter ``use_dist_sampler`` of ``{controller}`` to ``False``.") diff --git a/tests/core/drivers/torch_driver/test_ddp.py b/tests/core/drivers/torch_driver/test_ddp.py index 3f3dde74..74f44c04 100644 --- a/tests/core/drivers/torch_driver/test_ddp.py +++ b/tests/core/drivers/torch_driver/test_ddp.py @@ -12,8 +12,9 @@ from fastNLP.core.samplers import ( ) from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchNormalXYDataset -from tests.helpers.utils import magic_argv_env_context +from tests.helpers.utils import magic_argv_env_context, recover_logger, Capturing from fastNLP.envs.distributed import rank_zero_rm +from fastNLP import logger from fastNLP.core.drivers.torch_driver.dist_utils import fastnlp_torch_all_gather from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: @@ -936,12 +937,118 @@ def test_batch_sampler_dataloader(shuffle, batch_size, drop_last, reproducible=T for i in range(1, len(data)-1): flags.append(data[i]